Shortcuts

Modules

Standard TorchRec modules represent collections of embedding tables:

  • EmbeddingBagCollection is a collection of torch.nn.EmbeddingBag

  • EmbeddingCollection is a collection of torch.nn.Embedding

These modules are constructed through standardized config classes:

  • EmbeddingBagConfig for EmbeddingBagCollection

  • EmbeddingConfig for EmbeddingCollection

class torchrec.modules.embedding_configs.EmbeddingBagConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: ~torchrec.types.DataType = DataType.FP32, feature_names: ~typing.List[str] = <factory>, weight_init_max: ~typing.Optional[float] = None, weight_init_min: ~typing.Optional[float] = None, num_embeddings_post_pruning: ~typing.Optional[int] = None, init_fn: ~typing.Optional[~typing.Callable[[~torch.Tensor], ~typing.Optional[~torch.Tensor]]] = None, need_pos: bool = False, pooling: ~torchrec.modules.embedding_configs.PoolingType = PoolingType.SUM)

Bases: BaseEmbeddingConfig

EmbeddingBagConfig is a dataclass that represents a single embedding table, where outputs are meant to be pooled.

Parameters:

pooling (PoolingType) – pooling type.

class torchrec.modules.embedding_configs.EmbeddingConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: ~torchrec.types.DataType = DataType.FP32, feature_names: ~typing.List[str] = <factory>, weight_init_max: ~typing.Optional[float] = None, weight_init_min: ~typing.Optional[float] = None, num_embeddings_post_pruning: ~typing.Optional[int] = None, init_fn: ~typing.Optional[~typing.Callable[[~torch.Tensor], ~typing.Optional[~torch.Tensor]]] = None, need_pos: bool = False)

Bases: BaseEmbeddingConfig

EmbeddingConfig is a dataclass that represents a single embedding table.

class torchrec.modules.embedding_configs.BaseEmbeddingConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: ~torchrec.types.DataType = DataType.FP32, feature_names: ~typing.List[str] = <factory>, weight_init_max: ~typing.Optional[float] = None, weight_init_min: ~typing.Optional[float] = None, num_embeddings_post_pruning: ~typing.Optional[int] = None, init_fn: ~typing.Optional[~typing.Callable[[~torch.Tensor], ~typing.Optional[~torch.Tensor]]] = None, need_pos: bool = False)

Base class for embedding configs.

Parameters:
  • num_embeddings (int) – number of embeddings.

  • embedding_dim (int) – embedding dimension.

  • name (str) – name of the embedding table.

  • data_type (DataType) – data type of the embedding table.

  • feature_names (List[str]) – list of feature names.

  • weight_init_max (Optional[float]) – max value for weight initialization.

  • weight_init_min (Optional[float]) – min value for weight initialization.

  • num_embeddings_post_pruning (Optional[int]) – number of embeddings after pruning for inference. If None, no pruning is applied.

  • init_fn (Optional[Callable[[torch.Tensor], Optional[torch.Tensor]]]) – init function for embedding weights.

  • need_pos (bool) – whether table is position weighted.

class torchrec.modules.embedding_modules.EmbeddingBagCollection(tables: List[EmbeddingBagConfig], is_weighted: bool = False, device: Optional[device] = None)

EmbeddingBagCollection represents a collection of pooled embeddings (EmbeddingBags).

Note

EmbeddingBagCollection is an unsharded module and is not performance optimized. For performance-sensitive scenarios, consider using the sharded version ShardedEmbeddingBagCollection.

It processes sparse data in the form of KeyedJaggedTensor with values of the form [F X B X L] where:

  • F: features (keys)

  • B: batch size

  • L: length of sparse features (jagged)

and outputs a KeyedTensor with values of the form [B * (F * D)] where:

  • F: features (keys)

  • D: each feature’s (key’s) embedding dimension

  • B: batch size

Parameters:
  • tables (List[EmbeddingBagConfig]) – list of embedding tables.

  • is_weighted (bool) – whether input KeyedJaggedTensor is weighted.

  • device (Optional[torch.device]) – default compute device.

Example:

table_0 = EmbeddingBagConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
table_1 = EmbeddingBagConfig(
    name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"]
)

ebc = EmbeddingBagCollection(tables=[table_0, table_1])

#        0       1        2  <-- batch
# "f1"   [0,1] None    [2]
# "f2"   [3]    [4]    [5,6,7]
#  ^
# feature

features = KeyedJaggedTensor(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
    offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)

pooled_embeddings = ebc(features)
print(pooled_embeddings.values())
tensor([[-0.8899, -0.1342, -1.9060, -0.0905, -0.2814, -0.9369, -0.7783],
    [ 0.0000,  0.0000,  0.0000,  0.1598,  0.0695,  1.3265, -0.1011],
    [-0.4256, -1.1846, -2.1648, -1.0893,  0.3590, -1.9784, -0.7681]],
    grad_fn=<CatBackward0>)
print(pooled_embeddings.keys())
['f1', 'f2']
print(pooled_embeddings.offset_per_key())
tensor([0, 3, 7])
property device: device

Returns: torch.device: The compute device.

embedding_bag_configs() List[EmbeddingBagConfig]
Returns:

The embedding bag configs.

Return type:

List[EmbeddingBagConfig]

forward(features: KeyedJaggedTensor) KeyedTensor

Run the EmbeddingBagCollection forward pass. This method takes in a KeyedJaggedTensor and returns a KeyedTensor, which is the result of pooling the embeddings for each feature.

Parameters:

features (KeyedJaggedTensor) – Input KJT

Returns:

KeyedTensor

is_weighted() bool
Returns:

Whether the EmbeddingBagCollection is weighted.

Return type:

bool

reset_parameters() None

Reset the parameters of the EmbeddingBagCollection. Parameter values are intiialized based on the init_fn of each EmbeddingBagConfig if it exists.

class torchrec.modules.embedding_modules.EmbeddingCollection(tables: List[EmbeddingConfig], device: Optional[device] = None, need_indices: bool = False)

EmbeddingCollection represents a collection of non-pooled embeddings.

Note

EmbeddingCollection is an unsharded module and is not performance optimized. For performance-sensitive scenarios, consider using the sharded version ShardedEmbeddingCollection.

It processes sparse data in the form of KeyedJaggedTensor of the form [F X B X L] where:

  • F: features (keys)

  • B: batch size

  • L: length of sparse features (variable)

and outputs Dict[feature (key), JaggedTensor]. Each JaggedTensor contains values of the form (B * L) X D where:

  • B: batch size

  • L: length of sparse features (jagged)

  • D: each feature’s (key’s) embedding dimension and lengths are of the form L

Parameters:
  • tables (List[EmbeddingConfig]) – list of embedding tables.

  • device (Optional[torch.device]) – default compute device.

  • need_indices (bool) – if we need to pass indices to the final lookup dict.

Example:

e1_config = EmbeddingConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
e2_config = EmbeddingConfig(
    name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"]
)

ec = EmbeddingCollection(tables=[e1_config, e2_config])

#     0       1        2  <-- batch
# 0   [0,1] None    [2]
# 1   [3]    [4]    [5,6,7]
# ^
# feature

features = KeyedJaggedTensor.from_offsets_sync(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
    offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)
feature_embeddings = ec(features)
print(feature_embeddings['f2'].values())
tensor([[-0.2050,  0.5478,  0.6054],
[ 0.7352,  0.3210, -3.0399],
[ 0.1279, -0.1756, -0.4130],
[ 0.7519, -0.4341, -0.0499],
[ 0.9329, -1.0697, -0.8095]], grad_fn=<EmbeddingBackward>)
property device: device

Returns: torch.device: The compute device.

embedding_configs() List[EmbeddingConfig]
Returns:

The embedding configs.

Return type:

List[EmbeddingConfig]

embedding_dim() int
Returns:

The embedding dimension.

Return type:

int

embedding_names_by_table() List[List[str]]
Returns:

The embedding names by table.

Return type:

List[List[str]]

forward(features: KeyedJaggedTensor) Dict[str, JaggedTensor]

Run the EmbeddingBagCollection forward pass. This method takes in a KeyedJaggedTensor and returns a Dict[str, JaggedTensor], which is the result of the individual embeddings for each feature.

Parameters:

features (KeyedJaggedTensor) – KJT of form [F X B X L].

Returns:

Dict[str, JaggedTensor]

need_indices() bool
Returns:

Whether the EmbeddingCollection needs indices.

Return type:

bool

reset_parameters() None

Reset the parameters of the EmbeddingCollection. Parameter values are intiialized based on the init_fn of each EmbeddingConfig if it exists.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources