SSD Embedding Operators¶
CUDA Operators¶
-
enum RocksdbWriteMode¶
rocksdb write mode
In SSD offloading there are 3 writes in each train iteration FWD_ROCKSDB_READ: cache lookup will move uncached data from rocksdb into L2 cache on fwd path
FWD_L1_EVICTION: L1 cache eviciton will evict data into L2 cache on fwd path
BWD_L1_CNFLCT_MISS_WRITE_BACK: L1 conflict miss will insert into L2 for embedding update on bwd path
All the L2 cache filling above will potentially trigger rocksdb write once L2 cache is full
Additionally we will do ssd io on L2 flush
Values:
-
enumerator FWD_ROCKSDB_READ¶
-
enumerator FWD_L1_EVICTION¶
-
enumerator BWD_L1_CNFLCT_MISS_WRITE_BACK¶
-
enumerator FLUSH¶
-
enumerator FWD_ROCKSDB_READ¶
-
inline size_t hash_shard(int64_t id, size_t num_shards)¶
hash function used for SSD L2 cache and rocksdb sharding algorithm
- Parameters:
id – sharding key
num_shards – sharding range
- Returns:
shard id ranges from [0, num_shards)
-
void cuda_callback_func(cudaStream_t stream, cudaError_t status, void *functor)¶
A callback function for
cudaStreamAddCallback
A common callback function for
cudaStreamAddCallback
, i.e.,cudaStreamCallback_t callback
. This function castsfunctor
into a void function, invokes it and then delete it (the deletion occurs in another thread)- Parameters:
stream – CUDA stream that
cudaStreamAddCallback
operates onstatus – CUDA status
functor – A functor that will be called
- Returns:
None
-
Tensor masked_index_put_cuda(Tensor self, Tensor indices, Tensor values, Tensor count, const bool use_pipeline, const int64_t preferred_sms)¶
Similar to
torch.Tensor.index_put
but ignoreindices < 0
masked_index_put_cuda
only supports 2D inputvalues
. It putscount
rows invalues
intoself
using the row indices that are >= 0 inindices
.# Equivalent PyTorch Python code indices = indices[:count] filter_ = indices >= 0 indices_ = indices[filter_] self[indices_] = values[filter_.nonzero().flatten()]
- Parameters:
self – The 2D output tensor (the tensor that is indexed)
indices – The 1D index tensor
values – The 2D input tensor
count – The tensor that contains the length of
indices
to processuse_pipeline – A flag that indicates that this kernel will overlap with other kernels. If it is true, then use a fraction of SMs to reduce resource competition
preferred_sms – The number of preferred SMs for the kernel to use when use_pipeline=true. This value is ignored when use_pipeline=false.
- Returns:
The
self
tensor
-
Tensor masked_index_select_cuda(Tensor self, Tensor indices, Tensor values, Tensor count, const bool use_pipeline, const int64_t preferred_sms)¶
Similar to
torch.index_select
but ignoreindices < 0
masked_index_select_cuda
only supports 2D inputvalues
. It putscount
rows that are specified inindices
(whereindices
>= 0) fromvalues
intoself
# Equivalent PyTorch Python code indices = indices[:count] filter_ = indices >= 0 indices_ = indices[filter_] self[filter_.nonzero().flatten()] = values[indices_]
- Parameters:
self – The 2D output tensor
indices – The 1D index tensor
values – The 2D input tensor (the tensor that is indexed)
count – The tensor that contains the length of
indices
to processuse_pipeline – A flag that indicates that this kernel will overlap with other kernels. If it is true, then use a fraction of SMs to reduce resource competition
preferred_sms – The number of preferred SMs for the kernel to use when use_pipeline=true. This value is ignored when use_pipeline=false.
- Returns:
The
self
tensor
-
std::tuple<Tensor, Tensor> ssd_generate_row_addrs_cuda(const Tensor &lxu_cache_locations, const Tensor &assigned_cache_slots, const Tensor &linear_index_inverse_indices, const Tensor &unique_indices_count_cumsum, const Tensor &cache_set_inverse_indices, const Tensor &lxu_cache_weights, const Tensor &inserted_ssd_weights, const Tensor &unique_indices_length, const Tensor &cache_set_sorted_unique_indices)¶
Generate memory addresses for SSD TBE data.
The data retrieved from SSD can be stored in either a scratch pad (HBM) or LXU cache (also HBM).
lxu_cache_locations
is used to specify the location of the data. If the location is -1, the data for the associated index is in the scratch pad; otherwise, it is in the cache. To enable TBE kernels to access the data conveniently, this operator generates memory addresses of the first byte for each index. When accessing data, a TBE kernel only needs to convert addresses into pointers.Moreover, this operator also generate the list of post backward evicted indices which are basically the indices that their data is in the scratch pad.
- Parameters:
lxu_cache_locations – The tensor that contains cache slots where data is stored for the full list of indices. -1 is a sentinel value that indicates that data is not in cache.
assigned_cache_slots – The tensor that contains cache slots for the unique list of indices. -1 indicates that data is not in cache
linear_index_inverse_indices – The tensor that contains the original position of linear indices before being sorted
unique_indices_count_cumsum – The tensor that contains the the exclusive prefix sum results of the counts of unique indices
cache_set_inverse_indices – The tensor that contains the original positions of cache sets before being sorted
lxu_cache_weights – The LXU cache tensor
inserted_ssd_weights – The scratch pad tensor
unique_indices_length – The tensor that contains the number of unique indices (GPU tensor)
cache_set_sorted_unique_indices – The tensor that contains associated unique indices for the sorted unique cache sets
- Returns:
A tuple of tensors (the SSD row address tensor and the post backward evicted index tensor)
-
void ssd_update_row_addrs_cuda(const Tensor &ssd_row_addrs_curr, const Tensor &inserted_ssd_weights_curr_next_map, const Tensor &lxu_cache_locations_curr, const Tensor &linear_index_inverse_indices_curr, const Tensor &unique_indices_count_cumsum_curr, const Tensor &cache_set_inverse_indices_curr, const Tensor &lxu_cache_weights, const Tensor &inserted_ssd_weights_next, const Tensor &unique_indices_length_curr)¶
Update memory addresses for SSD TBE data.
When pipeline prefetching is enabled, data in a scratch pad of the current iteration can be moved to L1 or a scratch pad of the next iteration during the prefetch step. This operator updates the memory addresses of data that is relocated to the correct location.
- Parameters:
ssd_row_addrs_curr – The tensor that contains the row address of the current iteration
inserted_ssd_weights_curr_next_map – The tensor that contains mapping between the location of each index in the current iteration in the scratch pad of the next iteration. (-1 = the data has not been moved). inserted_ssd_weights_curr_next_map[i] is the location
lxu_cache_locations_curr – The tensor that contains cache slots where data is stored for the full list of indices for the current iteration. -1 is a sentinel value that indicates that data is not in cache.
linear_index_inverse_indices_curr – The tensor that contains the original position of linear indices before being sorted for the current iteration
unique_indices_count_cumsum_curr – The tensor that contains the the exclusive prefix sum results of the counts of unique indices for the current iteration
cache_set_inverse_indices_curr – The tensor that contains the original positions of cache sets before being sorted for the current iteration
lxu_cache_weights – The LXU cache tensor
inserted_ssd_weights_next – The scratch pad tensor for the next iteration
unique_indices_length_curr – The tensor that contains the number of unique indices (GPU tensor) for the current iteration
- Returns:
None
-
void compact_indices_cuda(std::vector<Tensor> compact_indices, Tensor compact_count, std::vector<Tensor> indices, Tensor masks, Tensor count)¶
Compact the given list of indices.
This operator compact the given list of indices based on the given masks (a tensor that contains either 0 or 1). The operater removes the indices that their corresponding mask is 0. It only operates on
count
number of elements (not the full tensor).Example:
indices = [[0, 3, -1, 3, -1, -1, 7], [0, 2, 2, 3, -1, 9, 7]] masks = [1, 1, 0, 1, 0, 0, 1] count = 5 # x represents an arbitrary value compact_indices = [[0, 3, 3, x, x, x, x], [0, 2, 3, x, x, x, x]] compact_count = 3
- Parameters:
compact_indices – A list of compact indices (output indices).
compact_count – A tensor that contains the number of elements after being compacted
indices – An input list of indices to be compacted
masks – A tensor that contains 0 or 1 to indicate whether to remove/keep the element. 0 = remove the corresponding index. 1 = keep the corresponding index. @count count A tensor that contains the number of elements to be compacted
-
class CacheLibCache¶
- #include <cachelib_cache.h>
A Cachelib wrapper class for Cachlib interaction.
It is for maintaining all the cache related operations, including initialization, insertion, lookup and eviction. It is stateful for eviction logic that caller has to specifically fetch and reset eviction related states. Cachelib related optimization will be captured inside this class e.g. fetch and delayed markUseful to boost up get performance
Note
that this class only handles single Cachelib read/update. parallelism is done on the caller side
-
class EmbeddingParameterServer : public EmbeddingKVDB¶
- #include <ps_table_batched_embeddings.h>
An implementation of EmbeddingKVDB for Training Parameter Service (TPS) client.
-
class CacheContext¶
- #include <kv_db_table_batched_embeddings.h>
It holds l2cache lookup results.
num_misses is the number of misses in the l2 cache lookup cached_addr_list is preallocated with number of lookups for better parallelism and invalid spot(cache misses) will stay as sentinel value
-
struct QueueItem¶
- #include <kv_db_table_batched_embeddings.h>
queue item for background L2/rocksdb update
indices/weights/count are the corresponding set() params
read_handles is cachelib abstracted indices/embedding pair metadata, will be later used on updating cachelib LRU queue as we separate it from EmbeddingKVDB::get_cache()
mode is used for monitoring rocksdb write, checkout RocksdbWriteMode for detailed explanation
-
class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB>¶
- #include <kv_db_table_batched_embeddings.h>
A class for interacting with different cache layers and storage layers, public calls are executed on cuda stream.
Currently it is used by TBE to offload Key(Embedding Index) Value(Embeddings) to DRAM, SSD or remote storage, to provide better scalability without blowing up HBM resources
Subclassed by EmbeddingParameterServer, EmbeddingRocksDB
-
class EmbeddingRocksDB : public EmbeddingKVDB¶
- #include <ssd_table_batched_embeddings.h>
An implementation of EmbeddingKVDB for RocksDB.