Shortcuts

torchrec.modules

Torchrec Common Modules

The torchrec modules contain a collection of various modules.

These modules include:
  • extensions of nn.Embedding and nn.EmbeddingBag, called EmbeddingBagCollection and EmbeddingCollection respectively.

  • established modules such as DeepFM and CrossNet.

  • common module patterns such as MLP and SwishLayerNorm.

  • custom modules for TorchRec such as PositionWeightedModule and LazyModuleExtensionMixin.

  • EmbeddingTower and EmbeddingTowerCollection, logical “tower” of embeddings passed to provided interaction module.

torchrec.modules.activation

Activation Modules

class torchrec.modules.activation.SwishLayerNorm(input_dims: Union[int, List[int], Size], device: Optional[device] = None)

Bases: Module

Applies the Swish function with layer normalization: Y = X * Sigmoid(LayerNorm(X)).

Parameters:
  • input_dims (Union[int, List[int], torch.Size]) – dimensions to normalize over. If an input tensor has shape [batch_size, d1, d2, d3], setting input_dim=[d2, d3] will do the layer normalization on last two dimensions.

  • device (Optional[torch.device]) – default compute device.

Example:

sln = SwishLayerNorm(100)
forward(input: Tensor) Tensor
Parameters:

input (torch.Tensor) – an input tensor.

Returns:

an output tensor.

Return type:

torch.Tensor

training: bool

torchrec.modules.crossnet

CrossNet API

class torchrec.modules.crossnet.CrossNet(in_features: int, num_layers: int)

Bases: Module

Cross Network:

Cross Net is a stack of “crossing” operations on a tensor of shape \((*, N)\) to the same shape, effectively creating \(N\) learnable polynomical functions over the input tensor.

In this module, the crossing operations are defined based on a full rank matrix (NxN), such that the crossing effect can cover all bits on each layer. On each layer l, the tensor is transformed into:

\[x_{l+1} = x_0 * (W_l \cdot x_l + b_l) + x_l\]

where \(W_l\) is a square matrix \((NxN)\), \(*\) means element-wise multiplication, \(\cdot\) means matrix multiplication.

Parameters:
  • in_features (int) – the dimension of the input.

  • num_layers (int) – the number of layers in the module.

Example:

batch_size = 3
num_layers = 2
in_features = 10
input = torch.randn(batch_size, in_features)
dcn = CrossNet(num_layers=num_layers)
output = dcn(input)
forward(input: Tensor) Tensor
Parameters:

input (torch.Tensor) – tensor with shape [batch_size, in_features].

Returns:

tensor with shape [batch_size, in_features].

Return type:

torch.Tensor

training: bool
class torchrec.modules.crossnet.LowRankCrossNet(in_features: int, num_layers: int, low_rank: int = 1)

Bases: Module

Low Rank Cross Net is a highly efficient cross net. Instead of using full rank cross matrices (NxN) at each layer, it will use two kernels \(W (N x r)\) and \(V (r x N)\), where r << N, to simplify the matrix multiplication.

On each layer l, the tensor is transformed into:

\[x_{l+1} = x_0 * (W_l \cdot (V_l \cdot x_l) + b_l) + x_l\]

where \(W_l\) is either a vector, \(*\) means element-wise multiplication, and \(\cdot\) means matrix multiplication.

Note

Rank r should be chosen smartly. Usually, we expect r < N/2 to have computational savings; we should expect \(r ~= N/4\) to preserve the accuracy of the full rank cross net.

Parameters:
  • in_features (int) – the dimension of the input.

  • num_layers (int) – the number of layers in the module.

  • low_rank (int) – the rank setup of the cross matrix (default = 1). Value must be always >= 1.

Example:

batch_size = 3
num_layers = 2
in_features = 10
input = torch.randn(batch_size, in_features)
dcn = LowRankCrossNet(num_layers=num_layers, low_rank=3)
output = dcn(input)
forward(input: Tensor) Tensor
Parameters:

input (torch.Tensor) – tensor with shape [batch_size, in_features].

Returns:

tensor with shape [batch_size, in_features].

Return type:

torch.Tensor

training: bool
class torchrec.modules.crossnet.LowRankMixtureCrossNet(in_features: int, num_layers: int, num_experts: int = 1, low_rank: int = 1, activation: ~typing.Union[~torch.nn.modules.module.Module, ~typing.Callable[[~torch.Tensor], ~torch.Tensor]] = <built-in method relu of type object>)

