Shortcuts

torchrec.models

torchrec.models.deepfm

class torchrec.models.deepfm.DenseArch(in_features: int, hidden_layer_size: int, embedding_dim: int)

Bases: Module

Processes the dense features of DeepFMNN model. Output layer is sized to the embedding_dimension of the EmbeddingBagCollection embeddings.

Parameters:
  • in_features (int) – dimensionality of the dense input features.

  • hidden_layer_size (int) – sizes of the hidden layers in the DenseArch.

  • embedding_dim (int) – the same size of the embedding_dimension of sparseArch.

  • device (torch.device) – default compute device.

Example:

B = 20
D = 3
in_features = 10
dense_arch = DenseArch(
    in_features=in_features, hidden_layer_size=10, embedding_dim=D
)

dense_arch_input = torch.rand((B, in_features))
dense_embedded = dense_arch(dense_arch_input)
forward(features: Tensor) Tensor
Parameters:

features (torch.Tensor) – size B X num_features.

Returns:

an output tensor of size B X D.

Return type:

torch.Tensor

training: bool
class torchrec.models.deepfm.FMInteractionArch(fm_in_features: int, sparse_feature_names: List[str], deep_fm_dimension: int)

Bases: Module

Processes the output of both SparseArch (sparse_features) and DenseArch (dense_features) and apply the general DeepFM interaction according to the external source of DeepFM paper: https://arxiv.org/pdf/1703.04247.pdf

The output dimension is expected to be a cat of dense_features, D.

Parameters:
  • fm_in_features (int) – the input dimension of dense_module in DeepFM. For example, if the input embeddings is [randn(3, 2, 3), randn(3, 4, 5)], then the fm_in_features should be: 2 * 3 + 4 * 5.

  • sparse_feature_names (List[str]) – length of F.

  • deep_fm_dimension (int) – output of the deep interaction (DI) in the DeepFM arch.

Example:

D = 3
B = 10
keys = ["f1", "f2"]
F = len(keys)
fm_inter_arch = FMInteractionArch(sparse_feature_names=keys)
dense_features = torch.rand((B, D))
sparse_features = KeyedTensor(
    keys=keys,
    length_per_key=[D, D],
    values=torch.rand((B, D * F)),
)
cat_fm_output = fm_inter_arch(dense_features, sparse_features)
forward(dense_features: Tensor, sparse_features: KeyedTensor) Tensor
Parameters:
  • dense_features (torch.Tensor) – tensor of size B X D.

  • sparse_features (KeyedJaggedTensor) – KJT of size F * D X B.

Returns:

an output tensor of size B X (D + DI + 1).

Return type:

torch.Tensor

training: bool
class torchrec.models.deepfm.OverArch(in_features: int)

Bases: Module

Final Arch - simple MLP. The output is just one target.

Parameters:

in_features (int) – the output dimension of the interaction arch.

Example:

B = 20
over_arch = OverArch()
logits = over_arch(torch.rand((B, 10)))
forward(features: Tensor) Tensor
Parameters:

features (torch.Tensor) –

Returns:

an output tensor of size B X 1.

Return type:

torch.Tensor

training: bool
class torchrec.models.deepfm.SimpleDeepFMNN(num_dense_features: int, embedding_bag_collection: EmbeddingBagCollection, hidden_layer_size: int, deep_fm_dimension: int)

Bases: Module

Basic recsys module with DeepFM arch. Processes sparse features by learning pooled embeddings for each feature. Learns the relationship between dense features and sparse features by projecting dense features into the same embedding space. Learns the interaction among those dense and sparse features by deep_fm proposed in this paper: https://arxiv.org/pdf/1703.04247.pdf

The module assumes all sparse features have the same embedding dimension (i.e, each EmbeddingBagConfig uses the same embedding_dim)

The following notation is used throughout the documentation for the models:

  • F: number of sparse features

  • D: embedding_dimension of sparse features

  • B: batch size

  • num_features: number of dense features

Parameters:
  • num_dense_features (int) – the number of input dense features.

  • embedding_bag_collection (EmbeddingBagCollection) – collection of embedding bags used to define SparseArch.

  • hidden_layer_size (int) – the hidden layer size used in dense module.

  • deep_fm_dimension (int) – the output layer size used in deep_fm’s deep interaction module.

Example:

B = 2
D = 8

eb1_config = EmbeddingBagConfig(
    name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"]
)
eb2_config = EmbeddingBagConfig(
    name="t2",
    embedding_dim=D,
    num_embeddings=100,
    feature_names=["f2"],
)

ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
sparse_nn = SimpleDeepFMNN(
    embedding_bag_collection=ebc, hidden_layer_size=20, over_embedding_dim=5
)

features = torch.rand((B, 100))

#     0       1
# 0   [1,2] [4,5]
# 1   [4,3] [2,9]
# ^
# feature
sparse_features = KeyedJaggedTensor.from_offsets_sync(
    keys=["f1", "f3"],
    values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]),
    offsets=torch.tensor([0, 2, 4, 6, 8]),
)

logits = sparse_nn(
    dense_features=features,
    sparse_features=sparse_features,
)
forward(dense_features: Tensor, sparse_features: KeyedJaggedTensor) Tensor
Parameters:
  • dense_features (torch.Tensor) – the dense features.

  • sparse_features (KeyedJaggedTensor) – the sparse features.

Returns:

logits with size B X 1.

Return type:

torch.Tensor

training: bool
class torchrec.models.deepfm.SparseArch(embedding_bag_collection: EmbeddingBagCollection)

Bases: Module

Processes the sparse features of the DeepFMNN model. Does embedding lookups for all EmbeddingBag and embedding features of each collection.

Parameters:

embedding_bag_collection (EmbeddingBagCollection) – represents a collection of pooled embeddings.

Example:

eb1_config = EmbeddingBagConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
eb2_config = EmbeddingBagConfig(
    name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"]
)

ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])

#     0       1        2  <-- batch
# 0   [0,1] None    [2]
# 1   [3]    [4]    [5,6,7]
# ^
# feature
features = KeyedJaggedTensor.from_offsets_sync(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
    offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)

sparse_arch(features)
forward(features: KeyedJaggedTensor) KeyedTensor
Parameters:

features (KeyedJaggedTensor) –

Returns:

an output KJT of size F * D X B.

Return type:

KeyedJaggedTensor

training: bool

torchrec.models.dlrm

Module contents

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources