Shortcuts

torchrec.modules

Torchrec Common Modules

The torchrec modules contain a collection of various modules.

These modules include:
  • extensions of nn.Embedding and nn.EmbeddingBag, called EmbeddingBagCollection and EmbeddingCollection respectively.

  • established modules such as DeepFM and CrossNet.

  • common module patterns such as MLP and SwishLayerNorm.

  • custom modules for TorchRec such as PositionWeightedModule and LazyModuleExtensionMixin.

  • EmbeddingTower and EmbeddingTowerCollection, logical “tower” of embeddings passed to provided interaction module.

torchrec.modules.activation

Activation Modules

class torchrec.modules.activation.SwishLayerNorm(input_dims: Union[int, List[int], Size], device: Optional[device] = None)

Bases: Module

Applies the Swish function with layer normalization: Y = X * Sigmoid(LayerNorm(X)).

Parameters:
  • input_dims (Union[int, List[int], torch.Size]) – dimensions to normalize over. If an input tensor has shape [batch_size, d1, d2, d3], setting input_dim=[d2, d3] will do the layer normalization on last two dimensions.

  • device (Optional[torch.device]) – default compute device.

Example:

sln = SwishLayerNorm(100)
forward(input: Tensor) Tensor
Parameters:

input (torch.Tensor) – an input tensor.

Returns:

an output tensor.

Return type:

torch.Tensor

training: bool

torchrec.modules.crossnet

CrossNet API

class torchrec.modules.crossnet.CrossNet(in_features: int, num_layers: int)

Bases: Module

Cross Network:

Cross Net is a stack of “crossing” operations on a tensor of shape \((*, N)\) to the same shape, effectively creating \(N\) learnable polynomical functions over the input tensor.

In this module, the crossing operations are defined based on a full rank matrix (NxN), such that the crossing effect can cover all bits on each layer. On each layer l, the tensor is transformed into:

\[x_{l+1} = x_0 * (W_l \cdot x_l + b_l) + x_l\]

where \(W_l\) is a square matrix \((NxN)\), \(*\) means element-wise multiplication, \(\cdot\) means matrix multiplication.

Parameters:
  • in_features (int) – the dimension of the input.

  • num_layers (int) – the number of layers in the module.

Example:

batch_size = 3
num_layers = 2
in_features = 10
input = torch.randn(batch_size, in_features)
dcn = CrossNet(num_layers=num_layers)
output = dcn(input)
forward(input: Tensor) Tensor
Parameters:

input (torch.Tensor) – tensor with shape [batch_size, in_features].

Returns:

tensor with shape [batch_size, in_features].

Return type:

torch.Tensor

training: bool
class torchrec.modules.crossnet.LowRankCrossNet(in_features: int, num_layers: int, low_rank: int = 1)

Bases: Module

Low Rank Cross Net is a highly efficient cross net. Instead of using full rank cross matrices (NxN) at each layer, it will use two kernels \(W (N x r)\) and \(V (r x N)\), where r << N, to simplify the matrix multiplication.

On each layer l, the tensor is transformed into:

\[x_{l+1} = x_0 * (W_l \cdot (V_l \cdot x_l) + b_l) + x_l\]

where \(W_l\) is either a vector, \(*\) means element-wise multiplication, and \(\cdot\) means matrix multiplication.

Note

Rank r should be chosen smartly. Usually, we expect r < N/2 to have computational savings; we should expect \(r ~= N/4\) to preserve the accuracy of the full rank cross net.

Parameters:
  • in_features (int) – the dimension of the input.

  • num_layers (int) – the number of layers in the module.

  • low_rank (int) – the rank setup of the cross matrix (default = 1). Value must be always >= 1.

Example:

batch_size = 3
num_layers = 2
in_features = 10
input = torch.randn(batch_size, in_features)
dcn = LowRankCrossNet(num_layers=num_layers, low_rank=3)
output = dcn(input)
forward(input: Tensor) Tensor
Parameters:

input (torch.Tensor) – tensor with shape [batch_size, in_features].

Returns:

tensor with shape [batch_size, in_features].

Return type:

torch.Tensor

training: bool
class torchrec.modules.crossnet.LowRankMixtureCrossNet(in_features: int, num_layers: int, num_experts: int = 1, low_rank: int = 1, activation: ~typing.Union[~torch.nn.modules.module.Module, ~typing.Callable[[~torch.Tensor], ~torch.Tensor]] = <built-in method relu of type object>)

Bases: Module

Low Rank Mixture Cross Net is a DCN V2 implementation from the paper:

LowRankMixtureCrossNet defines the learnable crossing parameter per layer as a low-rank matrix \((N*r)\) together with mixture of experts. Compared to LowRankCrossNet, instead of relying on one single expert to learn feature crosses, this module leverages such \(K\) experts; each learning feature interactions in different subspaces, and adaptively combining the learned crosses using a gating mechanism that depends on input \(x\)..

On each layer l, the tensor is transformed into:

\[x_{l+1} = MoE({expert_i : i \in K_{experts}}) + x_l\]

and each \(expert_i\) is defined as:

\[expert_i = x_0 * (U_{li} \cdot g(C_{li} \cdot g(V_{li} \cdot x_l)) + b_l)\]

where \(U_{li} (N, r)\), \(C_{li} (r, r)\) and \(V_{li} (r, N)\) are low-rank matrices, \(*\) means element-wise multiplication, \(x\) means matrix multiplication, and \(g()\) is the non-linear activation function.

When num_expert is 1, the gate evaluation and MOE will be skipped to save computation.

Parameters:
  • in_features (int) – the dimension of the input.

  • num_layers (int) – the number of layers in the module.

  • low_rank (int) – the rank setup of the cross matrix (default = 1). Value must be always >= 1

  • activation (Union[torch.nn.Module, Callable[[torch.Tensor], torch.Tensor]]) – the non-linear activation function, used in defining experts. Default is relu.

Example:

batch_size = 3
num_layers = 2
in_features = 10
input = torch.randn(batch_size, in_features)
dcn = LowRankCrossNet(num_layers=num_layers, num_experts=5, low_rank=3)
output = dcn(input)
forward(input: Tensor) Tensor
Parameters:

input (torch.Tensor) – tensor with shape [batch_size, in_features].

Returns:

tensor with shape [batch_size, in_features].

Return type:

torch.Tensor

training: bool
class torchrec.modules.crossnet.VectorCrossNet(in_features: int, num_layers: int)

Bases: Module

Vector Cross Network can be refered as DCN-V1.

It is also a specialized low rank cross net, where rank=1. In this version, on each layer, instead of keeping two kernels W and V, we only keep one vector kernel W (Nx1). We use the dot operation to compute the “crossing” effect of the features, thus saving two matrix multiplications to further reduce computational cost and cut the number of learnable parameters.

On each layer l, the tensor is transformed into

\[x_{l+1} = x_0 * (W_l . x_l + b_l) + x_l\]

where \(W_l\) is either a vector, \(*\) means element-wise multiplication; \(.\) means dot operations.

Parameters:
  • in_features (int) – the dimension of the input.

  • num_layers (int) – the number of layers in the module.

Example:

batch_size = 3
num_layers = 2
in_features = 10
input = torch.randn(batch_size, in_features)
dcn = VectorCrossNet(num_layers=num_layers)
output = dcn(input)
forward(input: Tensor) Tensor
Parameters:

input (torch.Tensor) – tensor with shape [batch_size, in_features].

Returns:

tensor with shape [batch_size, in_features].

Return type:

torch.Tensor

training: bool

torchrec.modules.deepfm

Deep Factorization-Machine Modules

The following modules are based off the Deep Factorization-Machine (DeepFM) paper

  • Class DeepFM implents the DeepFM Framework

  • Class FactorizationMachine implements FM as noted in the above paper.

class torchrec.modules.deepfm.DeepFM(dense_module: Module)

Bases: Module

This is the DeepFM module

This module does not cover the end-end functionality of the published paper. Instead, it covers only the deep component of the publication. It is used to learn high-order feature interactions. If low-order feature interactions should be learnt, please use FactorizationMachine module instead, which will share the same embedding input of this module.

To support modeling flexibility, we customize the key components as:

  • Different from the public paper, we change the input from raw sparse features to embeddings of the features. It allows flexibility in embedding dimensions and the number of embeddings, as long as all embedding tensors have the same batch size.

  • On top of the public paper, we allow users to customize the hidden layer to be any module, not limited to just MLP.

The general architecture of the module is like:

        1 x 10                  output
         /|\
          |                     pass into `dense_module`
          |
        1 x 90
         /|\
          |                     concat
          |
1 x 20, 1 x 30, 1 x 40          list of embeddings
Parameters:

dense_module (nn.Module) – any customized module that can be used (such as MLP) in DeepFM. The in_features of this module must be equal to the element counts. For example, if the input embedding is [randn(3, 2, 3), randn(3, 4, 5)], the in_features should be: 2*3+4*5.

Example:

import torch
from torchrec.fb.modules.deepfm import DeepFM
from torchrec.fb.modules.mlp import LazyMLP
batch_size = 3
output_dim = 30
# the input embedding are a torch.Tensor of [batch_size, num_embeddings, embedding_dim]
input_embeddings = [
    torch.randn(batch_size, 2, 64),
    torch.randn(batch_size, 2, 32),
]
dense_module = nn.Linear(192, output_dim)
deepfm = DeepFM(dense_module=dense_module)
deep_fm_output = deepfm(embeddings=input_embeddings)
forward(embeddings: List[Tensor]) Tensor
Parameters:

embeddings (List[torch.Tensor]) –

The list of all embeddings (e.g. dense, common_sparse, specialized_sparse, embedding_features, raw_embedding_features) in the shape of:

(batch_size, num_embeddings, embedding_dim)

For the ease of operation, embeddings that have the same embedding dimension have the option to be stacked into a single tensor. For example, when we have 1 trained embedding with dimension=32, 5 native embeddings with dimension=64, and 3 dense features with dimension=16, we can prepare the embeddings list to be the list of:

tensor(B, 1, 32) (trained_embedding with num_embeddings=1, embedding_dim=32)
tensor(B, 5, 64) (native_embedding with num_embeddings=5, embedding_dim=64)
tensor(B, 3, 16) (dense_features with num_embeddings=3, embedding_dim=32)

Note

batch_size of all input tensors need to be identical.

Returns:

output of dense_module with flattened and concatenated embeddings as input.

Return type:

torch.Tensor

training: bool
class torchrec.modules.deepfm.FactorizationMachine

Bases: Module

This is the Factorization Machine module, mentioned in the DeepFM paper:

This module does not cover the end-end functionality of the published paper. Instead, it covers only the FM part of the publication, and is used to learn 2nd-order feature interactions.

To support modeling flexibility, we customize the key components as different from the public paper:

We change the input from raw sparse features to embeddings of the features. This allows flexibility in embedding dimensions and the number of embeddings, as long as all embedding tensors have the same batch size.

The general architecture of the module is like:

        1 x 10                  output
         /|\
          |                     pass into `dense_module`
          |
        1 x 90
         /|\
          |                     concat
          |
1 x 20, 1 x 30, 1 x 40          list of embeddings

Example:

batch_size = 3
# the input embedding are in torch.Tensor of [batch_size, num_embeddings, embedding_dim]
input_embeddings = [
    torch.randn(batch_size, 2, 64),
    torch.randn(batch_size, 2, 32),
]
fm = FactorizationMachine()
output = fm(embeddings=input_embeddings)
forward(embeddings: List[Tensor]) Tensor
Parameters:

embeddings (List[torch.Tensor]) –

The list of all embeddings (e.g. dense, common_sparse, specialized_sparse, embedding_features, raw_embedding_features) in the shape of:

(batch_size, num_embeddings, embedding_dim)

For the ease of operation, embeddings that have the same embedding dimension have the option to be stacked into a single tensor. For example, when we have 1 trained embedding with dimension=32, 5 native embeddings with dimension=64, and 3 dense features with dimension=16, we can prepare the embeddings list to be the list of:

tensor(B, 1, 32) (trained_embedding with num_embeddings=1, embedding_dim=32)
tensor(B, 5, 64) (native_embedding with num_embeddings=5, embedding_dim=64)
tensor(B, 3, 16) (dense_features with num_embeddings=3, embedding_dim=32)

Note

batch_size of all input tensors need to be identical.

Returns:

output of fm with flattened and concatenated embeddings as input. Expected to be [B, 1].

Return type:

torch.Tensor

training: bool

torchrec.modules.embedding_configs

class torchrec.modules.embedding_configs.BaseEmbeddingConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: torchrec.types.DataType = <DataType.FP32: 'FP32'>, feature_names: List[str] = <factory>, weight_init_max: Union[float, NoneType] = None, weight_init_min: Union[float, NoneType] = None, pruning_indices_remapping: Union[torch.Tensor, NoneType] = None, init_fn: Union[Callable[[torch.Tensor], Union[torch.Tensor, NoneType]], NoneType] = None, need_pos: bool = False)

Bases: object

data_type: DataType = 'FP32'
embedding_dim: int
feature_names: List[str]
get_weight_init_max() float
get_weight_init_min() float
init_fn: Optional[Callable[[Tensor], Optional[Tensor]]] = None
name: str = ''
need_pos: bool = False
num_embeddings: int
num_features() int
pruning_indices_remapping: Optional[Tensor] = None
weight_init_max: Optional[float] = None
weight_init_min: Optional[float] = None
class torchrec.modules.embedding_configs.EmbeddingBagConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: torchrec.types.DataType = <DataType.FP32: 'FP32'>, feature_names: List[str] = <factory>, weight_init_max: Union[float, NoneType] = None, weight_init_min: Union[float, NoneType] = None, pruning_indices_remapping: Union[torch.Tensor, NoneType] = None, init_fn: Union[Callable[[torch.Tensor], Union[torch.Tensor, NoneType]], NoneType] = None, need_pos: bool = False, pooling: torchrec.modules.embedding_configs.PoolingType = <PoolingType.SUM: 'SUM'>)

Bases: BaseEmbeddingConfig

pooling: PoolingType = 'SUM'
class torchrec.modules.embedding_configs.EmbeddingConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: torchrec.types.DataType = <DataType.FP32: 'FP32'>, feature_names: List[str] = <factory>, weight_init_max: Union[float, NoneType] = None, weight_init_min: Union[float, NoneType] = None, pruning_indices_remapping: Union[torch.Tensor, NoneType] = None, init_fn: Union[Callable[[torch.Tensor], Union[torch.Tensor, NoneType]], NoneType] = None, need_pos: bool = False)

Bases: BaseEmbeddingConfig

embedding_dim: int
feature_names: List[str]
num_embeddings: int
class torchrec.modules.embedding_configs.EmbeddingTableConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: torchrec.types.DataType = <DataType.FP32: 'FP32'>, feature_names: List[str] = <factory>, weight_init_max: Union[float, NoneType] = None, weight_init_min: Union[float, NoneType] = None, pruning_indices_remapping: Union[torch.Tensor, NoneType] = None, init_fn: Union[Callable[[torch.Tensor], Union[torch.Tensor, NoneType]], NoneType] = None, need_pos: bool = False, pooling: torchrec.modules.embedding_configs.PoolingType = <PoolingType.SUM: 'SUM'>, is_weighted: bool = False, has_feature_processor: bool = False, embedding_names: List[str] = <factory>)

Bases: BaseEmbeddingConfig

embedding_names: List[str]
has_feature_processor: bool = False
is_weighted: bool = False
pooling: PoolingType = 'SUM'
class torchrec.modules.embedding_configs.PoolingType(value)

Bases: Enum

An enumeration.

MEAN = 'MEAN'
NONE = 'NONE'
SUM = 'SUM'
class torchrec.modules.embedding_configs.QuantConfig(activation, weight, per_table_weight_dtype)

Bases: tuple

activation: PlaceholderObserver

Alias for field number 0

per_table_weight_dtype: Optional[Dict[str, dtype]]

Alias for field number 2

weight: PlaceholderObserver

Alias for field number 1

class torchrec.modules.embedding_configs.ShardingType(value)

Bases: Enum

Well-known sharding types, used by inter-module optimizations.

COLUMN_WISE = 'column_wise'
DATA_PARALLEL = 'data_parallel'
ROW_WISE = 'row_wise'
TABLE_COLUMN_WISE = 'table_column_wise'
TABLE_ROW_WISE = 'table_row_wise'
TABLE_WISE = 'table_wise'
torchrec.modules.embedding_configs.data_type_to_dtype(data_type: DataType) dtype
torchrec.modules.embedding_configs.data_type_to_sparse_type(data_type: DataType) SparseType
torchrec.modules.embedding_configs.dtype_to_data_type(dtype: dtype) DataType
torchrec.modules.embedding_configs.pooling_type_to_pooling_mode(pooling_type: PoolingType, sharding_type: Optional[ShardingType] = None) PoolingMode
torchrec.modules.embedding_configs.pooling_type_to_str(pooling_type: PoolingType) str

torchrec.modules.embedding_modules

class torchrec.modules.embedding_modules.EmbeddingBagCollection(tables: List[EmbeddingBagConfig], is_weighted: bool = False, device: Optional[device] = None)

Bases: EmbeddingBagCollectionInterface

EmbeddingBagCollection represents a collection of pooled embeddings (EmbeddingBags).

Note

EmbeddingBagCollection is an unsharded module and is not performance optimized. For performance-sensitive scenarios, consider using the sharded version ShardedEmbeddingBagCollection.

It processes sparse data in the form of KeyedJaggedTensor with values of the form [F X B X L] where:

  • F: features (keys)

  • B: batch size

  • L: length of sparse features (jagged)

and outputs a KeyedTensor with values of the form [B * (F * D)] where:

  • F: features (keys)

  • D: each feature’s (key’s) embedding dimension

  • B: batch size

Parameters:
  • tables (List[EmbeddingBagConfig]) – list of embedding tables.

  • is_weighted (bool) – whether input KeyedJaggedTensor is weighted.

  • device (Optional[torch.device]) – default compute device.

Example:

table_0 = EmbeddingBagConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
table_1 = EmbeddingBagConfig(
    name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"]
)

ebc = EmbeddingBagCollection(tables=[table_0, table_1])

#        0       1        2  <-- batch
# "f1"   [0,1] None    [2]
# "f2"   [3]    [4]    [5,6,7]
#  ^
# feature

features = KeyedJaggedTensor(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
    offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)

pooled_embeddings = ebc(features)
print(pooled_embeddings.values())
tensor([[-0.8899, -0.1342, -1.9060, -0.0905, -0.2814, -0.9369, -0.7783],
    [ 0.0000,  0.0000,  0.0000,  0.1598,  0.0695,  1.3265, -0.1011],
    [-0.4256, -1.1846, -2.1648, -1.0893,  0.3590, -1.9784, -0.7681]],
    grad_fn=<CatBackward0>)
print(pooled_embeddings.keys())
['f1', 'f2']
print(pooled_embeddings.offset_per_key())
tensor([0, 3, 7])
property device: device
embedding_bag_configs() List[EmbeddingBagConfig]
forward(features: KeyedJaggedTensor) KeyedTensor
Parameters:

features (KeyedJaggedTensor) – KJT of form [F X B X L].

Returns:

KeyedTensor

is_weighted() bool
reset_parameters() None
training: bool
class torchrec.modules.embedding_modules.EmbeddingBagCollectionInterface(*args, **kwargs)

Bases: ABC, Module

Interface for EmbeddingBagCollection.

abstract embedding_bag_configs() List[EmbeddingBagConfig]
abstract forward(features: KeyedJaggedTensor) KeyedTensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

abstract is_weighted() bool
training: bool
class torchrec.modules.embedding_modules.EmbeddingCollection(tables: List[EmbeddingConfig], device: Optional[device] = None, need_indices: bool = False)

Bases: EmbeddingCollectionInterface

EmbeddingCollection represents a collection of non-pooled embeddings.

Note

EmbeddingCollection is an unsharded module and is not performance optimized. For performance-sensitive scenarios, consider using the sharded version ShardedEmbeddingCollection.

It processes sparse data in the form of KeyedJaggedTensor of the form [F X B X L] where:

  • F: features (keys)

  • B: batch size

  • L: length of sparse features (variable)

and outputs Dict[feature (key), JaggedTensor]. Each JaggedTensor contains values of the form (B * L) X D where:

  • B: batch size

  • L: length of sparse features (jagged)

  • D: each feature’s (key’s) embedding dimension and lengths are of the form L

Parameters:
  • tables (List[EmbeddingConfig]) – list of embedding tables.

  • device (Optional[torch.device]) – default compute device.

  • need_indices (bool) – if we need to pass indices to the final lookup dict.

Example:

e1_config = EmbeddingConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
e2_config = EmbeddingConfig(
    name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"]
)

ec = EmbeddingCollection(tables=[e1_config, e2_config])

#     0       1        2  <-- batch
# 0   [0,1] None    [2]
# 1   [3]    [4]    [5,6,7]
# ^
# feature

features = KeyedJaggedTensor.from_offsets_sync(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
    offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)
feature_embeddings = ec(features)
print(feature_embeddings['f2'].values())
tensor([[-0.2050,  0.5478,  0.6054],
[ 0.7352,  0.3210, -3.0399],
[ 0.1279, -0.1756, -0.4130],
[ 0.7519, -0.4341, -0.0499],
[ 0.9329, -1.0697, -0.8095]], grad_fn=<EmbeddingBackward>)
property device: device
embedding_configs() List[EmbeddingConfig]
embedding_dim() int
embedding_names_by_table() List[List[str]]
forward(features: KeyedJaggedTensor) Dict[str, JaggedTensor]
Parameters:

features (KeyedJaggedTensor) – KJT of form [F X B X L].

Returns:

Dict[str, JaggedTensor]

need_indices() bool
reset_parameters() None
training: bool
class torchrec.modules.embedding_modules.EmbeddingCollectionInterface(*args, **kwargs)

Bases: ABC, Module

Interface for EmbeddingCollection.

abstract embedding_configs() List[EmbeddingConfig]
abstract embedding_dim() int
abstract embedding_names_by_table() List[List[str]]
abstract forward(features: KeyedJaggedTensor) Dict[str, JaggedTensor]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

abstract need_indices() bool
training: bool
torchrec.modules.embedding_modules.get_embedding_names_by_table(tables: Union[List[EmbeddingBagConfig], List[EmbeddingConfig]]) List[List[str]]
torchrec.modules.embedding_modules.process_pooled_embeddings(pooled_embeddings: List[Tensor], inverse_indices: Tensor) Tensor
torchrec.modules.embedding_modules.reorder_inverse_indices(inverse_indices: Optional[Tuple[List[str], Tensor]], feature_names: List[str]) Tensor

torchrec.modules.feature_processor

class torchrec.modules.feature_processor.BaseFeatureProcessor(*args, **kwargs)

Bases: Module

Abstract base class for feature processor.

abstract forward(features: Dict[str, JaggedTensor]) Dict[str, JaggedTensor]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class torchrec.modules.feature_processor.BaseGroupedFeatureProcessor(*args, **kwargs)

Bases: Module

Abstract base class for grouped feature processor

abstract forward(features: KeyedJaggedTensor) KeyedJaggedTensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class torchrec.modules.feature_processor.PositionWeightedModule(max_feature_lengths: Dict[str, int], device: Optional[device] = None)

Bases: BaseFeatureProcessor

Adds position weights to id list features.

Parameters:

max_feature_lengths (Dict[str, int]) – feature name to max_length mapping. max_length, a.k.a truncation size, specifies the maximum number of ids each sample has. For each feature, its position weight parameter size is max_length.

forward(features: Dict[str, JaggedTensor]) Dict[str, JaggedTensor]
Parameters:

features (Dict[str, JaggedTensor]) – dictionary of keys to JaggedTensor, representing the features.

Returns:

same as input features with weights field being populated.

Return type:

Dict[str, JaggedTensor]

reset_parameters() None
training: bool
class torchrec.modules.feature_processor.PositionWeightedProcessor(max_feature_lengths: Dict[str, int], device: Optional[device] = None)

Bases: BaseGroupedFeatureProcessor

PositionWeightedProcessor represents a processor to apply position weight to a KeyedJaggedTensor.

It can handle both unsharded and sharded input and output corresponding output

Parameters:
  • max_feature_lengths (Dict[str, int]) – Dict of feature_lengths, the key is the feature_name and value is length.

  • device (Optional[torch.device]) – default compute device.

Example:

keys=["Feature0", "Feature1", "Feature2"]
values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 3, 4, 5, 6, 7])
lengths=torch.tensor([2, 0, 1, 1, 1, 3, 2, 3, 0])
features = KeyedJaggedTensor.from_lengths_sync(keys=keys, values=values, lengths=lengths)
pw = FeatureProcessorCollection(
    feature_processor_modules={key: PositionWeightedFeatureProcessor(max_feature_length=100) for key in keys}
)
result = pw(features)
# result is
# KeyedJaggedTensor({
#     "Feature0": {
#         "values": [[0, 1], [], [2]],
#         "weights": [[1.0, 1.0], [], [1.0]]
#     },
#     "Feature1": {
#         "values": [[3], [4], [5, 6, 7]],
#         "weights": [[1.0], [1.0], [1.0, 1.0, 1.0]]
#     },
#     "Feature2": {
#         "values": [[3, 4], [5, 6, 7], []],
#         "weights": [[1.0, 1.0], [1.0, 1.0, 1.0], []]
#     }
# })
forward(features: KeyedJaggedTensor) KeyedJaggedTensor

In unsharded or non-pipelined model, the input features both contain fp_feature and non_fp_features, and the output will filter out non_fp features In sharded pipelining model, the input features can only contain either none or all feature_processed features, since the input feature comes from the input_dist() of ebc which will filter out the keys not in the ebc. And the input size is same as output size

Parameters:

features (KeyedJaggedTensor) – input features

Returns:

KeyedJaggedTensor

named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]]

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Parameters:
  • prefix (str) – prefix to prepend to all buffer names.

  • recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.

  • remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.

Yields:

(str, torch.Tensor) – Tuple containing the name and buffer

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any]

Return a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Note

The returned object is a shallow copy. It contains references to the module’s parameters and buffers.

Warning

Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

Warning

Please avoid the use of argument destination as it is not designed for end-users.

Parameters:
  • destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''.

  • keep_vars (bool, optional) – by default the Tensor s returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

Returns:

a dictionary containing a whole state of the module

Return type:

dict

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
training: bool
torchrec.modules.feature_processor.offsets_to_range_traceble(offsets: Tensor, values: Tensor) Tensor
torchrec.modules.feature_processor.position_weighted_module_update_features(features: Dict[str, JaggedTensor], weighted_features: Dict[str, JaggedTensor]) Dict[str, JaggedTensor]

torchrec.modules.lazy_extension

class torchrec.modules.lazy_extension.LazyModuleExtensionMixin(*args, **kwargs)

Bases: LazyModuleMixin

This is a temporary extension of LazyModuleMixin to support passing keyword arguments to lazy module’s forward method.

The long-term plan is to upstream this feature to LazyModuleMixin. Please see https://github.com/pytorch/pytorch/issues/59923 for details.

Please see TestLazyModuleExtensionMixin, which contains unit tests that ensure:
  • LazyModuleExtensionMixin._infer_parameters has source code parity with torch.nn.modules.lazy.LazyModuleMixin._infer_parameters, except that the former can accept keyword arguments.

  • LazyModuleExtensionMixin._call_impl has source code parity with torch.nn.Module._call_impl, except that the former can pass keyword arguments to forward pre hooks.”

apply(fn: Callable[[Module], None]) Module

Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model.

Note

Calling apply() on an uninitialized lazy-module will result in an error. User is required to initialize a lazy-module (by doing a dummy forward pass) before calling apply() on the lazy-module.

Parameters:

fn (torch.nn.Module -> None) – function to be applied to each submodule.

Returns:

self

Return type:

torch.nn.Module

Example:

@torch.no_grad()
def init_weights(m):
    print(m)
    if type(m) == torch.nn.LazyLinear:
        m.weight.fill_(1.0)
        print(m.weight)

linear = torch.nn.LazyLinear(2)
linear.apply(init_weights)  # this fails, because `linear` (a lazy-module) hasn't been initialized yet

input = torch.randn(2, 10)
linear(input)  # run a dummy forward pass to initialize the lazy-module

linear.apply(init_weights)  # this works now
torchrec.modules.lazy_extension.lazy_apply(module: Module, fn: Callable[[Module], None]) Module

Attaches a function to a module, which will be applied recursively to every submodule (as returned by .children()) of the module as well as the module itself right after the first forward pass (i.e. after all submodules and parameters have been initialized).

Typical use includes initializing the numerical value of the parameters of a lazy module (i.e. modules inherited from LazyModuleMixin).

Note

lazy_apply() can be used on both lazy and non-lazy modules.

Parameters:
  • module (torch.nn.Module) – module to recursively apply fn on.

  • fn (Callable[[torch.nn.Module], None]) – function to be attached to module and later be applied to each submodule of module and the module itself.

Returns:

module with fn attached.

Return type:

torch.nn.Module

Example:

@torch.no_grad()
def init_weights(m):
    print(m)
    if type(m) == torch.nn.LazyLinear:
        m.weight.fill_(1.0)
        print(m.weight)

linear = torch.nn.LazyLinear(2)
lazy_apply(linear, init_weights)  # doesn't run `init_weights` immediately
input = torch.randn(2, 10)
linear(input)  # runs `init_weights` only once, right after first forward pass

seq = torch.nn.Sequential(torch.nn.LazyLinear(2), torch.nn.LazyLinear(2))
lazy_apply(seq, init_weights)  # doesn't run `init_weights` immediately
input = torch.randn(2, 10)
seq(input)  # runs `init_weights` only once, right after first forward pass

torchrec.modules.mlp

class torchrec.modules.mlp.MLP(in_size: int, layer_sizes: ~typing.List[int], bias: bool = True, activation: ~typing.Union[str, ~typing.Callable[[], ~torch.nn.modules.module.Module], ~torch.nn.modules.module.Module, ~typing.Callable[[~torch.Tensor], ~torch.Tensor]] = <built-in method relu of type object>, device: ~typing.Optional[~torch.device] = None, dtype: ~torch.dtype = torch.float32)

Bases: Module

Applies a stack of Perceptron modules sequentially (i.e. Multi-Layer Perceptron).

Parameters:
  • in_size (int) – in_size of the input.

  • layer_sizes (List[int]) – out_size of each Perceptron module.

  • bias (bool) – if set to False, the layer will not learn an additive bias. Default: True.

  • activation (str, Union[Callable[[], torch.nn.Module], torch.nn.Module, Callable[[torch.Tensor], torch.Tensor]]) – the activation function to apply to the output of linear transformation of each Perceptron module. If activation is a str, we currently only support the follow strings, as “relu”, “sigmoid”, and “swish_layernorm”. If activation is a Callable[[], torch.nn.Module], activation() will be called once per Perceptron module to generate the activation module for that Perceptron module, and the parameters won’t be shared between those activation modules. One use case is when all the activation modules share the same constructor arguments, but don’t share the actual module parameters. Default: torch.relu.

  • device (Optional[torch.device]) – default compute device.

Example:

batch_size = 3
in_size = 40
input = torch.randn(batch_size, in_size)

layer_sizes = [16, 8, 4]
mlp_module = MLP(in_size, layer_sizes, bias=True)
output = mlp_module(input)
assert list(output.shape) == [batch_size, layer_sizes[-1]]
forward(input: Tensor) Tensor
Parameters:

input (torch.Tensor) – tensor of shape (B, I) where I is number of elements in each input sample.

Returns:

tensor of shape (B, O) where O is out_size of the last Perceptron module.

Return type:

torch.Tensor

training: bool
class torchrec.modules.mlp.Perceptron(in_size: int, out_size: int, bias: bool = True, activation: ~typing.Union[~torch.nn.modules.module.Module, ~typing.Callable[[~torch.Tensor], ~torch.Tensor]] = <built-in method relu of type object>, device: ~typing.Optional[~torch.device] = None, dtype: ~torch.dtype = torch.float32)

Bases: Module

Applies a linear transformation and activation.

Parameters:
  • in_size (int) – number of elements in each input sample.

  • out_size (int) – number of elements in each output sample.

  • bias (bool) – if set to False, the layer will not learn an additive bias. Default: True.

  • activation (Union[torch.nn.Module, Callable[[torch.Tensor], torch.Tensor]]) – the activation function to apply to the output of linear transformation. Default: torch.relu.

  • device (Optional[torch.device]) – default compute device.

Example:

batch_size = 3
in_size = 40
input = torch.randn(batch_size, in_size)

out_size = 16
perceptron = Perceptron(in_size, out_size, bias=True)

output = perceptron(input)
assert list(output) == [batch_size, out_size]
forward(input: Tensor) Tensor
Parameters:

input (torch.Tensor) – tensor of shape (B, I) where I is number of elements in each input sample.

Returns:

tensor of shape (B, O) where O is number of elements per

channel in each output sample (i.e. out_size).

Return type:

torch.Tensor

training: bool

torchrec.modules.utils

class torchrec.modules.utils.OpRegistryState

Bases: object

State of operator registry.

We can only register the op schema once. So if we’re registering multiple times we need a lock and check if they’re the same schema

op_registry_lock = <unlocked _thread.lock object>
op_registry_schema: Dict[str, str] = {}
torchrec.modules.utils.check_module_output_dimension(module: Union[Iterable[Module], Module], in_features: int, out_features: int) bool

Verify that the out_features of a given module or a list of modules matches the specified number. If a list of modules or a ModuleList is given, recursively check all the submodules.

torchrec.modules.utils.construct_jagged_tensors(embeddings: Tensor, features: KeyedJaggedTensor, embedding_names: List[str], need_indices: bool = False, features_to_permute_indices: Optional[Dict[str, List[int]]] = None, original_features: Optional[KeyedJaggedTensor] = None, reverse_indices: Optional[Tensor] = None) Dict[str, JaggedTensor]
torchrec.modules.utils.construct_jagged_tensors_inference(embeddings: Tensor, lengths: Tensor, values: Tensor, embedding_names: List[str], need_indices: bool = False, features_to_permute_indices: Optional[Dict[str, List[int]]] = None, reverse_indices: Optional[Tensor] = None) Dict[str, JaggedTensor]
torchrec.modules.utils.construct_modulelist_from_single_module(module: Module, sizes: Tuple[int, ...]) Module

Given a single module, construct a (nested) ModuleList of size of sizes by making copies of the provided module and reinitializing the Linear layers.

torchrec.modules.utils.convert_list_of_modules_to_modulelist(modules: Iterable[Module], sizes: Tuple[int, ...]) Module
torchrec.modules.utils.extract_module_or_tensor_callable(module_or_callable: Union[Callable[[], Module], Module, Callable[[Tensor], Tensor]]) Union[Module, Callable[[Tensor], Tensor]]
torchrec.modules.utils.get_module_output_dimension(module: Union[Callable[[Tensor], Tensor], Module], in_features: int) int
torchrec.modules.utils.init_mlp_weights_xavier_uniform(m: Module) None
torchrec.modules.utils.register_custom_op(module: Module, dims: List[int]) Callable[[List[Optional[Tensor]], int], List[Tensor]]

Register a customized operator.

Parameters:
  • module – customized module instance

  • dims – output dimensions

torchrec.modules.mc_modules

class torchrec.modules.mc_modules.DistanceLFU_EvictionPolicy(decay_exponent: float = 1.0, threshold_filtering_func: Optional[Callable[[Tensor], Tuple[Tensor, Union[float, Tensor]]]] = None)

Bases: MCHEvictionPolicy

coalesce_history_metadata(current_iter: int, history_metadata: Dict[str, Tensor], unique_ids_counts: Tensor, unique_inverse_mapping: Tensor, additional_ids: Optional[Tensor] = None, threshold_mask: Optional[Tensor] = None) Dict[str, Tensor]

Args: history_metadata (Dict[str, torch.Tensor]): history metadata dict additional_ids (torch.Tensor): additional ids to be used as part of history unique_inverse_mapping (torch.Tensor): torch.unique inverse mapping generated from

torch.cat[history_accumulator, additional_ids]. used to map history metadata tensor indices to their coalesced tensor indices.

Coalesce metadata history buffers and return dict of processed metadata tensors.

property metadata_info: List[MCHEvictionPolicyMetadataInfo]
record_history_metadata(current_iter: int, incoming_ids: Tensor, history_metadata: Dict[str, Tensor]) None

Args: current_iter (int): current iteration incoming_ids (torch.Tensor): incoming ids history_metadata (Dict[str, torch.Tensor]): history metadata dict

Compute and record metadata based on incoming ids

for the implemented eviction policy.

update_metadata_and_generate_eviction_scores(current_iter: int, mch_size: int, coalesced_history_argsort_mapping: Tensor, coalesced_history_sorted_unique_ids_counts: Tensor, coalesced_history_mch_matching_elements_mask: Tensor, coalesced_history_mch_matching_indices: Tensor, mch_metadata: Dict[str, Tensor], coalesced_history_metadata: Dict[str, Tensor]) Tuple[Tensor, Tensor]

Args:

Returns Tuple of (evicted_indices, selected_new_indices) where:

evicted_indices are indices in the mch map to be evicted, and selected_new_indices are the indices of the ids in the coalesced history that are to be added to the mch.

class torchrec.modules.mc_modules.LFU_EvictionPolicy(threshold_filtering_func: Optional[Callable[[Tensor], Tuple[Tensor, Union[float, Tensor]]]] = None)

Bases: MCHEvictionPolicy

coalesce_history_metadata(current_iter: int, history_metadata: Dict[str, Tensor], unique_ids_counts: Tensor, unique_inverse_mapping: Tensor, additional_ids: Optional[Tensor] = None, threshold_mask: Optional[Tensor] = None) Dict[str, Tensor]

Args: history_metadata (Dict[str, torch.Tensor]): history metadata dict additional_ids (torch.Tensor): additional ids to be used as part of history unique_inverse_mapping (torch.Tensor): torch.unique inverse mapping generated from

torch.cat[history_accumulator, additional_ids]. used to map history metadata tensor indices to their coalesced tensor indices.

Coalesce metadata history buffers and return dict of processed metadata tensors.

property metadata_info: List[MCHEvictionPolicyMetadataInfo]
record_history_metadata(current_iter: int, incoming_ids: Tensor, history_metadata: Dict[str, Tensor]) None

Args: current_iter (int): current iteration incoming_ids (torch.Tensor): incoming ids history_metadata (Dict[str, torch.Tensor]): history metadata dict

Compute and record metadata based on incoming ids

for the implemented eviction policy.

update_metadata_and_generate_eviction_scores(current_iter: int, mch_size: int, coalesced_history_argsort_mapping: Tensor, coalesced_history_sorted_unique_ids_counts: Tensor, coalesced_history_mch_matching_elements_mask: Tensor, coalesced_history_mch_matching_indices: Tensor, mch_metadata: Dict[str, Tensor], coalesced_history_metadata: Dict[str, Tensor]) Tuple[Tensor, Tensor]

Args:

Returns Tuple of (evicted_indices, selected_new_indices) where:

evicted_indices are indices in the mch map to be evicted, and selected_new_indices are the indices of the ids in the coalesced history that are to be added to the mch.

class torchrec.modules.mc_modules.LRU_EvictionPolicy(decay_exponent: float = 1.0, threshold_filtering_func: Optional[Callable[[Tensor], Tuple[Tensor, Union[float, Tensor]]]] = None)

Bases: MCHEvictionPolicy

coalesce_history_metadata(current_iter: int, history_metadata: Dict[str, Tensor], unique_ids_counts: Tensor, unique_inverse_mapping: Tensor, additional_ids: Optional[Tensor] = None, threshold_mask: Optional[Tensor] = None) Dict[str, Tensor]

Args: history_metadata (Dict[str, torch.Tensor]): history metadata dict additional_ids (torch.Tensor): additional ids to be used as part of history unique_inverse_mapping (torch.Tensor): torch.unique inverse mapping generated from

torch.cat[history_accumulator, additional_ids]. used to map history metadata tensor indices to their coalesced tensor indices.

Coalesce metadata history buffers and return dict of processed metadata tensors.

property metadata_info: List[MCHEvictionPolicyMetadataInfo]
record_history_metadata(current_iter: int, incoming_ids: Tensor, history_metadata: Dict[str, Tensor]) None

Args: current_iter (int): current iteration incoming_ids (torch.Tensor): incoming ids history_metadata (Dict[str, torch.Tensor]): history metadata dict

Compute and record metadata based on incoming ids

for the implemented eviction policy.

update_metadata_and_generate_eviction_scores(current_iter: int, mch_size: int, coalesced_history_argsort_mapping: Tensor, coalesced_history_sorted_unique_ids_counts: Tensor, coalesced_history_mch_matching_elements_mask: Tensor, coalesced_history_mch_matching_indices: Tensor, mch_metadata: Dict[str, Tensor], coalesced_history_metadata: Dict[str, Tensor]) Tuple[Tensor, Tensor]

Args:

Returns Tuple of (evicted_indices, selected_new_indices) where:

evicted_indices are indices in the mch map to be evicted, and selected_new_indices are the indices of the ids in the coalesced history that are to be added to the mch.

class torchrec.modules.mc_modules.MCHEvictionPolicy(metadata_info: List[MCHEvictionPolicyMetadataInfo], threshold_filtering_func: Optional[Callable[[Tensor], Tuple[Tensor, Union[float, Tensor]]]] = None)

Bases: ABC

abstract coalesce_history_metadata(current_iter: int, history_metadata: Dict[str, Tensor], unique_ids_counts: Tensor, unique_inverse_mapping: Tensor, additional_ids: Optional[Tensor] = None, threshold_mask: Optional[Tensor] = None) Dict[str, Tensor]

Args: history_metadata (Dict[str, torch.Tensor]): history metadata dict additional_ids (torch.Tensor): additional ids to be used as part of history unique_inverse_mapping (torch.Tensor): torch.unique inverse mapping generated from

torch.cat[history_accumulator, additional_ids]. used to map history metadata tensor indices to their coalesced tensor indices.

Coalesce metadata history buffers and return dict of processed metadata tensors.

abstract property metadata_info: List[MCHEvictionPolicyMetadataInfo]
abstract record_history_metadata(current_iter: int, incoming_ids: Tensor, history_metadata: Dict[str, Tensor]) None

Args: current_iter (int): current iteration incoming_ids (torch.Tensor): incoming ids history_metadata (Dict[str, torch.Tensor]): history metadata dict

Compute and record metadata based on incoming ids

for the implemented eviction policy.

abstract update_metadata_and_generate_eviction_scores(current_iter: int, mch_size: int, coalesced_history_argsort_mapping: Tensor, coalesced_history_sorted_unique_ids_counts: Tensor, coalesced_history_mch_matching_elements_mask: Tensor, coalesced_history_mch_matching_indices: Tensor, mch_metadata: Dict[str, Tensor], coalesced_history_metadata: Dict[str, Tensor]) Tuple[Tensor, Tensor]

Args:

Returns Tuple of (evicted_indices, selected_new_indices) where:

evicted_indices are indices in the mch map to be evicted, and selected_new_indices are the indices of the ids in the coalesced history that are to be added to the mch.

class torchrec.modules.mc_modules.MCHEvictionPolicyMetadataInfo(metadata_name, is_mch_metadata, is_history_metadata)

Bases: tuple

is_history_metadata: bool

Alias for field number 2

is_mch_metadata: bool

Alias for field number 1

metadata_name: str

Alias for field number 0

class torchrec.modules.mc_modules.MCHManagedCollisionModule(zch_size: int, device: device, eviction_policy: MCHEvictionPolicy, eviction_interval: int, input_hash_size: int = 9223372036854775808, input_hash_func: Optional[Callable[[Tensor, int], Tensor]] = None, mch_size: Optional[int] = None, mch_hash_func: Optional[Callable[[Tensor, int], Tensor]] = None, name: Optional[str] = None, output_global_offset: int = 0)

Bases: ManagedCollisionModule

ZCH / MCH managed collision module

Parameters:
  • zch_size (int) – range of output ids, within [output_size_offset, output_size_offset + zch_size - 1)

  • device (torch.device) – device on which this module will be executed

  • eviction_policy (eviction policy) – eviction policy to be used

  • eviction_interval (int) – interval of eviction policy is triggered

  • input_hash_size (int) – input feature id range, will be passed to input_hash_func as second arg

  • input_hash_func (Optional[Callable]) – function used to generate hashes for input features. This function is typically used to drive uniform distribution over range same or greater than input data

  • mch_size (Optional[int]) – size of residual output (ie. legacy MCH), experimental feature. Ids are internally shifted by output_size_offset + zch_output_range

  • mch_hash_func (Optional[Callable]) – function used to generate hashes for residual feature. will hash down to mch_size.

  • output_global_offset (int) – offset of the output id for output range, typically only used in sharding applications.

evict() Optional[Tensor]

Returns None if no eviction should be done this iteration. Otherwise, return ids of slots to reset. On eviction, this module should reset its state for those slots, with the assumptionn that the downstream module will handle this properly.

forward(features: Dict[str, JaggedTensor]) Dict[str, JaggedTensor]

Args: feature (JaggedTensor]): feature representation :returns: modified JT :rtype: Dict[str, JaggedTensor]

input_size() int

Returns numerical range of input, for sharding info

output_size() int

Returns numerical range of output, for validation vs. downstream embedding lookups

preprocess(features: Dict[str, JaggedTensor]) Dict[str, JaggedTensor]
profile(features: Dict[str, JaggedTensor]) Dict[str, JaggedTensor]
rebuild_with_output_id_range(output_id_range: Tuple[int, int], device: Optional[device] = None) MCHManagedCollisionModule

Used for creating local MC modules for RW sharding, hack for now

remap(features: Dict[str, JaggedTensor]) Dict[str, JaggedTensor]
training: bool
class torchrec.modules.mc_modules.ManagedCollisionCollection(managed_collision_modules: Dict[str, ManagedCollisionModule], embedding_configs: List[BaseEmbeddingConfig])

Bases: Module

ManagedCollisionCollection represents a collection of managed collision modules. The inputs passed to the MCC will be remapped by the managed collision modules

and returned.

Parameters:
  • managed_collision_modules (Dict[str, ManagedCollisionModule]) – Dict of managed collision modules

  • embedding_confgs (List[BaseEmbeddingConfig]) – List of embedding configs, for each table with a managed collsion module

embedding_configs() List[BaseEmbeddingConfig]
evict() Dict[str, Optional[Tensor]]
forward(features: KeyedJaggedTensor) KeyedJaggedTensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class torchrec.modules.mc_modules.ManagedCollisionModule(device: device)

Bases: Module

Abstract base class for ManagedCollisionModule. Maps input ids to range [0, max_output_id).

Parameters:
  • max_output_id (int) – Max output value of remapped ids.

  • input_hash_size (int) – Max value of input range i.e. [0, input_hash_size)

  • remapping_range_start_index (int) – Relative start index of remapping range

  • device (torch.device) – default compute device.

Example::

jt = JaggedTensor(…) mcm = ManagedCollisionModule(…) mcm_jt = mcm(fp)

property device: device
abstract evict() Optional[Tensor]

Returns None if no eviction should be done this iteration. Otherwise, return ids of slots to reset. On eviction, this module should reset its state for those slots, with the assumptionn that the downstream module will handle this properly.

abstract forward(features: Dict[str, JaggedTensor]) Dict[str, JaggedTensor]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

abstract input_size() int

Returns numerical range of input, for sharding info

abstract output_size() int

Returns numerical range of output, for validation vs. downstream embedding lookups

abstract preprocess(features: Dict[str, JaggedTensor]) Dict[str, JaggedTensor]
abstract rebuild_with_output_id_range(output_id_range: Tuple[int, int], device: Optional[device] = None) ManagedCollisionModule

Used for creating local MC modules for RW sharding, hack for now

training: bool
torchrec.modules.mc_modules.apply_mc_method_to_jt_dict(method: str, features_dict: Dict[str, JaggedTensor], table_to_features: Dict[str, List[str]], managed_collisions: ModuleDict) Dict[str, JaggedTensor]

Applies an MC method to a dictionary of JaggedTensors, returning the updated dictionary with same ordering

torchrec.modules.mc_modules.average_threshold_filter(id_counts: Tensor) Tuple[Tensor, Tensor]

Threshold is average of id_counts. An id is added if its count is strictly greater than the mean.

torchrec.modules.mc_modules.dynamic_threshold_filter(id_counts: Tensor, threshold_skew_multiplier: float = 10.0) Tuple[Tensor, Tensor]

Threshold is total_count / num_ids * threshold_skew_multiplier. An id is added if its count is strictly greater than the threshold.

torchrec.modules.mc_modules.probabilistic_threshold_filter(id_counts: Tensor, per_id_probability: float = 0.01) Tuple[Tensor, Tensor]

Each id has probability per_id_probability of being added. For example, if per_id_probability is 0.01 and an id appears 100 times, then it has a 60% of being added. More precisely, the id score is 1 - (1 - per_id_probability) ^ id_count, and for a randomly generated threshold, the id score is the chance of it being added.

torchrec.modules.mc_embedding_modules

class torchrec.modules.mc_embedding_modules.BaseManagedCollisionEmbeddingCollection(embedding_module: Union[EmbeddingBagCollection, EmbeddingCollection], managed_collision_collection: ManagedCollisionCollection, return_remapped_features: bool = False)

Bases: Module

BaseManagedCollisionEmbeddingCollection represents a EC/EBC module and a set of managed collision modules. The inputs into the MC-EC/EBC will first be modified by the managed collision module before being passed into the embedding collection.

Parameters:
  • embedding_module – EmbeddingCollection to lookup embeddings

  • managed_collision_modules – Dict of managed collision modules

  • return_remapped_features (bool) – whether to return remapped input features in addition to embeddings

forward(features: KeyedJaggedTensor) Tuple[Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor]]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingBagCollection(embedding_bag_collection: EmbeddingBagCollection, managed_collision_collection: ManagedCollisionCollection, return_remapped_features: bool = False)

Bases: BaseManagedCollisionEmbeddingCollection

ManagedCollisionEmbeddingBagCollection represents a EmbeddingBagCollection module and a set of managed collision modules. The inputs into the MC-EBC will first be modified by the managed collision module before being passed into the embedding bag collection.

For details of input and output types, see EmbeddingBagCollection

Parameters:
  • embedding_module – EmbeddingBagCollection to lookup embeddings

  • managed_collision_modules – Dict of managed collision modules

  • return_remapped_features (bool) – whether to return remapped input features in addition to embeddings

training: bool
class torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingCollection(embedding_collection: EmbeddingCollection, managed_collision_collection: ManagedCollisionCollection, return_remapped_features: bool = False)

Bases: BaseManagedCollisionEmbeddingCollection

ManagedCollisionEmbeddingCollection represents a EmbeddingCollection module and a set of managed collision modules. The inputs into the MC-EC will first be modified by the managed collision module before being passed into the embedding collection.

For details of input and output types, see EmbeddingCollection

Parameters:
  • embedding_module – EmbeddingCollection to lookup embeddings

  • managed_collision_modules – Dict of managed collision modules

  • return_remapped_features (bool) – whether to return remapped input features in addition to embeddings

training: bool
torchrec.modules.mc_embedding_modules.evict(evictions: Dict[str, Optional[Tensor]], ebc: Union[EmbeddingBagCollection, EmbeddingCollection]) None

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources