Modules¶
Standard TorchRec modules represent collections of embedding tables:
EmbeddingBagCollection
is a collection oftorch.nn.EmbeddingBag
EmbeddingCollection
is a collection oftorch.nn.Embedding
These modules are constructed through standardized config classes:
EmbeddingBagConfig
forEmbeddingBagCollection
EmbeddingConfig
forEmbeddingCollection
- class torchrec.modules.embedding_configs.EmbeddingBagConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: ~torchrec.types.DataType = DataType.FP32, feature_names: ~typing.List[str] = <factory>, weight_init_max: ~typing.Optional[float] = None, weight_init_min: ~typing.Optional[float] = None, num_embeddings_post_pruning: ~typing.Optional[int] = None, init_fn: ~typing.Optional[~typing.Callable[[~torch.Tensor], ~typing.Optional[~torch.Tensor]]] = None, need_pos: bool = False, pooling: ~torchrec.modules.embedding_configs.PoolingType = PoolingType.SUM)¶
Bases:
BaseEmbeddingConfig
EmbeddingBagConfig is a dataclass that represents a single embedding table, where outputs are meant to be pooled.
- Parameters:
pooling (PoolingType) – pooling type.
- class torchrec.modules.embedding_configs.EmbeddingConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: ~torchrec.types.DataType = DataType.FP32, feature_names: ~typing.List[str] = <factory>, weight_init_max: ~typing.Optional[float] = None, weight_init_min: ~typing.Optional[float] = None, num_embeddings_post_pruning: ~typing.Optional[int] = None, init_fn: ~typing.Optional[~typing.Callable[[~torch.Tensor], ~typing.Optional[~torch.Tensor]]] = None, need_pos: bool = False)¶
Bases:
BaseEmbeddingConfig
EmbeddingConfig is a dataclass that represents a single embedding table.
- class torchrec.modules.embedding_configs.BaseEmbeddingConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: ~torchrec.types.DataType = DataType.FP32, feature_names: ~typing.List[str] = <factory>, weight_init_max: ~typing.Optional[float] = None, weight_init_min: ~typing.Optional[float] = None, num_embeddings_post_pruning: ~typing.Optional[int] = None, init_fn: ~typing.Optional[~typing.Callable[[~torch.Tensor], ~typing.Optional[~torch.Tensor]]] = None, need_pos: bool = False)¶
Base class for embedding configs.
- Parameters:
num_embeddings (int) – number of embeddings.
embedding_dim (int) – embedding dimension.
name (str) – name of the embedding table.
data_type (DataType) – data type of the embedding table.
feature_names (List[str]) – list of feature names.
weight_init_max (Optional[float]) – max value for weight initialization.
weight_init_min (Optional[float]) – min value for weight initialization.
num_embeddings_post_pruning (Optional[int]) – number of embeddings after pruning for inference. If None, no pruning is applied.
init_fn (Optional[Callable[[torch.Tensor], Optional[torch.Tensor]]]) – init function for embedding weights.
need_pos (bool) – whether table is position weighted.
- class torchrec.modules.embedding_modules.EmbeddingBagCollection(tables: List[EmbeddingBagConfig], is_weighted: bool = False, device: Optional[device] = None)¶
EmbeddingBagCollection represents a collection of pooled embeddings (EmbeddingBags).
Note
EmbeddingBagCollection is an unsharded module and is not performance optimized. For performance-sensitive scenarios, consider using the sharded version ShardedEmbeddingBagCollection.
It processes sparse data in the form of KeyedJaggedTensor with values of the form [F X B X L] where:
F: features (keys)
B: batch size
L: length of sparse features (jagged)
and outputs a KeyedTensor with values of the form [B * (F * D)] where:
F: features (keys)
D: each feature’s (key’s) embedding dimension
B: batch size
- Parameters:
tables (List[EmbeddingBagConfig]) – list of embedding tables.
is_weighted (bool) – whether input KeyedJaggedTensor is weighted.
device (Optional[torch.device]) – default compute device.
Example:
table_0 = EmbeddingBagConfig( name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] ) table_1 = EmbeddingBagConfig( name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] ) ebc = EmbeddingBagCollection(tables=[table_0, table_1]) # 0 1 2 <-- batch # "f1" [0,1] None [2] # "f2" [3] [4] [5,6,7] # ^ # feature features = KeyedJaggedTensor( keys=["f1", "f2"], values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), ) pooled_embeddings = ebc(features) print(pooled_embeddings.values()) tensor([[-0.8899, -0.1342, -1.9060, -0.0905, -0.2814, -0.9369, -0.7783], [ 0.0000, 0.0000, 0.0000, 0.1598, 0.0695, 1.3265, -0.1011], [-0.4256, -1.1846, -2.1648, -1.0893, 0.3590, -1.9784, -0.7681]], grad_fn=<CatBackward0>) print(pooled_embeddings.keys()) ['f1', 'f2'] print(pooled_embeddings.offset_per_key()) tensor([0, 3, 7])
- property device: device¶
Returns: torch.device: The compute device.
- embedding_bag_configs() List[EmbeddingBagConfig] ¶
- Returns:
The embedding bag configs.
- Return type:
List[EmbeddingBagConfig]
- forward(features: KeyedJaggedTensor) KeyedTensor ¶
Run the EmbeddingBagCollection forward pass. This method takes in a KeyedJaggedTensor and returns a KeyedTensor, which is the result of pooling the embeddings for each feature.
- Parameters:
features (KeyedJaggedTensor) – Input KJT
- Returns:
KeyedTensor
- is_weighted() bool ¶
- Returns:
Whether the EmbeddingBagCollection is weighted.
- Return type:
bool
- reset_parameters() None ¶
Reset the parameters of the EmbeddingBagCollection. Parameter values are intiialized based on the init_fn of each EmbeddingBagConfig if it exists.
- class torchrec.modules.embedding_modules.EmbeddingCollection(tables: List[EmbeddingConfig], device: Optional[device] = None, need_indices: bool = False)¶
EmbeddingCollection represents a collection of non-pooled embeddings.
Note
EmbeddingCollection is an unsharded module and is not performance optimized. For performance-sensitive scenarios, consider using the sharded version ShardedEmbeddingCollection.
It processes sparse data in the form of KeyedJaggedTensor of the form [F X B X L] where:
F: features (keys)
B: batch size
L: length of sparse features (variable)
and outputs Dict[feature (key), JaggedTensor]. Each JaggedTensor contains values of the form (B * L) X D where:
B: batch size
L: length of sparse features (jagged)
D: each feature’s (key’s) embedding dimension and lengths are of the form L
- Parameters:
tables (List[EmbeddingConfig]) – list of embedding tables.
device (Optional[torch.device]) – default compute device.
need_indices (bool) – if we need to pass indices to the final lookup dict.
Example:
e1_config = EmbeddingConfig( name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] ) e2_config = EmbeddingConfig( name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"] ) ec = EmbeddingCollection(tables=[e1_config, e2_config]) # 0 1 2 <-- batch # 0 [0,1] None [2] # 1 [3] [4] [5,6,7] # ^ # feature features = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f2"], values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), ) feature_embeddings = ec(features) print(feature_embeddings['f2'].values()) tensor([[-0.2050, 0.5478, 0.6054], [ 0.7352, 0.3210, -3.0399], [ 0.1279, -0.1756, -0.4130], [ 0.7519, -0.4341, -0.0499], [ 0.9329, -1.0697, -0.8095]], grad_fn=<EmbeddingBackward>)
- property device: device¶
Returns: torch.device: The compute device.
- embedding_configs() List[EmbeddingConfig] ¶
- Returns:
The embedding configs.
- Return type:
List[EmbeddingConfig]
- embedding_dim() int ¶
- Returns:
The embedding dimension.
- Return type:
int
- embedding_names_by_table() List[List[str]] ¶
- Returns:
The embedding names by table.
- Return type:
List[List[str]]
- forward(features: KeyedJaggedTensor) Dict[str, JaggedTensor] ¶
Run the EmbeddingBagCollection forward pass. This method takes in a KeyedJaggedTensor and returns a Dict[str, JaggedTensor], which is the result of the individual embeddings for each feature.
- Parameters:
features (KeyedJaggedTensor) – KJT of form [F X B X L].
- Returns:
Dict[str, JaggedTensor]
- need_indices() bool ¶
- Returns:
Whether the EmbeddingCollection needs indices.
- Return type:
bool
- reset_parameters() None ¶
Reset the parameters of the EmbeddingCollection. Parameter values are intiialized based on the init_fn of each EmbeddingConfig if it exists.