Bases: Module

Low Rank Mixture Cross Net is a DCN V2 implementation from the paper:

LowRankMixtureCrossNet defines the learnable crossing parameter per layer as a low-rank matrix \((N*r)\) together with mixture of experts. Compared to LowRankCrossNet, instead of relying on one single expert to learn feature crosses, this module leverages such \(K\) experts; each learning feature interactions in different subspaces, and adaptively combining the learned crosses using a gating mechanism that depends on input \(x\)..

On each layer l, the tensor is transformed into:

\[x_{l+1} = MoE({expert_i : i \in K_{experts}}) + x_l\]

and each \(expert_i\) is defined as:

\[expert_i = x_0 * (U_{li} \cdot g(C_{li} \cdot g(V_{li} \cdot x_l)) + b_l)\]

where \(U_{li} (N, r)\), \(C_{li} (r, r)\) and \(V_{li} (r, N)\) are low-rank matrices, \(*\) means element-wise multiplication, \(x\) means matrix multiplication, and \(g()\) is the non-linear activation function.

When num_expert is 1, the gate evaluation and MOE will be skipped to save computation.

Parameters:
  • in_features (int) – the dimension of the input.

  • num_layers (int) – the number of layers in the module.

  • low_rank (int) – the rank setup of the cross matrix (default = 1). Value must be always >= 1

  • activation (Union[torch.nn.Module, Callable[[torch.Tensor], torch.Tensor]]) – the non-linear activation function, used in defining experts. Default is relu.

Example:

batch_size = 3
num_layers = 2
in_features = 10
input = torch.randn(batch_size, in_features)
dcn = LowRankCrossNet(num_layers=num_layers, num_experts=5, low_rank=3)
output = dcn(input)
forward(input: Tensor) Tensor
Parameters:

input (torch.Tensor) – tensor with shape [batch_size, in_features].

Returns:

tensor with shape [batch_size, in_features].

Return type:

torch.Tensor

training: bool
class torchrec.modules.crossnet.VectorCrossNet(in_features: int, num_layers: int)

Bases: Module

Vector Cross Network can be refered as DCN-V1.

It is also a specialized low rank cross net, where rank=1. In this version, on each layer, instead of keeping two kernels W and V, we only keep one vector kernel W (Nx1). We use the dot operation to compute the “crossing” effect of the features, thus saving two matrix multiplications to further reduce computational cost and cut the number of learnable parameters.

On each layer l, the tensor is transformed into

\[x_{l+1} = x_0 * (W_l . x_l + b_l) + x_l\]

where \(W_l\) is either a vector, \(*\) means element-wise multiplication; \(.\) means dot operations.

Parameters:
  • in_features (int) – the dimension of the input.

  • num_layers (int) – the number of layers in the module.

Example:

batch_size = 3
num_layers = 2
in_features = 10
input = torch.randn(batch_size, in_features)
dcn = VectorCrossNet(num_layers=num_layers)
output = dcn(input)
forward(input: Tensor) Tensor
Parameters:

input (torch.Tensor) – tensor with shape [batch_size, in_features].

Returns:

tensor with shape [batch_size, in_features].

Return type:

torch.Tensor

training: bool

torchrec.modules.deepfm

Deep Factorization-Machine Modules

The following modules are based off the Deep Factorization-Machine (DeepFM) paper

  • Class DeepFM implents the DeepFM Framework

  • Class FactorizationMachine implements FM as noted in the above paper.

class torchrec.modules.deepfm.DeepFM(dense_module: Module)

Bases: Module

This is the DeepFM module

This module does not cover the end-end functionality of the published paper. Instead, it covers only the deep component of the publication. It is used to learn high-order feature interactions. If low-order feature interactions should be learnt, please use FactorizationMachine module instead, which will share the same embedding input of this module.

To support modeling flexibility, we customize the key components as:

  • Different from the public paper, we change the input from raw sparse features to embeddings of the features. It allows flexibility in embedding dimensions and the number of embeddings, as long as all embedding tensors have the same batch size.

  • On top of the public paper, we allow users to customize the hidden layer to be any module, not limited to just MLP.

The general architecture of the module is like:

        1 x 10                  output
         /|\
          |                     pass into `dense_module`
          |
        1 x 90
         /|\
          |                     concat
          |
1 x 20, 1 x 30, 1 x 40          list of embeddings
Parameters:

dense_module (nn.Module) – any customized module that can be used (such as MLP) in DeepFM. The in_features of this module must be equal to the element counts. For example, if the input embedding is [randn(3, 2, 3), randn(3, 4, 5)], the in_features should be: 2*3+4*5.

Example:

import torch
from torchrec.fb.modules.deepfm import DeepFM
from torchrec.fb.modules.mlp import LazyMLP
batch_size = 3
output_dim = 30
# the input embedding are a torch.Tensor of [batch_size, num_embeddings, embedding_dim]
input_embeddings = [
    torch.randn(batch_size, 2, 64),
    torch.randn(batch_size, 2, 32),
]
dense_module = nn.Linear(192, output_dim)
deepfm = DeepFM(dense_module=dense_module)
deep_fm_output = deepfm(embeddings=input_embeddings)
forward(embeddings: List[Tensor]) Tensor
Parameters:

embeddings (List[torch.Tensor]) –

The list of all embeddings (e.g. dense, common_sparse, specialized_sparse, embedding_features, raw_embedding_features) in the shape of:

(batch_size, num_embeddings, embedding_dim)

For the ease of operation, embeddings that have the same embedding dimension have the option to be stacked into a single tensor. For example, when we have 1 trained embedding with dimension=32, 5 native embeddings with dimension=64, and 3 dense features with dimension=16, we can prepare the embeddings list to be the list of:

tensor(B, 1, 32) (trained_embedding with num_embeddings=1, embedding_dim=32)
tensor(B, 5, 64) (native_embedding with num_embeddings=5, embedding_dim=64)
tensor(B, 3, 16) (dense_features with num_embeddings=3, embedding_dim=32)

Note

batch_size of all input tensors need to be identical.

Returns:

output of dense_module with flattened and concatenated embeddings as input.

Return type:

torch.Tensor

training: bool
class torchrec.modules.deepfm.FactorizationMachine

Bases: Module

This is the Factorization Machine module, mentioned in the DeepFM paper:

This module does not cover the end-end functionality of the published paper. Instead, it covers only the FM part of the publication, and is used to learn 2nd-order feature interactions.

To support modeling flexibility, we customize the key components as different from the public paper:

We change the input from raw sparse features to embeddings of the features. This allows flexibility in embedding dimensions and the number of embeddings, as long as all embedding tensors have the same batch size.

The general architecture of the module is like:

        1 x 10                  output
         /|\
          |                     pass into `dense_module`
          |
        1 x 90
         /|\
          |                     concat
          |
1 x 20, 1 x 30, 1 x 40          list of embeddings

Example:

batch_size = 3
# the input embedding are in torch.Tensor of [batch_size, num_embeddings, embedding_dim]
input_embeddings = [
    torch.randn(batch_size, 2, 64),
    torch.randn(batch_size, 2, 32),
]
fm = FactorizationMachine()
output = fm(embeddings=input_embeddings)
forward(embeddings: List[Tensor]) Tensor
Parameters:

embeddings (List[torch.Tensor]) –

The list of all embeddings (e.g. dense, common_sparse, specialized_sparse, embedding_features, raw_embedding_features) in the shape of:

(batch_size, num_embeddings, embedding_dim)

For the ease of operation, embeddings that have the same embedding dimension have the option to be stacked into a single tensor. For example, when we have 1 trained embedding with dimension=32, 5 native embeddings with dimension=64, and 3 dense features with dimension=16, we can prepare the embeddings list to be the list of:

tensor(B, 1, 32) (trained_embedding with num_embeddings=1, embedding_dim=32)
tensor(B, 5, 64) (native_embedding with num_embeddings=5, embedding_dim=64)
tensor(B, 3, 16) (dense_features with num_embeddings=3, embedding_dim=32)

Note

batch_size of all input tensors need to be identical.

Returns:

output of fm with flattened and concatenated embeddings as input. Expected to be [B, 1].

Return type:

torch.Tensor

training: bool

torchrec.modules.embedding_configs

class torchrec.modules.embedding_configs.BaseEmbeddingConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: torchrec.distributed.types.DataType = <DataType.FP32: 'FP32'>, feature_names: List[str] = <factory>, weight_init_max: Union[float, NoneType] = None, weight_init_min: Union[float, NoneType] = None, init_fn: Union[Callable[[torch.Tensor], Union[torch.Tensor, NoneType]], NoneType] = None, need_pos: bool = False)

Bases: object

data_type: DataType = 'FP32'
embedding_dim: int
feature_names: List[str]
get_weight_init_max() float
get_weight_init_min() float
init_fn: Optional[Callable[[Tensor], Optional[Tensor]]] = None
name: str = ''
need_pos: bool = False
num_embeddings: int
num_features() int
weight_init_max: Optional[float] = None
weight_init_min: Optional[float] = None
class torchrec.modules.embedding_configs.EmbeddingBagConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: torchrec.distributed.types.DataType = <DataType.FP32: 'FP32'>, feature_names: List[str] = <factory>, weight_init_max: Union[float, NoneType] = None, weight_init_min: Union[float, NoneType] = None, init_fn: Union[Callable[[torch.Tensor], Union[torch.Tensor, NoneType]], NoneType] = None, need_pos: bool = False, pooling: torchrec.modules.embedding_configs.PoolingType = <PoolingType.SUM: 'SUM'>)

Bases: BaseEmbeddingConfig

pooling: PoolingType = 'SUM'
class torchrec.modules.embedding_configs.EmbeddingConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: torchrec.distributed.types.DataType = <DataType.FP32: 'FP32'>, feature_names: List[str] = <factory>, weight_init_max: Union[float, NoneType] = None, weight_init_min: Union[float, NoneType] = None, init_fn: Union[Callable[[torch.Tensor], Union[torch.Tensor, NoneType]], NoneType] = None, need_pos: bool = False)

Bases: BaseEmbeddingConfig

embedding_dim: int
feature_names: List[str]
num_embeddings: int
class torchrec.modules.embedding_configs.EmbeddingTableConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: torchrec.distributed.types.DataType = <DataType.FP32: 'FP32'>, feature_names: List[str] = <factory>, weight_init_max: Union[float, NoneType] = None, weight_init_min: Union[float, NoneType] = None, init_fn: Union[Callable[[torch.Tensor], Union[torch.Tensor, NoneType]], NoneType] = None, need_pos: bool = False, pooling: torchrec.modules.embedding_configs.PoolingType = <PoolingType.SUM: 'SUM'>, is_weighted: bool = False, has_feature_processor: bool = False, embedding_names: List[str] = <factory>)

Bases: BaseEmbeddingConfig

embedding_names: List[str]
has_feature_processor: bool = False
is_weighted: bool = False
pooling: PoolingType = 'SUM'
class torchrec.modules.embedding_configs.PoolingType(value)

Bases: Enum

An enumeration.

MEAN = 'MEAN'
NONE = 'NONE'
SUM = 'SUM'
class torchrec.modules.embedding_configs.TrecQuantConfig(activation, weight, per_table_weight_dtype)

Bases: tuple

activation: PlaceholderObserver

Alias for field number 0

per_table_weight_dtype: Optional[Dict[str, dtype]]

Alias for field number 2

weight: PlaceholderObserver

Alias for field number 1

torchrec.modules.embedding_configs.data_type_to_dtype(data_type: DataType) dtype
torchrec.modules.embedding_configs.data_type_to_sparse_type(data_type: DataType) SparseType
torchrec.modules.embedding_configs.dtype_to_data_type(dtype: dtype) DataType
torchrec.modules.embedding_configs.pooling_type_to_pooling_mode(pooling_type: PoolingType) PoolingMode
torchrec.modules.embedding_configs.pooling_type_to_str(pooling_type: PoolingType) str
torchrec.modules.embedding_configs.to_fbgemm_bounds_check_mode(bounds_check_mode: BoundsCheckMode) BoundsCheckMode
torchrec.modules.embedding_configs.to_fbgemm_cache_algorithm(cache_algorithm: CacheAlgorithm) CacheAlgorithm

torchrec.modules.embedding_modules

class torchrec.modules.embedding_modules.EmbeddingBagCollection(tables: List[EmbeddingBagConfig], is_weighted: bool = False, device: Optional[device] = None)

Bases: EmbeddingBagCollectionInterface

EmbeddingBagCollection represents a collection of pooled embeddings (EmbeddingBags).

It processes sparse data in the form of KeyedJaggedTensor with values of the form [F X B X L] where:

  • F: features (keys)

  • B: batch size

  • L: length of sparse features (jagged)

and outputs a KeyedTensor with values of the form [B * (F * D)] where:

  • F: features (keys)

  • D: each feature’s (key’s) embedding dimension

  • B: batch size

Parameters:
  • tables (List[EmbeddingBagConfig]) – list of embedding tables.

  • is_weighted (bool) – whether input KeyedJaggedTensor is weighted.

  • device (Optional[torch.device]) – default compute device.

Example:

table_0 = EmbeddingBagConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
table_1 = EmbeddingBagConfig(
    name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"]
)

ebc = EmbeddingBagCollection(tables=[table_0, table_1])

#        0       1        2  <-- batch
# "f1"   [0,1] None    [2]
# "f2"   [3]    [4]    [5,6,7]
#  ^
# feature

features = KeyedJaggedTensor(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
    offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)

pooled_embeddings = ebc(features)
print(pooled_embeddings.values())
tensor([[-0.8899, -0.1342, -1.9060, -0.0905, -0.2814, -0.9369, -0.7783],
    [ 0.0000,  0.0000,  0.0000,  0.1598,  0.0695,  1.3265, -0.1011],
    [-0.4256, -1.1846, -2.1648, -1.0893,  0.3590, -1.9784, -0.7681]],
    grad_fn=<CatBackward0>)
print(pooled_embeddings.keys())
['f1', 'f2']
print(pooled_embeddings.offset_per_key())
tensor([0, 3, 7])
property device: device
embedding_bag_configs() List[EmbeddingBagConfig]
forward(features: KeyedJaggedTensor) KeyedTensor
Parameters:

features (KeyedJaggedTensor) – KJT of form [F X B X L].

Returns:

KeyedTensor

is_weighted() bool
reset_parameters() None
training: bool
class torchrec.modules.embedding_modules.EmbeddingBagCollectionInterface(*args, **kwargs)

Bases: ABC, Module

Interface for EmbeddingBagCollection.

abstract embedding_bag_configs() List[EmbeddingBagConfig]
abstract forward(features: KeyedJaggedTensor) KeyedTensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

abstract is_weighted() bool
training: bool
class torchrec.modules.embedding_modules.EmbeddingCollection(tables: List[EmbeddingConfig], device: Optional[device] = None, need_indices: bool = False)

Bases: EmbeddingCollectionInterface

EmbeddingCollection represents a collection of non-pooled embeddings.

It processes sparse data in the form of KeyedJaggedTensor of the form [F X B X L] where:

  • F: features (keys)

  • B: batch size

  • L: length of sparse features (variable)

and outputs Dict[feature (key), JaggedTensor]. Each JaggedTensor contains values of the form (B * L) X D where:

  • B: batch size

  • L: length of sparse features (jagged)

  • D: each feature’s (key’s) embedding dimension and lengths are of the form L

Parameters:
  • tables (List[EmbeddingConfig]) – list of embedding tables.

  • device (Optional[torch.device]) – default compute device.

  • need_indices (bool) – if we need to pass indices to the final lookup dict.

Example:

e1_config = EmbeddingConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
e2_config = EmbeddingConfig(
    name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"]
)

ec = EmbeddingCollection(tables=[e1_config, e2_config])

#     0       1        2  <-- batch
# 0   [0,1] None    [2]
# 1   [3]    [4]    [5,6,7]
# ^
# feature

features = KeyedJaggedTensor.from_offsets_sync(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
    offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)
feature_embeddings = ec(features)
print(feature_embeddings['f2'].values())
tensor([[-0.2050,  0.5478,  0.6054],
[ 0.7352,  0.3210, -3.0399],
[ 0.1279, -0.1756, -0.4130],
[ 0.7519, -0.4341, -0.0499],
[ 0.9329, -1.0697, -0.8095]], grad_fn=<EmbeddingBackward>)
property device: device
embedding_configs() List[EmbeddingConfig]
embedding_dim() int
embedding_names_by_table() List[List[str]]
forward(features: KeyedJaggedTensor) Dict[str, JaggedTensor]
Parameters:

features (KeyedJaggedTensor) – KJT of form [F X B X L].

Returns:

Dict[str, JaggedTensor]

need_indices() bool
reset_parameters() None
training: bool
class torchrec.modules.embedding_modules.EmbeddingCollectionInterface(*args, **kwargs)

Bases: ABC, Module

Interface for EmbeddingCollection.

abstract embedding_configs() List[EmbeddingConfig]
abstract embedding_dim() int
abstract embedding_names_by_table() List[List[str]]
abstract forward(features: KeyedJaggedTensor) Dict[str, JaggedTensor]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

abstract need_indices() bool
training: bool
torchrec.modules.embedding_modules.get_embedding_names_by_table(tables: Union[List[EmbeddingBagConfig], List[EmbeddingConfig]]) List[List[str]]

torchrec.modules.feature_processor

class torchrec.modules.feature_processor.BaseFeatureProcessor(*args, **kwargs)

Bases: Module

Abstract base class for feature processor.

abstract forward(features: Dict[str, JaggedTensor]) Dict[str, JaggedTensor]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class torchrec.modules.feature_processor.BaseGroupedFeatureProcessor(*args, **kwargs)

Bases: Module

Abstract base class for grouped feature processor

abstract forward(features: KeyedJaggedTensor) KeyedJaggedTensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class torchrec.modules.feature_processor.PositionWeightedModule(max_feature_lengths: Dict[str, int], device: Optional[device] = None)

Bases: BaseFeatureProcessor

Adds position weights to id list features.

Parameters:

max_feature_lengths (Dict[str, int]) – feature name to max_length mapping. max_length, a.k.a truncation size, specifies the maximum number of ids each sample has. For each feature, its position weight parameter size is max_length.

forward(features: Dict[str, JaggedTensor]) Dict[str, JaggedTensor]
Parameters:

features (Dict[str, JaggedTensor]) – dictionary of keys to JaggedTensor, representing the features.

Returns:

same as input features with weights field being populated.

Return type:

Dict[str, JaggedTensor]

reset_parameters() None
training: bool
class torchrec.modules.feature_processor.PositionWeightedProcessor(max_feature_lengths: Dict[str, int], device: Optional[device] = None)

Bases: BaseGroupedFeatureProcessor

PositionWeightedProcessor represents a processor to apply position weight to a KeyedJaggedTensor.

It can handle both unsharded and sharded input and output corresponding output

Parameters:
  • max_feature_lengths (Dict[str, int]) – Dict of feature_lengths, the key is the feature_name and value is length.

  • device (Optional[torch.device]) – default compute device.

Example:

keys=["Feature0", "Feature1", "Feature2"]
values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 3, 4, 5, 6, 7])
lengths=torch.tensor([2, 0, 1, 1, 1, 3, 2, 3, 0])
features = KeyedJaggedTensor.from_lengths_sync(keys=keys, values=values, lengths=lengths)
pw = FeatureProcessorCollection(
    feature_processor_modules={key: PositionWeightedFeatureProcessor(max_feature_length=100) for key in keys}
)
result = pw(features)
# result is
# KeyedJaggedTensor({
#     "Feature0": {
#         "values": [[0, 1], [], [2]],
#         "weights": [[1.0, 1.0], [], [1.0]]
#     },
#     "Feature1": {
#         "values": [[3], [4], [5, 6, 7]],
#         "weights": [[1.0], [1.0], [1.0, 1.0, 1.0]]
#     },
#     "Feature2": {
#         "values": [[3, 4], [5, 6, 7], []],
#         "weights": [[1.0, 1.0], [1.0, 1.0, 1.0], []]
#     }
# })
forward(features: KeyedJaggedTensor) KeyedJaggedTensor

In unsharded or non-pipelined model, the input features both contain fp_feature and non_fp_features, and the output will filter out non_fp features In sharded pipelining model, the input features can only contain either none or all feature_processed features, since the input feature comes from the input_dist() of ebc which will filter out the keys not in the ebc. And the input size is same as output size

Parameters:

features (KeyedJaggedTensor) – input features

Returns:

KeyedJaggedTensor

named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]]

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Parameters:
  • prefix (str) – prefix to prepend to all buffer names.

  • recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.

  • remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.

Yields:

(str, torch.Tensor) – Tuple containing the name and buffer

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any]

Returns a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Note

The returned object is a shallow copy. It contains references to the module’s parameters and buffers.

Warning

Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

Warning

Please avoid the use of argument destination as it is not designed for end-users.

Parameters:
  • destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''.

  • keep_vars (bool, optional) – by default the Tensor s returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

Returns:

a dictionary containing a whole state of the module

Return type:

dict

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
training: bool
torchrec.modules.feature_processor.position_weighted_module_update_features(features: Dict[str, JaggedTensor], weighted_features: Dict[str, JaggedTensor]) Dict[str, JaggedTensor]

torchrec.modules.lazy_extension

class torchrec.modules.lazy_extension.LazyModuleExtensionMixin(*args, **kwargs)

Bases: LazyModuleMixin

This is a temporary extension of LazyModuleMixin to support passing keyword arguments to lazy module’s forward method.

The long-term plan is to upstream this feature to LazyModuleMixin. Please see https://github.com/pytorch/pytorch/issues/59923 for details.

Please see TestLazyModuleExtensionMixin, which contains unit tests that ensure:
  • LazyModuleExtensionMixin._infer_parameters has source code parity with torch.nn.modules.lazy.LazyModuleMixin._infer_parameters, except that the former can accept keyword arguments.

  • LazyModuleExtensionMixin._call_impl has source code parity with torch.nn.Module._call_impl, except that the former can pass keyword arguments to forward pre hooks.”

apply(fn: Callable[[Module], None]) Module

Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model.

Note

Calling apply() on an uninitialized lazy-module will result in an error. User is required to initialize a lazy-module (by doing a dummy forward pass) before calling apply() on the lazy-module.

Parameters:

fn (torch.nn.Module -> None) – function to be applied to each submodule.

Returns:

self

Return type:

torch.nn.Module

Example:

@torch.no_grad()
def init_weights(m):
    print(m)
    if type(m) == torch.nn.LazyLinear:
        m.weight.fill_(1.0)
        print(m.weight)

linear = torch.nn.LazyLinear(2)
linear.apply(init_weights)  # this fails, because `linear` (a lazy-module) hasn't been initialized yet

input = torch.randn(2, 10)
linear(input)  # run a dummy forward pass to initialize the lazy-module

linear.apply(init_weights)  # this works now
torchrec.modules.lazy_extension.lazy_apply(module: Module, fn: Callable[[Module], None]) Module

Attaches a function to a module, which will be applied recursively to every submodule (as returned by .children()) of the module as well as the module itself right after the first forward pass (i.e. after all submodules and parameters have been initialized).

Typical use includes initializing the numerical value of the parameters of a lazy module (i.e. modules inherited from LazyModuleMixin).

Note

lazy_apply() can be used on both lazy and non-lazy modules.

Parameters:
  • module (torch.nn.Module) – module to recursively apply fn on.

  • fn (Callable[[torch.nn.Module], None]) – function to be attached to module and later be applied to each submodule of module and the module itself.

Returns:

module with fn attached.

Return type:

torch.nn.Module

Example:

@torch.no_grad()
def init_weights(m):
    print(m)
    if type(m) == torch.nn.LazyLinear:
        m.weight.fill_(1.0)
        print(m.weight)

linear = torch.nn.LazyLinear(2)
lazy_apply(linear, init_weights)  # doesn't run `init_weights` immediately
input = torch.randn(2, 10)
linear(input)  # runs `init_weights` only once, right after first forward pass

seq = torch.nn.Sequential(torch.nn.LazyLinear(2), torch.nn.LazyLinear(2))
lazy_apply(seq, init_weights)  # doesn't run `init_weights` immediately
input = torch.randn(2, 10)
seq(input)  # runs `init_weights` only once, right after first forward pass

torchrec.modules.mlp

class torchrec.modules.mlp.MLP(in_size: int, layer_sizes: ~typing.List[int], bias: bool = True, activation: ~typing.Union[str, ~typing.Callable[[], ~torch.nn.modules.module.Module], ~torch.nn.modules.module.Module, ~typing.Callable[[~torch.Tensor], ~torch.Tensor]] = <built-in method relu of type object>, device: ~typing.Optional[~torch.device] = None)

Bases: Module

Applies a stack of Perceptron modules sequentially (i.e. Multi-Layer Perceptron).

Parameters:
  • in_size (int) – in_size of the input.

  • layer_sizes (List[int]) – out_size of each Perceptron module.

  • bias (bool) – if set to False, the layer will not learn an additive bias. Default: True.

  • activation (str, Union[Callable[[], torch.nn.Module], torch.nn.Module, Callable[[torch.Tensor], torch.Tensor]]) – the activation function to apply to the output of linear transformation of each Perceptron module. If activation is a str, we currently only support the follow strings, as “relu”, “sigmoid”, and “swish_layernorm”. If activation is a Callable[[], torch.nn.Module], activation() will be called once per Perceptron module to generate the activation module for that Perceptron module, and the parameters won’t be shared between those activation modules. One use case is when all the activation modules share the same constructor arguments, but don’t share the actual module parameters. Default: torch.relu.

  • device (Optional[torch.device]) – default compute device.

Example:

batch_size = 3
in_size = 40
input = torch.randn(batch_size, in_size)

layer_sizes = [16, 8, 4]
mlp_module = MLP(in_size, layer_sizes, bias=True)
output = mlp_module(input)
assert list(output.shape) == [batch_size, layer_sizes[-1]]
forward(input: Tensor) Tensor
Parameters:

input (torch.Tensor) – tensor of shape (B, I) where I is number of elements in each input sample.

Returns:

tensor of shape (B, O) where O is out_size of the last Perceptron module.

Return type:

torch.Tensor

training: bool
class torchrec.modules.mlp.Perceptron(in_size: int, out_size: int, bias: bool = True, activation: ~typing.Union[~torch.nn.modules.module.Module, ~typing.Callable[[~torch.Tensor], ~torch.Tensor]] = <built-in method relu of type object>, device: ~typing.Optional[~torch.device] = None)

Bases: Module

Applies a linear transformation and activation.

Parameters:
  • in_size (int) – number of elements in each input sample.

  • out_size (int) – number of elements in each output sample.

  • bias (bool) – if set to False, the layer will not learn an additive bias. Default: True.

  • activation (Union[torch.nn.Module, Callable[[torch.Tensor], torch.Tensor]]) – the activation function to apply to the output of linear transformation. Default: torch.relu.

  • device (Optional[torch.device]) – default compute device.

Example:

batch_size = 3
in_size = 40
input = torch.randn(batch_size, in_size)

out_size = 16
perceptron = Perceptron(in_size, out_size, bias=True)

output = perceptron(input)
assert list(output) == [batch_size, out_size]
forward(input: Tensor) Tensor
Parameters:

input (torch.Tensor) – tensor of shape (B, I) where I is number of elements in each input sample.

Returns:

tensor of shape (B, O) where O is number of elements per

channel in each output sample (i.e. out_size).

Return type:

torch.Tensor

training: bool

torchrec.modules.utils

torchrec.modules.utils.check_module_output_dimension(module: Union[Iterable[Module], Module], in_features: int, out_features: int) bool

Verify that the out_features of a given module or a list of modules matches the specified number. If a list of modules or a ModuleList is given, recursively check all the submodules.

torchrec.modules.utils.construct_modulelist_from_single_module(module: Module, sizes: Tuple[int, ...]) Module

Given a single module, construct a (nested) ModuleList of size of sizes by making copies of the provided module and reinitializing the Linear layers.

torchrec.modules.utils.convert_list_of_modules_to_modulelist(modules: Iterable[Module], sizes: Tuple[int, ...]) Module
torchrec.modules.utils.extract_module_or_tensor_callable(module_or_callable: Union[Callable[[], Module], Module, Callable[[Tensor], Tensor]]) Union[Module, Callable[[Tensor], Tensor]]
torchrec.modules.utils.get_module_output_dimension(module: Union[Callable[[Tensor], Tensor], Module], in_features: int) int
torchrec.modules.utils.init_mlp_weights_xavier_uniform(m: Module) None

Module contents

Torchrec Common Modules

The torchrec modules contain a collection of various modules.

These modules include:
  • extensions of nn.Embedding and nn.EmbeddingBag, called EmbeddingBagCollection and EmbeddingCollection respectively.

  • established modules such as DeepFM and CrossNet.

  • common module patterns such as MLP and SwishLayerNorm.

  • custom modules for TorchRec such as PositionWeightedModule and LazyModuleExtensionMixin.

  • EmbeddingTower and EmbeddingTowerCollection, logical “tower” of embeddings passed to provided interaction module.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources