torchrec.models¶
Torchrec Models
Torchrec provides the architecture for two popular recsys models; DeepFM and DLRM (Deep Learning Recommendation Model).
Along with the overall model, the individual architectures of each layer are also provided (e.g. SparseArch, DenseArch, InteractionArch, and OverArch).
Examples can be found within each model.
The following notation is used throughout the documentation for the models:
F: number of sparse features
D: embedding_dimension of sparse features
B: batch size
num_features: number of dense features
torchrec.models.deepfm¶
- class torchrec.models.deepfm.DenseArch(in_features: int, hidden_layer_size: int, embedding_dim: int)¶
Bases:
torch.nn.modules.module.Module
Processes the dense features of DeepFMNN model. Output layer is sized to the embedding_dimension of the EmbeddingBagCollection embeddings.
- Parameters
in_features (int) – dimensionality of the dense input features.
hidden_layer_size (int) – sizes of the hidden layers in the DenseArch.
embedding_dim (int) – the same size of the embedding_dimension of sparseArch.
device (torch.device) – default compute device.
Example:
B = 20 D = 3 in_features = 10 dense_arch = DenseArch(in_features=10, hidden_layer_size=10, embedding_dim=D) dense_embedded = dense_arch(torch.rand((B, 10)))
- forward(features: torch.Tensor) torch.Tensor ¶
- Parameters
features (torch.Tensor) – size B X num_features.
- Returns
an output tensor of size B X D.
- Return type
torch.Tensor
- training: bool¶
- class torchrec.models.deepfm.FMInteractionArch(fm_in_features: int, sparse_feature_names: List[str], deep_fm_dimension: int)¶
Bases:
torch.nn.modules.module.Module
Processes the output of both SparseArch (sparse_features) and DenseArch (dense_features) and apply the general DeepFM interaction according to the external source of DeepFM paper: https://arxiv.org/pdf/1703.04247.pdf
The output dimension is expected to be a cat of dense_features, D.
- Parameters
fm_in_features (int) – the input dimension of dense_module in DeepFM. For example, if the input embeddings is [randn(3, 2, 3), randn(3, 4, 5)], then the fm_in_features should be: 2 * 3 + 4 * 5.
sparse_feature_names (List[str]) – length of F.
deep_fm_dimension (int) – output of the deep interaction (DI) in the DeepFM arch.
Example:
D = 3 B = 10 keys = ["f1", "f2"] F = len(keys) fm_inter_arch = FMInteractionArch(sparse_feature_names=keys) dense_features = torch.rand((B, D)) sparse_features = KeyedTensor( keys=keys, length_per_key=[D, D], values=torch.rand((B, D * F)), ) cat_fm_output = fm_inter_arch(dense_features, sparse_features)
- forward(dense_features: torch.Tensor, sparse_features: torchrec.sparse.jagged_tensor.KeyedTensor) torch.Tensor ¶
- Parameters
dense_features (torch.Tensor) – tensor of size B X D.
sparse_features (KeyedJaggedTensor) – KJT of size F * D X B.
- Returns
an output tensor of size B X (D + DI + 1).
- Return type
torch.Tensor
- training: bool¶
- class torchrec.models.deepfm.OverArch(in_features: int)¶
Bases:
torch.nn.modules.module.Module
Final Arch - simple MLP. The output is just one target.
- Parameters
in_features (int) – the output dimension of the interaction arch.
Example:
B = 20 over_arch = OverArch() logits = over_arch(torch.rand((B, 10)))
- forward(features: torch.Tensor) torch.Tensor ¶
- Parameters
features (torch.Tensor) –
- Returns
an output tensor of size B X 1.
- Return type
torch.Tensor
- training: bool¶
- class torchrec.models.deepfm.SimpleDeepFMNN(num_dense_features: int, embedding_bag_collection: torchrec.modules.embedding_modules.EmbeddingBagCollection, hidden_layer_size: int, deep_fm_dimension: int)¶
Bases:
torch.nn.modules.module.Module
Basic recsys module with DeepFM arch. Processes sparse features by learning pooled embeddings for each feature. Learns the relationship between dense features and sparse features by projecting dense features into the same embedding space. Learns the interaction among those dense and sparse features by deep_fm proposed in this paper: https://arxiv.org/pdf/1703.04247.pdf
The module assumes all sparse features have the same embedding dimension (i.e, each EmbeddingBagConfig uses the same embedding_dim)
The following notation is used throughout the documentation for the models:
F: number of sparse features
D: embedding_dimension of sparse features
B: batch size
num_features: number of dense features
- Parameters
num_dense_features (int) – the number of input dense features.
embedding_bag_collection (EmbeddingBagCollection) – collection of embedding bags used to define SparseArch.
hidden_layer_size (int) – the hidden layer size used in dense module.
deep_fm_dimension (int) – the output layer size used in deep_fm’s deep interaction module.
Example:
B = 2 D = 8 eb1_config = EmbeddingBagConfig( name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"] ) eb2_config = EmbeddingBagConfig( name="t2", embedding_dim=D, num_embeddings=100, feature_names=["f2"], ) ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) sparse_nn = SimpleDeepFMNN( embedding_bag_collection=ebc, hidden_layer_size=20, over_embedding_dim=5 ) features = torch.rand((B, 100)) # 0 1 # 0 [1,2] [4,5] # 1 [4,3] [2,9] # ^ # feature sparse_features = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f3"], values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]), offsets=torch.tensor([0, 2, 4, 6, 8]), ) logits = sparse_nn( dense_features=features, sparse_features=sparse_features, )
- forward(dense_features: torch.Tensor, sparse_features: torchrec.sparse.jagged_tensor.KeyedJaggedTensor) torch.Tensor ¶
- Parameters
dense_features (torch.Tensor) – the dense features.
sparse_features (KeyedJaggedTensor) – the sparse features.
- Returns
logits with size B X 1.
- Return type
torch.Tensor
- training: bool¶
- class torchrec.models.deepfm.SparseArch(embedding_bag_collection: torchrec.modules.embedding_modules.EmbeddingBagCollection)¶
Bases:
torch.nn.modules.module.Module
Processes the sparse features of the DeepFMNN model. Does embedding lookups for all EmbeddingBag and embedding features of each collection.
- Parameters
embedding_bag_collection (EmbeddingBagCollection) – represents a collection of pooled embeddings.
Example:
eb1_config = EmbeddingBagConfig( name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] ) eb2_config = EmbeddingBagConfig( name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] ) ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) # 0 1 2 <-- batch # 0 [0,1] None [2] # 1 [3] [4] [5,6,7] # ^ # feature features = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f2"], values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), ) sparse_arch(features)
- forward(features: torchrec.sparse.jagged_tensor.KeyedJaggedTensor) torchrec.sparse.jagged_tensor.KeyedTensor ¶
- Parameters
features (KeyedJaggedTensor) –
- Returns
an output KJT of size F * D X B.
- Return type
- training: bool¶
torchrec.models.dlrm¶
- class torchrec.models.dlrm.DLRM(embedding_bag_collection: torchrec.modules.embedding_modules.EmbeddingBagCollection, dense_in_features: int, dense_arch_layer_sizes: List[int], over_arch_layer_sizes: List[int], dense_device: Optional[torch.device] = None)¶
Bases:
torch.nn.modules.module.Module
Recsys model from “Deep Learning Recommendation Model for Personalization and Recommendation Systems” (https://arxiv.org/abs/1906.00091). Processes sparse features by learning pooled embeddings for each feature. Learns the relationship between dense features and sparse features by projecting dense features into the same embedding space. Also, learns the pairwise relationships between sparse features.
The module assumes all sparse features have the same embedding dimension (i.e. each EmbeddingBagConfig uses the same embedding_dim).
The following notation is used throughout the documentation for the models:
F: number of sparse features
D: embedding_dimension of sparse features
B: batch size
num_features: number of dense features
- Parameters
embedding_bag_collection (EmbeddingBagCollection) – collection of embedding bags used to define SparseArch.
dense_in_features (int) – the dimensionality of the dense input features.
dense_arch_layer_sizes (List[int]) – the layer sizes for the DenseArch.
over_arch_layer_sizes (List[int]) – the layer sizes for the OverArch. The output dimension of the InteractionArch should not be manually specified here.
dense_device (Optional[torch.device]) – default compute device.
Example:
B = 2 D = 8 eb1_config = EmbeddingBagConfig( name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"] ) eb2_config = EmbeddingBagConfig( name="t2", embedding_dim=D, num_embeddings=100, feature_names=["f2"], ) ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) model = DLRM( embedding_bag_collection=ebc, dense_in_features=100, dense_arch_layer_sizes=[20], over_arch_layer_sizes=[5, 1], ) features = torch.rand((B, 100)) # 0 1 # 0 [1,2] [4,5] # 1 [4,3] [2,9] # ^ # feature sparse_features = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f3"], values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]), offsets=torch.tensor([0, 2, 4, 6, 8]), ) logits = model( dense_features=features, sparse_features=sparse_features, )
- forward(dense_features: torch.Tensor, sparse_features: torchrec.sparse.jagged_tensor.KeyedJaggedTensor) torch.Tensor ¶
- Parameters
dense_features (torch.Tensor) – the dense features.
sparse_features (KeyedJaggedTensor) – the sparse features.
- Returns
logits.
- Return type
torch.Tensor
- training: bool¶
- class torchrec.models.dlrm.DenseArch(in_features: int, layer_sizes: List[int], device: Optional[torch.device] = None)¶
Bases:
torch.nn.modules.module.Module
Processes the dense features of DLRM model.
- Parameters
in_features (int) – dimensionality of the dense input features.
layer_sizes (List[int]) – list of layer sizes.
device (Optional[torch.device]) – default compute device.
Example:
B = 20 D = 3 dense_arch = DenseArch(10, layer_sizes=[15, D]) dense_embedded = dense_arch(torch.rand((B, 10)))
- forward(features: torch.Tensor) torch.Tensor ¶
- Parameters
features (torch.Tensor) – an input tensor of dense features.
- Returns
an output tensor of size B X D.
- Return type
torch.Tensor
- training: bool¶
- class torchrec.models.dlrm.InteractionArch(num_sparse_features: int)¶
Bases:
torch.nn.modules.module.Module
Processes the output of both SparseArch (sparse_features) and DenseArch (dense_features). Returns the pairwise dot product of each sparse feature pair, the dot product of each sparse features with the output of the dense layer, and the dense layer itself (all concatenated).
Note
The dimensionality of the dense_features (D) is expected to match the dimensionality of the sparse_features so that the dot products between them can be computed.
- Parameters
num_sparse_features (int) –
Example:
D = 3 B = 10 keys = ["f1", "f2"] F = len(keys) inter_arch = InteractionArch(num_sparse_features=len(keys)) dense_features = torch.rand((B, D)) sparse_features = torch.rand((B, F, D)) # B X (D + F + F choose 2) concat_dense = inter_arch(dense_features, sparse_features)
- forward(dense_features: torch.Tensor, sparse_features: torch.Tensor) torch.Tensor ¶
- Parameters
dense_features (torch.Tensor) – an input tensor of size B X D.
sparse_features (torch.Tensor) – an input tensor of size B X F X D.
- Returns
an output tensor of size B X (D + F + F choose 2).
- Return type
torch.Tensor
- training: bool¶
- class torchrec.models.dlrm.OverArch(in_features: int, layer_sizes: List[int], device: Optional[torch.device] = None)¶
Bases:
torch.nn.modules.module.Module
Final Arch of DLRM - simple MLP over OverArch.
- Parameters
in_features (int) – size of the input.
layer_sizes (List[int]) – sizes of the layers of the OverArch.
device (Optional[torch.device]) – default compute device.
Example:
B = 20 D = 3 over_arch = OverArch(10, [5, 1]) logits = over_arch(torch.rand((B, 10)))
- forward(features: torch.Tensor) torch.Tensor ¶
- Parameters
features (torch.Tensor) –
- Returns
size B X layer_sizes[-1]
- Return type
torch.Tensor
- training: bool¶
- class torchrec.models.dlrm.SparseArch(embedding_bag_collection: torchrec.modules.embedding_modules.EmbeddingBagCollection)¶
Bases:
torch.nn.modules.module.Module
Processes the sparse features of DLRM. Does embedding lookups for all EmbeddingBag and embedding features of each collection.
- Parameters
embedding_bag_collection (EmbeddingBagCollection) – represents a collection of pooled embeddings.
Example:
eb1_config = EmbeddingBagConfig( name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] ) eb2_config = EmbeddingBagConfig( name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] ) ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) sparse_arch = SparseArch(embedding_bag_collection) # 0 1 2 <-- batch # 0 [0,1] None [2] # 1 [3] [4] [5,6,7] # ^ # feature features = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f2"], values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), ) sparse_embeddings = sparse_arch(features)
- forward(features: torchrec.sparse.jagged_tensor.KeyedJaggedTensor) torch.Tensor ¶
- Parameters
features (KeyedJaggedTensor) – an input tensor of sparse features.
- Returns
tensor of shape B X F X D.
- Return type
torch.Tensor
- property sparse_feature_names: List[str]¶
- training: bool¶
- torchrec.models.dlrm.choose(n: int, k: int) int ¶
Simple implementation of math.comb for Python 3.7 compatibility.
Module contents¶
Torchrec Models
Torchrec provides the architecture for two popular recsys models; DeepFM and DLRM (Deep Learning Recommendation Model).
Along with the overall model, the individual architectures of each layer are also provided (e.g. SparseArch, DenseArch, InteractionArch, and OverArch).
Examples can be found within each model.
The following notation is used throughout the documentation for the models:
F: number of sparse features
D: embedding_dimension of sparse features
B: batch size
num_features: number of dense features