class fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen(embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]], feature_table_map: List[int] | None = None, cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU, cache_load_factor: float = 0.2, cache_sets: int = 0, cache_reserved_memory: float = 0.0, cache_precision: SparseType | None = None, weights_precision: SparseType = SparseType.FP32, output_dtype: SparseType = SparseType.FP32, enforce_hbm: bool = False, optimizer: EmbOptimType = EmbOptimType.EXACT_SGD, record_cache_metrics: RecordCacheMetrics | None = None, gather_uvm_cache_stats: bool | None = False, stochastic_rounding: bool = True, gradient_clipping: bool = False, max_gradient: float = 1.0, max_norm: float = 0.0, learning_rate: float = 0.01, eps: float = 1e-08, momentum: float = 0.9, weight_decay: float = 0.0, weight_decay_mode: WeightDecayMode = WeightDecayMode.NONE, eta: float = 0.001, beta1: float = 0.9, beta2: float = 0.999, ensemble_mode: EnsembleModeDefinition | None = None, emainplace_mode: EmainplaceModeDefinition | None = None, counter_based_regularization: CounterBasedRegularizationDefinition | None = None, cowclip_regularization: CowClipDefinition | None = None, pooling_mode: PoolingMode = PoolingMode.SUM, device: str | device | int | None = None, bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING, uvm_non_rowwise_momentum: bool = False, use_experimental_tbe: bool = False, prefetch_pipeline: bool = False, stats_reporter_config: TBEStatsReporterConfig | None = None, table_names: List[str] | None = None, optimizer_state_dtypes: Dict[str, SparseType] | None = None, multipass_prefetch_config: MultiPassPrefetchConfig | None = None, global_weight_decay: GlobalWeightDecayDefinition | None = None, uvm_host_mapped: bool = False, extra_optimizer_config: UserEnabledConfigDefinition | None = None, tbe_input_multiplexer_config: TBEInputMultiplexerConfig | None = None, embedding_table_index_type: dtype = torch.int64, embedding_table_offset_type: dtype = torch.int64)[source]

Table Batched Embedding (TBE) operator. Looks up one or more embedding tables. The module is application for training. The backward operator is fused with optimizer. Thus, the embedding tables are updated during backward.

  • embedding_specs (List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]) –

    A list of embedding specifications. Each spec describes a specification of a physical embedding table. Each one is a tuple of number of embedding rows, embedding dimension (must be a multiple of 4), table placement (EmbeddingLocation), and compute device (ComputeDevice).

    Available EmbeddingLocation options are

    1. DEVICE = placing an embedding table in the GPU global memory (HBM)

    2. MANAGED = placing an embedding in the unified virtual memory (accessible from both GPU and CPU)

    3. MANAGED_CACHING = placing an embedding table in the unified virtual memory and using the GPU global memory (HBM) as a cache

    4. HOST = placing an embedding table in the CPU memory (DRAM)

    5. MTIA = placing an embedding table in the MTIA memory

    Available ComputeDevice options are

    1. CPU = performing table lookup on CPU

    2. CUDA = performing table lookup on GPU

    3. MTIA = performing table lookup on MTIA

  • feature_table_map (Optional[List[int]] = None) – An optional list that specifies feature-table mapping. feature_table_map[i] indicates the physical embedding table that feature i maps to.

  • cache_algorithm (CacheAlgorithm = CacheAlgorithm.LRU) –

    The cache algorithm (used when EmbeddingLocation is set to MANAGED_CACHING). Options are

    1. LRU = least recently used

    2. LFU = least frequently used

  • cache_load_factor (float = 0.2) – A factor used for determining the cache capacity when EmbeddingLocation.MANAGED_CACHING is used. The cache capacity is cache_load_factor * the total number of rows in all embedding tables

  • cache_sets (int = 0) – The number of cache sets (used when EmbeddingLocation is set to MANAGED_CACHING)

  • cache_reserved_memory (float = 0.0) – The amount of memory reserved in HBM for non-cache purpose (used when EmbeddingLocation is set to MANAGED_CACHING).

  • cache_precision (SparseType = SparseType.FP32) – The data type of the cache (used when EmbeddingLocation is set to MANAGED_CACHING). Options are SparseType.FP32 and SparseType.FP16

  • weights_precision (SparseType = SparseType.FP32) – The data type of embedding tables (also known as weights). Options are SparseType.FP32 and SparseType.FP16

  • output_dtype (SparseType = SparseType.FP32) – The data type of an output tensor. Options are SparseType.FP32 and SparseType.FP16

  • enforce_hbm (bool = False) – If True, place all weights/momentums in HBM when using EmbeddingLocation.MANAGED_CACHING

  • optimizer (OptimType = OptimType.EXACT_SGD) –

    An optimizer to use for embedding table update in the backward pass. Available OptimType options are

    1. ADAM = Adam

    2. EXACT_ADAGRAD = Adagrad

    3. EXACT_ROWWISE_ADAGRAD = Rowwise-Aadagrad

    4. EXACT_SGD = SGD

    5. LAMB = Lamb


    7. PARTIAL_ROWWISE_ADAM = Partial rowwise-Adam

    8. PARTIAL_ROWWISE_LAMB = Partial rowwise-Lamb

    9. ENSEMBLE_ROWWISE_ADAGRAD = Ensemble rowwise-Adagrad

    10. EMAINPLACE_ROWWISE_ADAGRAD = Ema inplace rowwise-Adagrad

    11. NONE = Not applying an optimizer update in the backward pass

    and outputting a sparse weight gradient

  • record_cache_metrics (Optional[RecordCacheMetrics] = None) – Record a number of hits, a number of requests, etc if RecordCacheMetrics.record_cache_miss_counter is True and record the similar metrics table-wise if RecordCacheMetrics.record_tablewise_cache_miss is True

  • gather_uvm_cache_stats (Optional[bool] = False) – If True, collect the cache statistics when EmbeddingLocation is set to MANAGED_CACHING

  • stochastic_rounding (bool = True) – If True, apply stochastic rounding for weight type that is not SparseType.FP32

  • gradient_clipping (bool = False) – If True, apply gradient clipping

  • max_gradient (float = 1.0) – The value for gradient clipping

  • max_norm (float = 0.0) – The max norm value

  • learning_rate (float = 0.01) – The learning rate

  • eps (float = 1.0e-8) – The epsilon value used by Adagrad, LAMB, and Adam. Note that default is different from torch.nn.optim.Adagrad default of 1e-10

  • momentum (float = 0.9) – Momentum used by LARS-SGD

  • weight_decay (float = 0.0) –

    Weight decay used by LARS-SGD, LAMB, ADAM, and rowwise-Adagrad.

    1. EXACT_ADAGRAD, SGD, EXACT_SGD do not support weight decay

    2. LAMB, ADAM, PARTIAL_ROWWISE_ADAM, PARTIAL_ROWWISE_LAMB, LARS_SGD support decoupled weight decay

    3. EXACT_ROWWISE_ADAGRAD support both L2 and decoupled weight decay (via weight_decay_mode)

  • weight_decay_mode (WeightDecayMode = WeightDecayMode.NONE) – Weight decay mode. Options are WeightDecayMode.NONE, WeightDecayMode.L2, and WeightDecayMode.DECOUPLE

  • eta (float = 0.001) – The eta value used by LARS-SGD

  • beta1 (float = 0.9) – The beta1 value used by LAMB and ADAM

  • beta2 (float = 0.999) – The beta2 value used by LAMB and ADAM

  • ensemble_mode (Optional[EnsembleModeDefinition] = None) – Used by Ensemble Rowwise Adagrad

  • emainplace_mode (Optional[EmainplaceModeDefinition] = None) – Used by EMA in-place Rowwise Adagrad

  • counter_based_regularization (Optional[CounterBasedRegularizationDefinition] = None) – Used by Rowwise Adagrad

  • cowclip_regularization (Optional[CowClipDefinition] = None) – Used by Rowwise Adagrad

  • pooling_mode (PoolingMode = PoolingMode.SUM) –

    Pooling mode. Available PoolingMode options are

    1. SUM = Sum pooling

    2. MEAN = Mean pooling

    3. NONE = No pooling (sequence embedding)

  • device (Optional[Union[str, int, torch.device]] = None) – The current device to place tensors on

  • bounds_check_mode (BoundsCheckMode = BoundsCheckMode.WARNING) –

    Input checking mode. Available BoundsCheckMode options are

    1. NONE = skip bounds check

    2. FATAL = throw an error when encountering an invalid index/offset

    3. WARNING = print a warning message when encountering an invalid index/offset and fix it (setting an invalid index to zero and adjusting an invalid offset to be within the bound)

    4. IGNORE = silently fix an invalid index/offset (setting an invalid index to zero and adjusting an invalid offset to be within the bound)

  • uvm_non_rowwise_momentum (bool = False) – If True, place non-rowwise momentum on the unified virtual memory

  • use_experimental_tbe (bool = False) – If True, use an optimized TBE implementation (TBE v2). Note that this is supported only on NVIDIA GPUs.

  • prefetch_pipeline (bool = False) – If True, enable cache prefetch pipeline when using EmbeddingLocation.MANAGED_CACHING. Currently only supports the LRU cache policy. If a separate stream is used for prefetch, the optional forward_stream arg of prefetch function must be set.

  • stats_reporter_config (Optional[TBEStatsReporterConfig] = None) – A config for TBE stats reporter

  • table_names (Optional[List[str]] = None) – A list of embedding table names in this TBE

  • optimizer_state_dtypes (Optional[Dict[str, SparseType]] = None) – A optimizer state data types dict. Keys are the optimizer state names and values are their corresponding types

  • multipass_prefetch_config (Optional[MultiPassPrefetchConfig] = None) – A config for multipass cache prefetching (when EmbeddingLocation.MANAGED_CACHING is used)

  • global_weight_decay (Optional[GlobalWeightDecayDefinition] = None) – A config for global weight decay

  • uvm_host_mapped (bool = False) – If True, allocate every UVM tensor using malloc + cudaHostRegister. Otherwise use cudaMallocManaged

  • None) (extra_optimizer_config Optional[UserEnabledConfigDefinition] =) –

    An extra config to enable certain modes for optimizer. These modes are not enabled by default. - use_rowwise_bias_correction is used in Adam to enable rowwise

    bias correction computation

  • embedding_table_index_type (torch.dtype = torch.int64) – The data type of the embedding table index tensor. Options are torch.int32 and torch.int64

  • embedding_table_offset_type (torch.dtype = torch.int64) – The data type of the embedding table offset tensor. Options are torch.int32 and torch.int64

forward(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor | None = None, feature_requires_grad: Tensor | None = None, batch_size_per_feature_per_rank: List[List[int]] | None = None, total_unique_indices: int | None = None) Tensor[source]

The forward pass function that

  1. Performs input bound checking

  2. Generates necessary variable batch size embedding (VBE) metadata (if VBE is used)

  3. Prefetches data from UVM to cache (if EmbeddingLocation.MANAGED_CACHING is used and the user has not explicitly prefetched data)

  4. Performs the embedding table lookup by invoking a corresponding Autograd function (based on the chosen optimizer)

  • indices (Tensor) – A 1D-tensor that contains indices to be looked up from all embedding table

  • offsets (Tensor) – A 1D-tensor that conatins offsets of indices. Shape (B * T + 1) where B = batch size and T = the number of features. offsets[t * B + b + 1] - offsets[t * B + b] is the length of bag b of feature t

  • per_sample_weights (Optional[Tensor]) – An optional 1D-float-tensor that contains per sample weights. If None, unweighted embedding lookup will be perform. Otherwise, weighted will be used. The length of this tensor must be the same as the length of the indices tensor. The value of per_sample_weights[i] will be used to multiply with every element in the looked up row indices[i], where 0 <= i < len(per_sample_weights).

  • feature_requires_grad (Optional[Tensor]) – An optional 1D-tensor for indicating if per_sample_weights requires gradient. The length of the tensor must be equal to the number of features

  • batch_size_per_feature_per_rank (Optional[List[List[int]]]) – An optional 2D-tensor that contains batch sizes for every rank and every feature. If None, TBE assumes that every feature has the same batch size and computes the batch size from the offsets shape. Otherwise, TBE assumes that different features can have different batch sizes and uses the variable batch size embedding look up mode (VBE). Shape (number of features, number of ranks). batch_size_per_feature_per_rank[f][r] represents the batch size of feature f and rank r

  • total_unique_indices (Optional[int]) – An optional integer that represents the total number of unique indices. This value must be set when using OptimType.NONE. This is because TBE requires this information for allocating the weight gradient tensor in the backward pass.


A 2D-tensor containing looked up data. Shape (B, total_D) where B = batch size and total_D = the sum of all embedding dimensions in the table


>>> import torch
>>> from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
>>>    EmbeddingLocation,
>>> )
>>> from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
>>>    SplitTableBatchedEmbeddingBagsCodegen,
>>>    ComputeDevice,
>>> )
>>> # Two tables
>>> embedding_specs = [
>>>     (3, 8, EmbeddingLocation.DEVICE, ComputeDevice.CUDA),
>>>     (5, 4, EmbeddingLocation.MANAGED, ComputeDevice.CUDA)
>>> ]
>>> tbe = SplitTableBatchedEmbeddingBagsCodegen(embedding_specs)
>>> tbe.init_embedding_weights_uniform(-1, 1)
>>> print(tbe.split_embedding_weights())
[tensor([[-0.9426,  0.7046,  0.4214, -0.0419,  0.1331, -0.7856, -0.8124, -0.2021],
        [-0.5771,  0.5911, -0.7792, -0.1068, -0.6203,  0.4813, -0.1677,  0.4790],
        [-0.5587, -0.0941,  0.5754,  0.3475, -0.8952, -0.1964,  0.0810, -0.4174]],
       device='cuda:0'), tensor([[-0.2513, -0.4039, -0.3775,  0.3273],
        [-0.5399, -0.0229, -0.1455, -0.8770],
        [-0.9520,  0.4593, -0.7169,  0.6307],
        [-0.1765,  0.8757,  0.8614,  0.2051],
        [-0.0603, -0.9980, -0.7958, -0.5826]], device='cuda:0')]
>>> # Batch size = 3
>>> indices = torch.tensor([0, 1, 2, 0, 1, 2, 0, 3, 1, 4, 2, 0, 0],
>>>                        device="cuda",
>>>                        dtype=torch.long)
>>> offsets = torch.tensor([0, 2, 5, 7, 9, 12, 13],
>>>                        device="cuda",
>>>                        dtype=torch.long)
>>> output = tbe(indices, offsets)
>>> # Batch size = 3, total embedding dimension = 12
>>> print(output.shape)
torch.Size([3, 12])
>>> print(output)
tensor([[-1.5197,  1.2957, -0.3578, -0.1487, -0.4873, -0.3044, -0.9801,  0.2769,
         -0.7164,  0.8528,  0.7159, -0.6719],
        [-2.0784,  1.2016,  0.2176,  0.1988, -1.3825, -0.5008, -0.8991, -0.1405,
         -1.2637, -0.9427, -1.8902,  0.3754],
        [-1.5013,  0.6105,  0.9968,  0.3057, -0.7621, -0.9821, -0.7314, -0.6195,
         -0.2513, -0.4039, -0.3775,  0.3273]], device='cuda:0',
set_learning_rate(lr: float) None[source]

Sets the learning rate.


lr (float) – The learning rate value to set to

set_optimizer_step(step: int) None[source]

Sets the optimizer step.


step (int) – The step value to set to

split_embedding_weights() List[Tensor][source]

Returns a list of embedding weights (view), split by table


A list of weights. Length = the number of tables

split_optimizer_states() List[List[Tensor]][source]

Returns a list of optimizer states (view), split by table


A list of list of states. Shape = (the number of tables, the number of states).

The following shows the list of states (in the returned order) for each optimizer:

  1. ADAM: momentum1, momentum2

  2. EXACT_ADAGRAD: momentum1

  3. EXACT_ROWWISE_ADAGRAD: momentum1 (rowwise), prev_iter (rowwise; only when using WeightDecayMode = COUNTER or COWCLIP or global_weight_decay is not None), row_counter (rowwise; only when using WeightDecayMode = COUNTER or COWCLIP)

  4. EXACT_SGD: no states

  5. LAMB: momentum1, momentum2

  6. LARS_SGD: momentum1

  7. PARTIAL_ROWWISE_ADAM: momentum1, momentum2 (rowwise)

  8. PARTIAL_ROWWISE_LAMB: momentum1, momentum2 (rowwise)

  9. ENSEMBLE_ROWWISE_ADAGRAD: momentum1 (rowwise), momentum2

  10. NONE: no states (throwing an error)

update_hyper_parameters(params_dict: Dict[str, float]) None[source]

Sets hyper-parameters from external control flow.


params_dict (Dict[str, float]) – The dict that contains the hyper-parameter names and their values

