• Docs >
  • SSD Embedding Operators
Shortcuts

SSD Embedding Operators

CUDA Operators

enum RocksdbWriteMode

rocksdb write mode

In SSD offloading there are 3 writes in each train iteration FWD_ROCKSDB_READ: cache lookup will move uncached data from rocksdb into L2 cache on fwd path

FWD_L1_EVICTION: L1 cache eviciton will evict data into L2 cache on fwd path

BWD_L1_CNFLCT_MISS_WRITE_BACK: L1 conflict miss will insert into L2 for embedding update on bwd path

All the L2 cache filling above will potentially trigger rocksdb write once L2 cache is full

Additionally we will do ssd io on L2 flush

Values:

enumerator FWD_ROCKSDB_READ
enumerator FWD_L1_EVICTION
enumerator BWD_L1_CNFLCT_MISS_WRITE_BACK
enumerator FLUSH
inline size_t hash_shard(int64_t id, size_t num_shards)

hash function used for SSD L2 cache and rocksdb sharding algorithm

Parameters:
  • id – sharding key

  • num_shards – sharding range

Returns:

shard id ranges from [0, num_shards)

void cuda_callback_func(cudaStream_t stream, cudaError_t status, void *functor)

A callback function for cudaStreamAddCallback

A common callback function for cudaStreamAddCallback, i.e., cudaStreamCallback_t callback. This function casts functor into a void function, invokes it and then delete it (the deletion occurs in another thread)

Parameters:
  • stream – CUDA stream that cudaStreamAddCallback operates on

  • status – CUDA status

  • functor – A functor that will be called

Returns:

None

Tensor masked_index_put_cuda(Tensor self, Tensor indices, Tensor values, Tensor count, const bool use_pipeline, const int64_t preferred_sms)

Similar to torch.Tensor.index_put but ignore indices < 0

masked_index_put_cuda only supports 2D input values. It puts count rows in values into self using the row indices that are >= 0 in indices.

# Equivalent PyTorch Python code
indices = indices[:count]
filter_ = indices >= 0
indices_ = indices[filter_]
self[indices_] = values[filter_.nonzero().flatten()]
Parameters:
  • self – The 2D output tensor (the tensor that is indexed)

  • indices – The 1D index tensor

  • values – The 2D input tensor

  • count – The tensor that contains the length of indices to process

  • use_pipeline – A flag that indicates that this kernel will overlap with other kernels. If it is true, then use a fraction of SMs to reduce resource competition

  • preferred_sms – The number of preferred SMs for the kernel to use when use_pipeline=true. This value is ignored when use_pipeline=false.

Returns:

The self tensor

Tensor masked_index_select_cuda(Tensor self, Tensor indices, Tensor values, Tensor count, const bool use_pipeline, const int64_t preferred_sms)

Similar to torch.index_select but ignore indices < 0

masked_index_select_cuda only supports 2D input values. It puts count rows that are specified in indices (where indices >= 0) from values into self

# Equivalent PyTorch Python code
indices = indices[:count]
filter_ = indices >= 0
indices_ = indices[filter_]
self[filter_.nonzero().flatten()] = values[indices_]
Parameters:
  • self – The 2D output tensor

  • indices – The 1D index tensor

  • values – The 2D input tensor (the tensor that is indexed)

  • count – The tensor that contains the length of indices to process

  • use_pipeline – A flag that indicates that this kernel will overlap with other kernels. If it is true, then use a fraction of SMs to reduce resource competition

  • preferred_sms – The number of preferred SMs for the kernel to use when use_pipeline=true. This value is ignored when use_pipeline=false.

Returns:

The self tensor

std::tuple<Tensor, Tensor> ssd_generate_row_addrs_cuda(const Tensor &lxu_cache_locations, const Tensor &assigned_cache_slots, const Tensor &linear_index_inverse_indices, const Tensor &unique_indices_count_cumsum, const Tensor &cache_set_inverse_indices, const Tensor &lxu_cache_weights, const Tensor &inserted_ssd_weights, const Tensor &unique_indices_length, const Tensor &cache_set_sorted_unique_indices)

Generate memory addresses for SSD TBE data.

The data retrieved from SSD can be stored in either a scratch pad (HBM) or LXU cache (also HBM). lxu_cache_locations is used to specify the location of the data. If the location is -1, the data for the associated index is in the scratch pad; otherwise, it is in the cache. To enable TBE kernels to access the data conveniently, this operator generates memory addresses of the first byte for each index. When accessing data, a TBE kernel only needs to convert addresses into pointers.

Moreover, this operator also generate the list of post backward evicted indices which are basically the indices that their data is in the scratch pad.

Parameters:
  • lxu_cache_locations – The tensor that contains cache slots where data is stored for the full list of indices. -1 is a sentinel value that indicates that data is not in cache.

  • assigned_cache_slots – The tensor that contains cache slots for the unique list of indices. -1 indicates that data is not in cache

  • linear_index_inverse_indices – The tensor that contains the original position of linear indices before being sorted

  • unique_indices_count_cumsum – The tensor that contains the the exclusive prefix sum results of the counts of unique indices

  • cache_set_inverse_indices – The tensor that contains the original positions of cache sets before being sorted

  • lxu_cache_weights – The LXU cache tensor

  • inserted_ssd_weights – The scratch pad tensor

  • unique_indices_length – The tensor that contains the number of unique indices (GPU tensor)

  • cache_set_sorted_unique_indices – The tensor that contains associated unique indices for the sorted unique cache sets

Returns:

A tuple of tensors (the SSD row address tensor and the post backward evicted index tensor)

void ssd_update_row_addrs_cuda(const Tensor &ssd_row_addrs_curr, const Tensor &inserted_ssd_weights_curr_next_map, const Tensor &lxu_cache_locations_curr, const Tensor &linear_index_inverse_indices_curr, const Tensor &unique_indices_count_cumsum_curr, const Tensor &cache_set_inverse_indices_curr, const Tensor &lxu_cache_weights, const Tensor &inserted_ssd_weights_next, const Tensor &unique_indices_length_curr)

Update memory addresses for SSD TBE data.

When pipeline prefetching is enabled, data in a scratch pad of the current iteration can be moved to L1 or a scratch pad of the next iteration during the prefetch step. This operator updates the memory addresses of data that is relocated to the correct location.

Parameters:
  • ssd_row_addrs_curr – The tensor that contains the row address of the current iteration

  • inserted_ssd_weights_curr_next_map – The tensor that contains mapping between the location of each index in the current iteration in the scratch pad of the next iteration. (-1 = the data has not been moved). inserted_ssd_weights_curr_next_map[i] is the location

  • lxu_cache_locations_curr – The tensor that contains cache slots where data is stored for the full list of indices for the current iteration. -1 is a sentinel value that indicates that data is not in cache.

  • linear_index_inverse_indices_curr – The tensor that contains the original position of linear indices before being sorted for the current iteration

  • unique_indices_count_cumsum_curr – The tensor that contains the the exclusive prefix sum results of the counts of unique indices for the current iteration

  • cache_set_inverse_indices_curr – The tensor that contains the original positions of cache sets before being sorted for the current iteration

  • lxu_cache_weights – The LXU cache tensor

  • inserted_ssd_weights_next – The scratch pad tensor for the next iteration

  • unique_indices_length_curr – The tensor that contains the number of unique indices (GPU tensor) for the current iteration

Returns:

None

void compact_indices_cuda(std::vector<Tensor> compact_indices, Tensor compact_count, std::vector<Tensor> indices, Tensor masks, Tensor count)

Compact the given list of indices.

This operator compact the given list of indices based on the given masks (a tensor that contains either 0 or 1). The operater removes the indices that their corresponding mask is 0. It only operates on count number of elements (not the full tensor).

Example:

indices = [[0, 3, -1, 3, -1, -1, 7], [0, 2, 2, 3, -1, 9, 7]]
masks = [1, 1, 0, 1, 0, 0, 1]
count = 5

# x represents an arbitrary value
compact_indices = [[0, 3, 3, x, x, x, x], [0, 2, 3, x, x, x, x]]
compact_count = 3
Parameters:
  • compact_indices – A list of compact indices (output indices).

  • compact_count – A tensor that contains the number of elements after being compacted

  • indices – An input list of indices to be compacted

  • masks – A tensor that contains 0 or 1 to indicate whether to remove/keep the element. 0 = remove the corresponding index. 1 = keep the corresponding index. @count count A tensor that contains the number of elements to be compacted

class CacheLibCache
#include <cachelib_cache.h>

A Cachelib wrapper class for Cachlib interaction.

It is for maintaining all the cache related operations, including initialization, insertion, lookup and eviction. It is stateful for eviction logic that caller has to specifically fetch and reset eviction related states. Cachelib related optimization will be captured inside this class e.g. fetch and delayed markUseful to boost up get performance

Note

that this class only handles single Cachelib read/update. parallelism is done on the caller side

class EmbeddingParameterServer : public EmbeddingKVDB
#include <ps_table_batched_embeddings.h>

An implementation of EmbeddingKVDB for Training Parameter Service (TPS) client.

class CacheContext
#include <kv_db_table_batched_embeddings.h>

It holds l2cache lookup results.

num_misses is the number of misses in the l2 cache lookup cached_addr_list is preallocated with number of lookups for better parallelism and invalid spot(cache misses) will stay as sentinel value

struct QueueItem
#include <kv_db_table_batched_embeddings.h>

queue item for background L2/rocksdb update

indices/weights/count are the corresponding set() params

read_handles is cachelib abstracted indices/embedding pair metadata, will be later used on updating cachelib LRU queue as we separate it from EmbeddingKVDB::get_cache()

mode is used for monitoring rocksdb write, checkout RocksdbWriteMode for detailed explanation

class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB>
#include <kv_db_table_batched_embeddings.h>

A class for interacting with different cache layers and storage layers, public calls are executed on cuda stream.

Currently it is used by TBE to offload Key(Embedding Index) Value(Embeddings) to DRAM, SSD or remote storage, to provide better scalability without blowing up HBM resources

Subclassed by EmbeddingParameterServer, EmbeddingRocksDB

class EmbeddingRocksDB : public EmbeddingKVDB
#include <ssd_table_batched_embeddings.h>

An implementation of EmbeddingKVDB for RocksDB.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources