• Docs >
  • Table Batched Embedding (TBE) Inference Module
Shortcuts

Table Batched Embedding (TBE) Inference Module

Stable API

class fbgemm_gpu.split_table_batched_embeddings_ops_inference.IntNBitTableBatchedEmbeddingBagsCodegen(embedding_specs: List[Tuple[str, int, int, SparseType, EmbeddingLocation]], feature_table_map: List[int] | None = None, index_remapping: List[Tensor] | None = None, pooling_mode: PoolingMode = PoolingMode.SUM, device: str | device | int | None = None, bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING, weight_lists: List[Tuple[Tensor, Tensor | None]] | None = None, pruning_hash_load_factor: float = 0.5, use_array_for_index_remapping: bool = True, output_dtype: SparseType = SparseType.FP16, cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU, cache_load_factor: float = 0.2, cache_sets: int = 0, cache_reserved_memory: float = 0.0, enforce_hbm: bool = False, record_cache_metrics: RecordCacheMetrics | None = None, gather_uvm_cache_stats: bool | None = False, row_alignment: int | None = None, fp8_exponent_bits: int | None = None, fp8_exponent_bias: int | None = None, cache_assoc: int = 32, scale_bias_size_in_bytes: int = 4, cacheline_alignment: bool = True, uvm_host_mapped: bool = False, reverse_qparam: bool = False, feature_names_per_table: List[List[str]] | None = None, indices_dtype: dtype = torch.int32)[source]

Table-batched version of nn.EmbeddingBag(sparse=False) Inference version, with support for FP32/FP16/FP8/INT8/INT4/INT2 weights

Parameters:
  • embedding_specs (List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]) –

    A list of embedding specifications. Each spec describes a specification of a physical embedding table. Each one is a tuple of number of embedding rows, embedding dimension (must be a multiple of 4), table placement (EmbeddingLocation), and compute device (ComputeDevice).

    Available EmbeddingLocation options are

    1. DEVICE = placing an embedding table in the GPU global memory (HBM)

    2. MANAGED = placing an embedding in the unified virtual memory (accessible from both GPU and CPU)

    3. MANAGED_CACHING = placing an embedding table in the unified virtual memory and using the GPU global memory (HBM) as a cache

    4. HOST = placing an embedding table in the CPU memory (DRAM)

    5. MTIA = placing an embedding table in the MTIA memory

    Available ComputeDevice options are

    1. CPU = performing table lookup on CPU

    2. CUDA = performing table lookup on GPU

    3. MTIA = performing table lookup on MTIA

  • feature_table_map (Optional[List[int]] = None) – An optional list that specifies feature-table mapping. feature_table_map[i] indicates the physical embedding table that feature i maps to.

  • index_remapping (Optional[List[Tensor]] = None) – Index remapping for pruning

  • pooling_mode (PoolingMode = PoolingMode.SUM) –

    Pooling mode. Available PoolingMode options are

    1. SUM = Sum pooling

    2. MEAN = Mean pooling

    3. NONE = No pooling (sequence embedding)

  • device (Optional[Union[str, int, torch.device]] = None) – The current device to place tensors on

  • bounds_check_mode (BoundsCheckMode = BoundsCheckMode.WARNING) –

    Input checking mode. Available BoundsCheckMode options are

    1. NONE = skip bounds check

    2. FATAL = throw an error when encountering an invalid index/offset

    3. WARNING = print a warning message when encountering an invalid index/offset and fix it (setting an invalid index to zero and adjusting an invalid offset to be within the bound)

    4. IGNORE = silently fix an invalid index/offset (setting an invalid index to zero and adjusting an invalid offset to be within the bound)

  • weight_lists (Optional[List[Tuple[Tensor, Optional[Tensor]]]] = None) – [T]

  • pruning_hash_load_factor (float = 0.5) – Load factor for pruning hash

  • use_array_for_index_remapping (bool = True) – If True, use array for index remapping. Otherwise, use hash map.

  • output_dtype (SparseType = SparseType.FP16) – The data type of an output tensor.

  • cache_algorithm (CacheAlgorithm = CacheAlgorithm.LRU) –

    The cache algorithm (used when EmbeddingLocation is set to MANAGED_CACHING). Options are

    1. LRU = least recently used

    2. LFU = least frequently used

  • cache_load_factor (float = 0.2) – A factor used for determining the cache capacity when EmbeddingLocation.MANAGED_CACHING is used. The cache capacity is cache_load_factor * the total number of rows in all embedding tables

  • cache_sets (int = 0) – The number of cache sets (used when EmbeddingLocation is set to MANAGED_CACHING)

  • cache_reserved_memory (float = 0.0) – The amount of memory reserved in HBM for non-cache purpose (used when EmbeddingLocation is set to MANAGED_CACHING).

  • enforce_hbm (bool = False) – If True, place all weights/momentums in HBM when using EmbeddingLocation.MANAGED_CACHING

  • record_cache_metrics (Optional[RecordCacheMetrics] = None) – Record a number of hits, a number of requests, etc if RecordCacheMetrics.record_cache_miss_counter is True and record the similar metrics table-wise if RecordCacheMetrics.record_tablewise_cache_miss is True

  • gather_uvm_cache_stats (Optional[bool] = False) – If True, collect the cache statistics when EmbeddingLocation is set to MANAGED_CACHING

  • row_alignment (Optional[int] = None) – Row alignment

  • fp8_exponent_bits (Optional[int] = None) – Exponent bits when using FP8

  • fp8_exponent_bias (Optional[int] = None) – Exponent bias when using FP8

  • cache_assoc (int = 32) – Number of ways for cache

  • scale_bias_size_in_bytes (int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES) – Size of scale and bias in bytes

  • cacheline_alignment (bool = True) – If True, align each table to 128b cache line boundary

  • uvm_host_mapped (bool = False) – If True, allocate every UVM tensor using malloc + cudaHostRegister. Otherwise use cudaMallocManaged

  • reverse_qparam (bool = False) – If True, load qparams at end of each row. Otherwise, load qparams at begnning of each row.

  • feature_names_per_table (Optional[List[List[str]]] = None) – An optional list that specifies feature names per table. feature_names_per_table[t] indicates the feature names of table t.

  • indices_dtype (torch.dtype = torch.int32) – The expected dtype of the indices tensor that will be passed to the forward() call. This information will be used to construct the remap_indices array/hash. Options are torch.int32 and torch.int64.

assign_embedding_weights(q_weight_list: List[Tuple[Tensor, Tensor | None]]) None[source]

Assigns self.split_embedding_weights() with values from the input list of weights and scale_shifts.

fill_random_weights() None[source]

Fill the buffer with random weights, table by table

forward(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor | None = None) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

recompute_module_buffers() None[source]

Compute module buffers that’re on meta device and are not materialized in reset_weights_placements_and_offsets(). Currently those buffers are weights_tys, rows_per_table, D_offsets and bounds_check_warning. Pruning related or uvm related buffers are not computed right now.

split_embedding_weights(split_scale_shifts: bool = True) List[Tuple[Tensor, Tensor | None]][source]

Returns a list of weights, split by table

split_embedding_weights_with_scale_bias(split_scale_bias_mode: int = 1) List[Tuple[Tensor, Tensor | None, Tensor | None]][source]

Returns a list of weights, split by table split_scale_bias_mode:

0: return one row; 1: return weights + scale_bias; 2: return weights, scale, bias.

Other API

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources