torchrec.models¶
Torchrec Models
Torchrec provides the architecture for two popular recsys models; DeepFM and DLRM (Deep Learning Recommendation Model).
Along with the overall model, the individual architectures of each layer are also provided (e.g. SparseArch, DenseArch, InteractionArch, and OverArch).
Examples can be found within each model.
The following notation is used throughout the documentation for the models:
F: number of sparse features
D: embedding_dimension of sparse features
B: batch size
num_features: number of dense features
torchrec.models.deepfm¶
- class torchrec.models.deepfm.DenseArch(in_features: int, hidden_layer_size: int, embedding_dim: int)¶
Bases:
Module
Processes the dense features of DeepFMNN model. Output layer is sized to the embedding_dimension of the EmbeddingBagCollection embeddings.
- Parameters:
in_features (int) – dimensionality of the dense input features.
hidden_layer_size (int) – sizes of the hidden layers in the DenseArch.
embedding_dim (int) – the same size of the embedding_dimension of sparseArch.
device (torch.device) – default compute device.
Example:
B = 20 D = 3 in_features = 10 dense_arch = DenseArch( in_features=in_features, hidden_layer_size=10, embedding_dim=D ) dense_arch_input = torch.rand((B, in_features)) dense_embedded = dense_arch(dense_arch_input)
- forward(features: Tensor) Tensor ¶
- Parameters:
features (torch.Tensor) – size B X num_features.
- Returns:
an output tensor of size B X D.
- Return type:
torch.Tensor
- training: bool¶
- class torchrec.models.deepfm.FMInteractionArch(fm_in_features: int, sparse_feature_names: List[str], deep_fm_dimension: int)¶
Bases:
Module
Processes the output of both SparseArch (sparse_features) and DenseArch (dense_features) and apply the general DeepFM interaction according to the external source of DeepFM paper: https://arxiv.org/pdf/1703.04247.pdf
The output dimension is expected to be a cat of dense_features, D.
- Parameters:
fm_in_features (int) – the input dimension of dense_module in DeepFM. For example, if the input embeddings is [randn(3, 2, 3), randn(3, 4, 5)], then the fm_in_features should be: 2 * 3 + 4 * 5.
sparse_feature_names (List[str]) – length of F.
deep_fm_dimension (int) – output of the deep interaction (DI) in the DeepFM arch.
Example:
D = 3 B = 10 keys = ["f1", "f2"] F = len(keys) fm_inter_arch = FMInteractionArch(sparse_feature_names=keys) dense_features = torch.rand((B, D)) sparse_features = KeyedTensor( keys=keys, length_per_key=[D, D], values=torch.rand((B, D * F)), ) cat_fm_output = fm_inter_arch(dense_features, sparse_features)
- forward(dense_features: Tensor, sparse_features: KeyedTensor) Tensor ¶
- Parameters:
dense_features (torch.Tensor) – tensor of size B X D.
sparse_features (KeyedJaggedTensor) – KJT of size F * D X B.
- Returns:
an output tensor of size B X (D + DI + 1).
- Return type:
torch.Tensor
- training: bool¶
- class torchrec.models.deepfm.OverArch(in_features: int)¶
Bases:
Module
Final Arch - simple MLP. The output is just one target.
- Parameters:
in_features (int) – the output dimension of the interaction arch.
Example:
B = 20 over_arch = OverArch() logits = over_arch(torch.rand((B, 10)))
- forward(features: Tensor) Tensor ¶
- Parameters:
features (torch.Tensor) –
- Returns:
an output tensor of size B X 1.
- Return type:
torch.Tensor
- training: bool¶
- class torchrec.models.deepfm.SimpleDeepFMNN(num_dense_features: int, embedding_bag_collection: EmbeddingBagCollection, hidden_layer_size: int, deep_fm_dimension: int)¶
Bases:
Module
Basic recsys module with DeepFM arch. Processes sparse features by learning pooled embeddings for each feature. Learns the relationship between dense features and sparse features by projecting dense features into the same embedding space. Learns the interaction among those dense and sparse features by deep_fm proposed in this paper: https://arxiv.org/pdf/1703.04247.pdf
The module assumes all sparse features have the same embedding dimension (i.e, each EmbeddingBagConfig uses the same embedding_dim)
The following notation is used throughout the documentation for the models:
F: number of sparse features
D: embedding_dimension of sparse features
B: batch size
num_features: number of dense features
- Parameters:
num_dense_features (int) – the number of input dense features.
embedding_bag_collection (EmbeddingBagCollection) – collection of embedding bags used to define SparseArch.
hidden_layer_size (int) – the hidden layer size used in dense module.
deep_fm_dimension (int) – the output layer size used in deep_fm’s deep interaction module.
Example:
B = 2 D = 8 eb1_config = EmbeddingBagConfig( name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"] ) eb2_config = EmbeddingBagConfig( name="t2", embedding_dim=D, num_embeddings=100, feature_names=["f2"], ) ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) sparse_nn = SimpleDeepFMNN( embedding_bag_collection=ebc, hidden_layer_size=20, over_embedding_dim=5 ) features = torch.rand((B, 100)) # 0 1 # 0 [1,2] [4,5] # 1 [4,3] [2,9] # ^ # feature sparse_features = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f3"], values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]), offsets=torch.tensor([0, 2, 4, 6, 8]), ) logits = sparse_nn( dense_features=features, sparse_features=sparse_features, )
- forward(dense_features: Tensor, sparse_features: KeyedJaggedTensor) Tensor ¶
- Parameters:
dense_features (torch.Tensor) – the dense features.
sparse_features (KeyedJaggedTensor) – the sparse features.
- Returns:
logits with size B X 1.
- Return type:
torch.Tensor
- training: bool¶
- class torchrec.models.deepfm.SparseArch(embedding_bag_collection: EmbeddingBagCollection)¶
Bases:
Module
Processes the sparse features of the DeepFMNN model. Does embedding lookups for all EmbeddingBag and embedding features of each collection.
- Parameters:
embedding_bag_collection (EmbeddingBagCollection) – represents a collection of pooled embeddings.
Example:
eb1_config = EmbeddingBagConfig( name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] ) eb2_config = EmbeddingBagConfig( name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] ) ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) # 0 1 2 <-- batch # 0 [0,1] None [2] # 1 [3] [4] [5,6,7] # ^ # feature features = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f2"], values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), ) sparse_arch(features)
- forward(features: KeyedJaggedTensor) KeyedTensor ¶
- Parameters:
features (KeyedJaggedTensor) –
- Returns:
an output KJT of size F * D X B.
- Return type:
- training: bool¶
torchrec.models.dlrm¶
- class torchrec.models.dlrm.DLRM(embedding_bag_collection: EmbeddingBagCollection, dense_in_features: int, dense_arch_layer_sizes: List[int], over_arch_layer_sizes: List[int], dense_device: Optional[device] = None)¶
Bases:
Module
Recsys model from “Deep Learning Recommendation Model for Personalization and Recommendation Systems” (https://arxiv.org/abs/1906.00091). Processes sparse features by learning pooled embeddings for each feature. Learns the relationship between dense features and sparse features by projecting dense features into the same embedding space. Also, learns the pairwise relationships between sparse features.
The module assumes all sparse features have the same embedding dimension (i.e. each EmbeddingBagConfig uses the same embedding_dim).
The following notation is used throughout the documentation for the models:
F: number of sparse features
D: embedding_dimension of sparse features
B: batch size
num_features: number of dense features
- Parameters:
embedding_bag_collection (EmbeddingBagCollection) – collection of embedding bags used to define SparseArch.
dense_in_features (int) – the dimensionality of the dense input features.
dense_arch_layer_sizes (List[int]) – the layer sizes for the DenseArch.
over_arch_layer_sizes (List[int]) – the layer sizes for the OverArch. The output dimension of the InteractionArch should not be manually specified here.
dense_device (Optional[torch.device]) – default compute device.
Example:
B = 2 D = 8 eb1_config = EmbeddingBagConfig( name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1"] ) eb2_config = EmbeddingBagConfig( name="t2", embedding_dim=D, num_embeddings=100, feature_names=["f2"], ) ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) model = DLRM( embedding_bag_collection=ebc, dense_in_features=100, dense_arch_layer_sizes=[20, D], over_arch_layer_sizes=[5, 1], ) features = torch.rand((B, 100)) # 0 1 # 0 [1,2] [4,5] # 1 [4,3] [2,9] # ^ # feature sparse_features = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f2"], values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]), offsets=torch.tensor([0, 2, 4, 6, 8]), ) logits = model( dense_features=features, sparse_features=sparse_features, )
- forward(dense_features: Tensor, sparse_features: KeyedJaggedTensor) Tensor ¶
- Parameters:
dense_features (torch.Tensor) – the dense features.
sparse_features (KeyedJaggedTensor) – the sparse features.
- Returns:
logits.
- Return type:
torch.Tensor
- training: bool¶
- class torchrec.models.dlrm.DLRMTrain(dlrm_module: DLRM)¶
Bases:
Module
nn.Module to wrap DLRM model to use with train_pipeline.
DLRM Recsys model from “Deep Learning Recommendation Model for Personalization and Recommendation Systems” (https://arxiv.org/abs/1906.00091). Processes sparse features by learning pooled embeddings for each feature. Learns the relationship between dense features and sparse features by projecting dense features into the same embedding space. Also, learns the pairwise relationships between sparse features.
The module assumes all sparse features have the same embedding dimension (i.e, each EmbeddingBagConfig uses the same embedding_dim)
- Parameters:
dlrm_module – DLRM module (DLRM or DLRM_Projection or DLRM_DCN) to be used in
training –
Example:
ebc = EmbeddingBagCollection(config=ebc_config) dlrm_module = DLRM( embedding_bag_collection=ebc, dense_in_features=100, dense_arch_layer_sizes=[20], over_arch_layer_sizes=[5, 1], ) dlrm_model = DLRMTrain(dlrm_module)
- forward(batch: Batch) Tuple[Tensor, Tuple[Tensor, Tensor, Tensor]] ¶
- Parameters:
batch – batch used with criteo and random data from torchrec.datasets
- Returns:
Tuple[loss, Tuple[loss, logits, labels]]
- training: bool¶
- class torchrec.models.dlrm.DLRM_DCN(embedding_bag_collection: EmbeddingBagCollection, dense_in_features: int, dense_arch_layer_sizes: List[int], over_arch_layer_sizes: List[int], dcn_num_layers: int, dcn_low_rank_dim: int, dense_device: Optional[device] = None)¶
Bases:
DLRM
Recsys model with DCN modified from the original model from “Deep Learning Recommendation Model for Personalization and Recommendation Systems” (https://arxiv.org/abs/1906.00091). Similar to DLRM module but has DeepCrossNet https://arxiv.org/pdf/2008.13535.pdf as the interaction layer.
The module assumes all sparse features have the same embedding dimension (i.e. each EmbeddingBagConfig uses the same embedding_dim).
The following notation is used throughout the documentation for the models:
F: number of sparse features
D: embedding_dimension of sparse features
B: batch size
num_features: number of dense features
- Parameters:
embedding_bag_collection (EmbeddingBagCollection) – collection of embedding bags used to define SparseArch.
dense_in_features (int) – the dimensionality of the dense input features.
dense_arch_layer_sizes (List[int]) – the layer sizes for the DenseArch.
over_arch_layer_sizes (List[int]) – the layer sizes for the OverArch. The output dimension of the InteractionArch should not be manually specified here.
dcn_num_layers (int) – the number of DCN layers in the interaction.
dcn_low_rank_dim (int) – the dimensionality of low rank approximation used in the dcn layers.
dense_device (Optional[torch.device]) – default compute device.
Example:
B = 2 D = 8 eb1_config = EmbeddingBagConfig( name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"] ) eb2_config = EmbeddingBagConfig( name="t2", embedding_dim=D, num_embeddings=100, feature_names=["f2"], ) ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) model = DLRM_DCN( embedding_bag_collection=ebc, dense_in_features=100, dense_arch_layer_sizes=[20, D], dcn_num_layers=2, dcn_low_rank_dim=8, over_arch_layer_sizes=[5, 1], ) features = torch.rand((B, 100)) # 0 1 # 0 [1,2] [4,5] # 1 [4,3] [2,9] # ^ # feature sparse_features = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f3"], values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]), offsets=torch.tensor([0, 2, 4, 6, 8]), ) logits = model( dense_features=features, sparse_features=sparse_features, )
- sparse_arch: SparseArch¶
- training: bool¶
- class torchrec.models.dlrm.DLRM_Projection(embedding_bag_collection: EmbeddingBagCollection, dense_in_features: int, dense_arch_layer_sizes: List[int], over_arch_layer_sizes: List[int], interaction_branch1_layer_sizes: List[int], interaction_branch2_layer_sizes: List[int], dense_device: Optional[device] = None)¶
Bases:
DLRM
Recsys model modified from the original model from “Deep Learning Recommendation Model for Personalization and Recommendation Systems” (https://arxiv.org/abs/1906.00091). Similar to DLRM module but has additional MLPs in the interaction layer (along 2 branches).
The module assumes all sparse features have the same embedding dimension (i.e. each EmbeddingBagConfig uses the same embedding_dim).
The following notation is used throughout the documentation for the models:
F: number of sparse features
D: embedding_dimension of sparse features
B: batch size
num_features: number of dense features
- Parameters:
embedding_bag_collection (EmbeddingBagCollection) – collection of embedding bags used to define SparseArch.
dense_in_features (int) – the dimensionality of the dense input features.
dense_arch_layer_sizes (List[int]) – the layer sizes for the DenseArch.
over_arch_layer_sizes (List[int]) – the layer sizes for the OverArch. The output dimension of the InteractionArch should not be manually specified here.
interaction_branch1_layer_sizes (List[int]) – the layer sizes for first branch of interaction layer. The output dimension must be a multiple of D.
interaction_branch2_layer_sizes (List[int]) – the layer sizes for second branch of interaction layer. The output dimension must be a multiple of D.
dense_device (Optional[torch.device]) – default compute device.
Example:
B = 2 D = 8 eb1_config = EmbeddingBagConfig( name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"] ) eb2_config = EmbeddingBagConfig( name="t2", embedding_dim=D, num_embeddings=100, feature_names=["f2"], ) ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) model = DLRM_Projection( embedding_bag_collection=ebc, dense_in_features=100, dense_arch_layer_sizes=[20, D], interaction_branch1_layer_sizes=[3*D+D, 4*D], interaction_branch2_layer_sizes=[3*D+D, 4*D], over_arch_layer_sizes=[5, 1], ) features = torch.rand((B, 100)) # 0 1 # 0 [1,2] [4,5] # 1 [4,3] [2,9] # ^ # feature sparse_features = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f3"], values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]), offsets=torch.tensor([0, 2, 4, 6, 8]), ) logits = model( dense_features=features, sparse_features=sparse_features, )
- sparse_arch: SparseArch¶
- training: bool¶
- class torchrec.models.dlrm.DenseArch(in_features: int, layer_sizes: List[int], device: Optional[device] = None)¶
Bases:
Module
Processes the dense features of DLRM model.
- Parameters:
in_features (int) – dimensionality of the dense input features.
layer_sizes (List[int]) – list of layer sizes.
device (Optional[torch.device]) – default compute device.
Example:
B = 20 D = 3 dense_arch = DenseArch(10, layer_sizes=[15, D]) dense_embedded = dense_arch(torch.rand((B, 10)))
- forward(features: Tensor) Tensor ¶
- Parameters:
features (torch.Tensor) – an input tensor of dense features.
- Returns:
an output tensor of size B X D.
- Return type:
torch.Tensor
- training: bool¶
- class torchrec.models.dlrm.InteractionArch(num_sparse_features: int)¶
Bases:
Module
Processes the output of both SparseArch (sparse_features) and DenseArch (dense_features). Returns the pairwise dot product of each sparse feature pair, the dot product of each sparse features with the output of the dense layer, and the dense layer itself (all concatenated).
Note
The dimensionality of the dense_features (D) is expected to match the dimensionality of the sparse_features so that the dot products between them can be computed.
- Parameters:
num_sparse_features (int) –
Example:
D = 3 B = 10 keys = ["f1", "f2"] F = len(keys) inter_arch = InteractionArch(num_sparse_features=len(keys)) dense_features = torch.rand((B, D)) sparse_features = torch.rand((B, F, D)) # B X (D + F + F choose 2) concat_dense = inter_arch(dense_features, sparse_features)
- forward(dense_features: Tensor, sparse_features: Tensor) Tensor ¶
- Parameters:
dense_features (torch.Tensor) – an input tensor of size B X D.
sparse_features (torch.Tensor) – an input tensor of size B X F X D.
- Returns:
an output tensor of size B X (D + F + F choose 2).
- Return type:
torch.Tensor
- training: bool¶
- class torchrec.models.dlrm.InteractionDCNArch(num_sparse_features: int, crossnet: Module)¶
Bases:
Module
Processes the output of both SparseArch (sparse_features) and DenseArch (dense_features). Returns the output of a Deep Cross Net v2 https://arxiv.org/pdf/2008.13535.pdf with a low rank approximation for the weight matrix. The input and output sizes are the same for this interaction layer (F*D + D).
Note
The dimensionality of the dense_features (D) is expected to match the dimensionality of the sparse_features so that the dot products between them can be computed.
- Parameters:
num_sparse_features (int) –
Example:
D = 3 B = 10 keys = ["f1", "f2"] F = len(keys) DCN = LowRankCrossNet( in_features = F*D+D, dcn_num_layers = 2, dnc_low_rank_dim = 4, ) inter_arch = InteractionDCNArch( num_sparse_features=len(keys), crossnet=DCN, ) dense_features = torch.rand((B, D)) sparse_features = torch.rand((B, F, D)) # B X (F*D + D) concat_dense = inter_arch(dense_features, sparse_features)
- forward(dense_features: Tensor, sparse_features: Tensor) Tensor ¶
- Parameters:
dense_features (torch.Tensor) – an input tensor of size B X D.
sparse_features (torch.Tensor) – an input tensor of size B X F X D.
- Returns:
an output tensor of size B X (F*D + D).
- Return type:
torch.Tensor
- training: bool¶
- class torchrec.models.dlrm.InteractionProjectionArch(num_sparse_features: int, interaction_branch1: Module, interaction_branch2: Module)¶
Bases:
Module
Processes the output of both SparseArch (sparse_features) and DenseArch (dense_features). Return Y*Z and the dense layer itself (all concatenated) where Y is the output of interaction branch 1 and Z is the output of interaction branch 2. Y and Z are of size Bx(F1xD) and Bx(DxF2) respectively for some F1 and F2.
Note
The dimensionality of the dense_features (D) is expected to match the dimensionality of the sparse_features so that the dot products between them can be computed. The output dimension of the 2 interaction branches should be a multiple of D.
- Parameters:
num_sparse_features (int) –
interaction_branch1 (nn.Module) – MLP module for the first branch of interaction layer
interaction_branch2 (nn.Module) – MLP module for the second branch of interaction layer
Example:
D = 3 B = 10 keys = ["f1", "f2"] F = len(keys) # Assume last layer of I1 = DenseArch( in_features= 3 * D + D, layer_sizes=[4*D, 4*D], # F1 = 4 device=dense_device, ) I2 = DenseArch( in_features= 3 * D + D, layer_sizes=[4*D, 4*D], # F2 = 4 device=dense_device, ) inter_arch = InteractionProjectionArch( num_sparse_features=len(keys), interaction_branch1 = I1, interaction_branch2 = I2, ) dense_features = torch.rand((B, D)) sparse_features = torch.rand((B, F, D)) # B X (D + F1 * F2) concat_dense = inter_arch(dense_features, sparse_features)
- forward(dense_features: Tensor, sparse_features: Tensor) Tensor ¶
- Parameters:
dense_features (torch.Tensor) – an input tensor of size B X D.
sparse_features (torch.Tensor) – an input tensor of size B X F X D.
- Returns:
an output tensor of size B X (D + F1 * F2)) where F1*D and F2*D are the output dimensions of the 2 interaction MLPs.
- Return type:
torch.Tensor
- training: bool¶
- class torchrec.models.dlrm.OverArch(in_features: int, layer_sizes: List[int], device: Optional[device] = None)¶
Bases:
Module
Final Arch of DLRM - simple MLP over OverArch.
- Parameters:
in_features (int) – size of the input.
layer_sizes (List[int]) – sizes of the layers of the OverArch.
device (Optional[torch.device]) – default compute device.
Example:
B = 20 D = 3 over_arch = OverArch(10, [5, 1]) logits = over_arch(torch.rand((B, 10)))
- forward(features: Tensor) Tensor ¶
- Parameters:
features (torch.Tensor) –
- Returns:
size B X layer_sizes[-1]
- Return type:
torch.Tensor
- training: bool¶
- class torchrec.models.dlrm.SparseArch(embedding_bag_collection: EmbeddingBagCollection)¶
Bases:
Module
Processes the sparse features of DLRM. Does embedding lookups for all EmbeddingBag and embedding features of each collection.
- Parameters:
embedding_bag_collection (EmbeddingBagCollection) – represents a collection of pooled embeddings.
Example:
eb1_config = EmbeddingBagConfig( name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] ) eb2_config = EmbeddingBagConfig( name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] ) ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) sparse_arch = SparseArch(embedding_bag_collection) # 0 1 2 <-- batch # 0 [0,1] None [2] # 1 [3] [4] [5,6,7] # ^ # feature features = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f2"], values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), ) sparse_embeddings = sparse_arch(features)
- forward(features: KeyedJaggedTensor) Tensor ¶
- Parameters:
features (KeyedJaggedTensor) – an input tensor of sparse features.
- Returns:
tensor of shape B X F X D.
- Return type:
torch.Tensor
- property sparse_feature_names: List[str]¶
- training: bool¶
- torchrec.models.dlrm.choose(n: int, k: int) int ¶
Simple implementation of math.comb for Python 3.7 compatibility.
Module contents¶
Torchrec Models
Torchrec provides the architecture for two popular recsys models; DeepFM and DLRM (Deep Learning Recommendation Model).
Along with the overall model, the individual architectures of each layer are also provided (e.g. SparseArch, DenseArch, InteractionArch, and OverArch).
Examples can be found within each model.
The following notation is used throughout the documentation for the models:
F: number of sparse features
D: embedding_dimension of sparse features
B: batch size
num_features: number of dense features