Table Batched Embedding (TBE) Inference Module¶
Stable API¶
- class fbgemm_gpu.split_table_batched_embeddings_ops_inference.IntNBitTableBatchedEmbeddingBagsCodegen(embedding_specs: List[Tuple[str, int, int, SparseType, EmbeddingLocation]], feature_table_map: List[int] | None = None, index_remapping: List[Tensor] | None = None, pooling_mode: PoolingMode = PoolingMode.SUM, device: str | device | int | None = None, bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING, weight_lists: List[Tuple[Tensor, Tensor | None]] | None = None, pruning_hash_load_factor: float = 0.5, use_array_for_index_remapping: bool = True, output_dtype: SparseType = SparseType.FP16, cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU, cache_load_factor: float = 0.2, cache_sets: int = 0, cache_reserved_memory: float = 0.0, enforce_hbm: bool = False, record_cache_metrics: RecordCacheMetrics | None = None, gather_uvm_cache_stats: bool | None = False, row_alignment: int | None = None, fp8_exponent_bits: int | None = None, fp8_exponent_bias: int | None = None, cache_assoc: int = 32, scale_bias_size_in_bytes: int = 4, cacheline_alignment: bool = True, uvm_host_mapped: bool = False, reverse_qparam: bool = False, feature_names_per_table: List[List[str]] | None = None, indices_dtype: dtype = torch.int32)[source]¶
Table-batched version of nn.EmbeddingBag(sparse=False) Inference version, with support for FP32/FP16/FP8/INT8/INT4/INT2 weights
- Parameters:
embedding_specs (List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]) –
A list of embedding specifications. Each spec describes a specification of a physical embedding table. Each one is a tuple of number of embedding rows, embedding dimension (must be a multiple of 4), table placement (EmbeddingLocation), and compute device (ComputeDevice).
Available EmbeddingLocation options are
DEVICE = placing an embedding table in the GPU global memory (HBM)
MANAGED = placing an embedding in the unified virtual memory (accessible from both GPU and CPU)
MANAGED_CACHING = placing an embedding table in the unified virtual memory and using the GPU global memory (HBM) as a cache
HOST = placing an embedding table in the CPU memory (DRAM)
MTIA = placing an embedding table in the MTIA memory
Available ComputeDevice options are
CPU = performing table lookup on CPU
CUDA = performing table lookup on GPU
MTIA = performing table lookup on MTIA
feature_table_map (Optional[List[int]] = None) – An optional list that specifies feature-table mapping. feature_table_map[i] indicates the physical embedding table that feature i maps to.
index_remapping (Optional[List[Tensor]] = None) – Index remapping for pruning
pooling_mode (PoolingMode = PoolingMode.SUM) –
Pooling mode. Available PoolingMode options are
SUM = Sum pooling
MEAN = Mean pooling
NONE = No pooling (sequence embedding)
device (Optional[Union[str, int, torch.device]] = None) – The current device to place tensors on
bounds_check_mode (BoundsCheckMode = BoundsCheckMode.WARNING) –
Input checking mode. Available BoundsCheckMode options are
NONE = skip bounds check
FATAL = throw an error when encountering an invalid index/offset
WARNING = print a warning message when encountering an invalid index/offset and fix it (setting an invalid index to zero and adjusting an invalid offset to be within the bound)
IGNORE = silently fix an invalid index/offset (setting an invalid index to zero and adjusting an invalid offset to be within the bound)
weight_lists (Optional[List[Tuple[Tensor, Optional[Tensor]]]] = None) – [T]
pruning_hash_load_factor (float = 0.5) – Load factor for pruning hash
use_array_for_index_remapping (bool = True) – If True, use array for index remapping. Otherwise, use hash map.
output_dtype (SparseType = SparseType.FP16) – The data type of an output tensor.
cache_algorithm (CacheAlgorithm = CacheAlgorithm.LRU) –
The cache algorithm (used when EmbeddingLocation is set to MANAGED_CACHING). Options are
LRU = least recently used
LFU = least frequently used
cache_load_factor (float = 0.2) – A factor used for determining the cache capacity when EmbeddingLocation.MANAGED_CACHING is used. The cache capacity is cache_load_factor * the total number of rows in all embedding tables
cache_sets (int = 0) – The number of cache sets (used when EmbeddingLocation is set to MANAGED_CACHING)
cache_reserved_memory (float = 0.0) – The amount of memory reserved in HBM for non-cache purpose (used when EmbeddingLocation is set to MANAGED_CACHING).
enforce_hbm (bool = False) – If True, place all weights/momentums in HBM when using EmbeddingLocation.MANAGED_CACHING
record_cache_metrics (Optional[RecordCacheMetrics] = None) – Record a number of hits, a number of requests, etc if RecordCacheMetrics.record_cache_miss_counter is True and record the similar metrics table-wise if RecordCacheMetrics.record_tablewise_cache_miss is True
gather_uvm_cache_stats (Optional[bool] = False) – If True, collect the cache statistics when EmbeddingLocation is set to MANAGED_CACHING
row_alignment (Optional[int] = None) – Row alignment
fp8_exponent_bits (Optional[int] = None) – Exponent bits when using FP8
fp8_exponent_bias (Optional[int] = None) – Exponent bias when using FP8
cache_assoc (int = 32) – Number of ways for cache
scale_bias_size_in_bytes (int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES) – Size of scale and bias in bytes
cacheline_alignment (bool = True) – If True, align each table to 128b cache line boundary
uvm_host_mapped (bool = False) – If True, allocate every UVM tensor using malloc + cudaHostRegister. Otherwise use cudaMallocManaged
reverse_qparam (bool = False) – If True, load qparams at end of each row. Otherwise, load qparams at begnning of each row.
feature_names_per_table (Optional[List[List[str]]] = None) – An optional list that specifies feature names per table. feature_names_per_table[t] indicates the feature names of table t.
indices_dtype (torch.dtype = torch.int32) – The expected dtype of the indices tensor that will be passed to the forward() call. This information will be used to construct the remap_indices array/hash. Options are torch.int32 and torch.int64.
- assign_embedding_weights(q_weight_list: List[Tuple[Tensor, Tensor | None]]) None [source]¶
Assigns self.split_embedding_weights() with values from the input list of weights and scale_shifts.
- forward(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor | None = None) Tensor [source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- recompute_module_buffers() None [source]¶
Compute module buffers that’re on meta device and are not materialized in reset_weights_placements_and_offsets(). Currently those buffers are weights_tys, rows_per_table, D_offsets and bounds_check_warning. Pruning related or uvm related buffers are not computed right now.