Table Batched Embedding (TBE) Training Module¶
Stable API¶
- class fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen(embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]], feature_table_map: List[int] | None = None, cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU, cache_load_factor: float = 0.2, cache_sets: int = 0, cache_reserved_memory: float = 0.0, cache_precision: SparseType | None = None, weights_precision: SparseType = SparseType.FP32, output_dtype: SparseType = SparseType.FP32, enforce_hbm: bool = False, optimizer: EmbOptimType = EmbOptimType.EXACT_SGD, record_cache_metrics: RecordCacheMetrics | None = None, gather_uvm_cache_stats: bool | None = False, stochastic_rounding: bool = True, gradient_clipping: bool = False, max_gradient: float = 1.0, max_norm: float = 0.0, learning_rate: float = 0.01, eps: float = 1e-08, momentum: float = 0.9, weight_decay: float = 0.0, weight_decay_mode: WeightDecayMode = WeightDecayMode.NONE, eta: float = 0.001, beta1: float = 0.9, beta2: float = 0.999, ensemble_mode: EnsembleModeDefinition | None = None, emainplace_mode: EmainplaceModeDefinition | None = None, counter_based_regularization: CounterBasedRegularizationDefinition | None = None, cowclip_regularization: CowClipDefinition | None = None, pooling_mode: PoolingMode = PoolingMode.SUM, device: str | device | int | None = None, bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING, uvm_non_rowwise_momentum: bool = False, use_experimental_tbe: bool = False, prefetch_pipeline: bool = False, stats_reporter_config: TBEStatsReporterConfig | None = None, table_names: List[str] | None = None, optimizer_state_dtypes: Dict[str, SparseType] | None = None, multipass_prefetch_config: MultiPassPrefetchConfig | None = None, global_weight_decay: GlobalWeightDecayDefinition | None = None, uvm_host_mapped: bool = False)[source]¶
Table Batched Embedding (TBE) operator. Looks up one or more embedding tables. The module is application for training. The backward operator is fused with optimizer. Thus, the embedding tables are updated during backward.
- Parameters:
embedding_specs (List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]) –
A list of embedding specifications. Each spec describes a specification of a physical embedding table. Each one is a tuple of number of embedding rows, embedding dimension (must be a multiple of 4), table placement (EmbeddingLocation), and compute device (ComputeDevice).
Available EmbeddingLocation options are
DEVICE = placing an embedding table in the GPU global memory (HBM)
MANAGED = placing an embedding in the unified virtual memory (accessible from both GPU and CPU)
MANAGED_CACHING = placing an embedding table in the unified virtual memory and using the GPU global memory (HBM) as a cache
HOST = placing an embedding table in the CPU memory (DRAM)
MTIA = placing an embedding table in the MTIA memory
Available ComputeDevice options are
CPU = performing table lookup on CPU
CUDA = performing table lookup on GPU
MTIA = performing table lookup on MTIA
feature_table_map (Optional[List[int]] = None) – An optional list that specifies feature-table mapping. feature_table_map[i] indicates the physical embedding table that feature i maps to.
cache_algorithm (CacheAlgorithm = CacheAlgorithm.LRU) –
The cache algorithm (used when EmbeddingLocation is set to MANAGED_CACHING). Options are
LRU = least recently used
LFU = least frequently used
cache_load_factor (float = 0.2) – A factor used for determining the cache capacity when EmbeddingLocation.MANAGED_CACHING is used. The cache capacity is cache_load_factor * the total number of rows in all embedding tables
cache_sets (int = 0) – The number of cache sets (used when EmbeddingLocation is set to MANAGED_CACHING)
cache_reserved_memory (float = 0.0) – The amount of memory reserved in HBM for non-cache purpose (used when EmbeddingLocation is set to MANAGED_CACHING).
cache_precision (SparseType = SparseType.FP32) – The data type of the cache (used when EmbeddingLocation is set to MANAGED_CACHING). Options are SparseType.FP32 and SparseType.FP16
weights_precision (SparseType = SparseType.FP32) – The data type of embedding tables (also known as weights). Options are SparseType.FP32 and SparseType.FP16
output_dtype (SparseType = SparseType.FP32) – The data type of an output tensor. Options are SparseType.FP32 and SparseType.FP16
enforce_hbm (bool = False) – If True, place all weights/momentums in HBM when using EmbeddingLocation.MANAGED_CACHING
optimizer (OptimType = OptimType.EXACT_SGD) –
An optimizer to use for embedding table update in the backward pass. Available OptimType options are
ADAM = Adam
EXACT_ADAGRAD = Adagrad
EXACT_ROWWISE_ADAGRAD = Rowwise-Aadagrad
EXACT_SGD = SGD
LAMB = Lamb
LARS_SGD = LARS-SGD
PARTIAL_ROWWISE_ADAM = Partial rowwise-Adam
PARTIAL_ROWWISE_LAMB = Partial rowwise-Lamb
ENSEMBLE_ROWWISE_ADAGRAD = Ensemble rowwise-Adagrad
EMAINPLACE_ROWWISE_ADAGRAD = Ema inplace rowwise-Adagrad
NONE = Not applying an optimizer update in the backward pass
and outputting a sparse weight gradient
record_cache_metrics (Optional[RecordCacheMetrics] = None) – Record a number of hits, a number of requests, etc if RecordCacheMetrics.record_cache_miss_counter is True and record the similar metrics table-wise if RecordCacheMetrics.record_tablewise_cache_miss is True
gather_uvm_cache_stats (Optional[bool] = False) – If True, collect the cache statistics when EmbeddingLocation is set to MANAGED_CACHING
stochastic_rounding (bool = True) – If True, apply stochastic rounding for weight type that is not SparseType.FP32
gradient_clipping (bool = False) – If True, apply gradient clipping
max_gradient (float = 1.0) – The value for gradient clipping
max_norm (float = 0.0) – The max norm value
learning_rate (float = 0.01) – The learning rate
eps (float = 1.0e-8) – The epsilon value used by Adagrad, LAMB, and Adam. Note that default is different from torch.nn.optim.Adagrad default of 1e-10
momentum (float = 0.9) – Momentum used by LARS-SGD
weight_decay (float = 0.0) –
Weight decay used by LARS-SGD, LAMB, ADAM, and rowwise-Adagrad.
EXACT_ADAGRAD, SGD, EXACT_SGD do not support weight decay
LAMB, ADAM, PARTIAL_ROWWISE_ADAM, PARTIAL_ROWWISE_LAMB, LARS_SGD support decoupled weight decay
EXACT_ROWWISE_ADAGRAD support both L2 and decoupled weight decay (via weight_decay_mode)
weight_decay_mode (WeightDecayMode = WeightDecayMode.NONE) – Weight decay mode. Options are WeightDecayMode.NONE, WeightDecayMode.L2, and WeightDecayMode.DECOUPLE
eta (float = 0.001) – The eta value used by LARS-SGD
beta1 (float = 0.9) – The beta1 value used by LAMB and ADAM
beta2 (float = 0.999) – The beta2 value used by LAMB and ADAM
ensemble_mode (Optional[EnsembleModeDefinition] = None) – Used by Ensemble Rowwise Adagrad
emainplace_mode (Optional[EmainplaceModeDefinition] = None) – Used by EMA in-place Rowwise Adagrad
counter_based_regularization (Optional[CounterBasedRegularizationDefinition] = None) – Used by Rowwise Adagrad
cowclip_regularization (Optional[CowClipDefinition] = None) – Used by Rowwise Adagrad
pooling_mode (PoolingMode = PoolingMode.SUM) –
Pooling mode. Available PoolingMode options are
SUM = Sum pooling
MEAN = Mean pooling
NONE = No pooling (sequence embedding)
device (Optional[Union[str, int, torch.device]] = None) – The current device to place tensors on
bounds_check_mode (BoundsCheckMode = BoundsCheckMode.WARNING) –
Input checking mode. Available BoundsCheckMode options are
NONE = skip bounds check
FATAL = throw an error when encountering an invalid index/offset
WARNING = print a warning message when encountering an invalid index/offset and fix it (setting an invalid index to zero and adjusting an invalid offset to be within the bound)
IGNORE = silently fix an invalid index/offset (setting an invalid index to zero and adjusting an invalid offset to be within the bound)
uvm_non_rowwise_momentum (bool = False) – If True, place non-rowwise momentum on the unified virtual memory
use_experimental_tbe (bool = False) – If True, use an optimized TBE implementation (TBE v2). Note that this is supported only on NVIDIA GPUs.
prefetch_pipeline (bool = False) – If True, enable cache prefetch pipeline when using EmbeddingLocation.MANAGED_CACHING. Currently only supports the LRU cache policy. If a separate stream is used for prefetch, the optional forward_stream arg of prefetch function must be set.
stats_reporter_config (Optional[TBEStatsReporterConfig] = None) – A config for TBE stats reporter
table_names (Optional[List[str]] = None) – A list of embedding table names in this TBE
optimizer_state_dtypes (Optional[Dict[str, SparseType]] = None) – A optimizer state data types dict. Keys are the optimizer state names and values are their corresponding types
multipass_prefetch_config (Optional[MultiPassPrefetchConfig] = None) – A config for multipass cache prefetching (when EmbeddingLocation.MANAGED_CACHING is used)
global_weight_decay (Optional[GlobalWeightDecayDefinition] = None) – A config for global weight decay
uvm_host_mapped (bool = False) – If True, allocate every UVM tensor using malloc + cudaHostRegister. Otherwise use cudaMallocManaged
- forward(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor | None = None, feature_requires_grad: Tensor | None = None, batch_size_per_feature_per_rank: List[List[int]] | None = None, total_unique_indices: int | None = None) Tensor [source]¶
The forward pass function that
Performs input bound checking
Generates necessary variable batch size embedding (VBE) metadata (if VBE is used)
Prefetches data from UVM to cache (if EmbeddingLocation.MANAGED_CACHING is used and the user has not explicitly prefetched data)
Performs the embedding table lookup by invoking a corresponding Autograd function (based on the chosen optimizer)
- Parameters:
indices (Tensor) – A 1D-tensor that contains indices to be looked up from all embedding table
offsets (Tensor) – A 1D-tensor that conatins offsets of indices. Shape (B * T + 1) where B = batch size and T = the number of features. offsets[t * B + b + 1] - offsets[t * B + b] is the length of bag b of feature t
per_sample_weights (Optional[Tensor]) – An optional 1D-float-tensor that contains per sample weights. If None, unweighted embedding lookup will be perform. Otherwise, weighted will be used. The length of this tensor must be the same as the length of the indices tensor. The value of per_sample_weights[i] will be used to multiply with every element in the looked up row indices[i], where 0 <= i < len(per_sample_weights).
feature_requires_grad (Optional[Tensor]) – An optional 1D-tensor for indicating if per_sample_weights requires gradient. The length of the tensor must be equal to the number of features
batch_size_per_feature_per_rank (Optional[List[List[int]]]) – An optional 2D-tensor that contains batch sizes for every rank and every feature. If None, TBE assumes that every feature has the same batch size and computes the batch size from the offsets shape. Otherwise, TBE assumes that different features can have different batch sizes and uses the variable batch size embedding look up mode (VBE). Shape (number of features, number of ranks). batch_size_per_feature_per_rank[f][r] represents the batch size of feature f and rank r
total_unique_indices (Optional[int]) – An optional integer that represents the total number of unique indices. This value must be set when using OptimType.NONE. This is because TBE requires this information for allocating the weight gradient tensor in the backward pass.
- Returns:
A 2D-tensor containing looked up data. Shape (B, total_D) where B = batch size and total_D = the sum of all embedding dimensions in the table
Example
>>> import torch >>> >>> from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( >>> EmbeddingLocation, >>> ) >>> from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( >>> SplitTableBatchedEmbeddingBagsCodegen, >>> ComputeDevice, >>> ) >>> >>> # Two tables >>> embedding_specs = [ >>> (3, 8, EmbeddingLocation.DEVICE, ComputeDevice.CUDA), >>> (5, 4, EmbeddingLocation.MANAGED, ComputeDevice.CUDA) >>> ] >>> >>> tbe = SplitTableBatchedEmbeddingBagsCodegen(embedding_specs) >>> tbe.init_embedding_weights_uniform(-1, 1) >>> >>> print(tbe.split_embedding_weights()) [tensor([[-0.9426, 0.7046, 0.4214, -0.0419, 0.1331, -0.7856, -0.8124, -0.2021], [-0.5771, 0.5911, -0.7792, -0.1068, -0.6203, 0.4813, -0.1677, 0.4790], [-0.5587, -0.0941, 0.5754, 0.3475, -0.8952, -0.1964, 0.0810, -0.4174]], device='cuda:0'), tensor([[-0.2513, -0.4039, -0.3775, 0.3273], [-0.5399, -0.0229, -0.1455, -0.8770], [-0.9520, 0.4593, -0.7169, 0.6307], [-0.1765, 0.8757, 0.8614, 0.2051], [-0.0603, -0.9980, -0.7958, -0.5826]], device='cuda:0')]
>>> # Batch size = 3 >>> indices = torch.tensor([0, 1, 2, 0, 1, 2, 0, 3, 1, 4, 2, 0, 0], >>> device="cuda", >>> dtype=torch.long) >>> offsets = torch.tensor([0, 2, 5, 7, 9, 12, 13], >>> device="cuda", >>> dtype=torch.long) >>> >>> output = tbe(indices, offsets) >>> >>> # Batch size = 3, total embedding dimension = 12 >>> print(output.shape) torch.Size([3, 12])
>>> print(output) tensor([[-1.5197, 1.2957, -0.3578, -0.1487, -0.4873, -0.3044, -0.9801, 0.2769, -0.7164, 0.8528, 0.7159, -0.6719], [-2.0784, 1.2016, 0.2176, 0.1988, -1.3825, -0.5008, -0.8991, -0.1405, -1.2637, -0.9427, -1.8902, 0.3754], [-1.5013, 0.6105, 0.9968, 0.3057, -0.7621, -0.9821, -0.7314, -0.6195, -0.2513, -0.4039, -0.3775, 0.3273]], device='cuda:0', grad_fn=<CppNode<SplitLookupFunction_sgd_Op>>)
- set_learning_rate(lr: float) None [source]¶
Sets the learning rate.
- Parameters:
lr (float) – The learning rate value to set to
- set_optimizer_step(step: int) None [source]¶
Sets the optimizer step.
- Parameters:
step (int) – The step value to set to
- split_embedding_weights() List[Tensor] [source]¶
Returns a list of embedding weights (view), split by table
- Returns:
A list of weights. Length = the number of tables
- split_optimizer_states() List[List[Tensor]] [source]¶
Returns a list of optimizer states (view), split by table
- Returns:
A list of list of states. Shape = (the number of tables, the number of states).
The following shows the list of states (in the returned order) for each optimizer:
ADAM: momentum1, momentum2
EXACT_ADAGRAD: momentum1
EXACT_ROWWISE_ADAGRAD: momentum1 (rowwise), prev_iter (rowwise; only when using WeightDecayMode = COUNTER or COWCLIP or global_weight_decay is not None), row_counter (rowwise; only when using WeightDecayMode = COUNTER or COWCLIP)
EXACT_SGD: no states
LAMB: momentum1, momentum2
LARS_SGD: momentum1
PARTIAL_ROWWISE_ADAM: momentum1, momentum2 (rowwise)
PARTIAL_ROWWISE_LAMB: momentum1, momentum2 (rowwise)
ENSEMBLE_ROWWISE_ADAGRAD: momentum1 (rowwise), momentum2
NONE: no states (throwing an error)