TorchRec Concepts¶
In this section, we will learn about the key concepts of TorchRec, designed to optimize large-scale recommendation systems using PyTorch. We will learn how each concept works in detail and review how it is used with the rest of TorchRec.
TorchRec has specific input/output data types of its modules to efficiently represent sparse features, including:
JaggedTensor: a wrapper around the lengths/offsets and values tensors for a singular sparse feature.
KeyedJaggedTensor: efficiently represent multiple sparse features, can think of it as multiple
JaggedTensor
s.KeyedTensor: a wrapper around
torch.Tensor
that allows access to tensor values through keys.
With the goal of high performance and efficiency, the canonical
torch.Tensor
is highly inefficient for representing sparse data.
TorchRec introduces these new data types because they provide efficient
storage and representation of sparse input data. As you will see later
on, the KeyedJaggedTensor
makes communication of input data in a
distributed environment very efficient leading to one of the key
performance advantages that TorchRec provides.
In the end-to-end training loop, TorchRec comprises of the following main components:
Planner: Takes in the configuration of embedding tables, environment setup, and generates an optimized sharding plan for the model.
Sharder: Shards model according to sharding plan with different sharding strategies including data-parallel, table-wise, row-wise, table-wise-row-wise, column-wise, and table-wise-column-wise sharding.
DistributedModelParallel: Combines sharder, optimizer, and provides an entry point into the training the model in a distributed manner.
JaggedTensor¶
A JaggedTensor
represents a sparse feature through lengths, values,
and offsets. It is called “jagged” because it efficiently represents
data with variable-length sequences. In contrast, a canonical
torch.Tensor
assumes that each sequence has the same length, which
is often not the case with real world data. A JaggedTensor
facilitates the representation of such data without padding making it
highly efficient.
Key Components:
Lengths
: A list of integers representing the number of elements for each entity.Offsets
: A list of integers representing the starting index of each sequence in the flattened values tensor. These provide an alternative to lengths.Values
: A 1D tensor containing the actual values for each entity, stored contiguously.
Here is a simple example demonstrating how each of the components would look like:
# User interactions:
# - User 1 interacted with 2 items
# - User 2 interacted with 3 items
# - User 3 interacted with 1 item
lengths = [2, 3, 1]
offsets = [0, 2, 5] # Starting index of each user's interactions
values = torch.Tensor([101, 102, 201, 202, 203, 301]) # Item IDs interacted with
jt = JaggedTensor(lengths=lengths, values=values)
# OR
jt = JaggedTensor(offsets=offsets, values=values)
KeyedJaggedTensor¶
A KeyedJaggedTensor
extends the functionality of JaggedTensor
by
introducing keys (which are typically feature names) to label different
groups of features, for example, user features and item features. This
is the data type used in forward
of EmbeddingBagCollection
and
EmbeddingCollection
as they are used to represent multiple features
in a table.
A KeyedJaggedTensor
has an implied batch size, which is the number
of features divided by the length of lengths
tensor. The example
below has a batch size of 2. Similar to a JaggedTensor
, the
offsets
and lengths
function in the same manner. You can also
access the lengths
, offsets
, and values
of a feature by
accessing the key from the KeyedJaggedTensor
.
keys = ["user_features", "item_features"]
# Lengths of interactions:
# - User features: 2 users, with 2 and 3 interactions respectively
# - Item features: 2 items, with 1 and 2 interactions respectively
lengths = [2, 3, 1, 2]
values = torch.Tensor([11, 12, 21, 22, 23, 101, 102, 201])
# Create a KeyedJaggedTensor
kjt = KeyedJaggedTensor(keys=keys, lengths=lengths, values=values)
# Access the features by key
print(kjt["user_features"])
# Outputs user features
print(kjt["item_features"])
Planner¶
The TorchRec planner helps determine the best sharding configuration for a model. It evaluates multiple possibilities for sharding embedding tables and optimizes for performance. The planner performs the following:
Assesses the memory constraints of the hardware.
Estimates compute requirements based on memory fetches, such as embedding lookups.
Addresses data-specific factors.
Considers other hardware specifics, such as bandwidth, to generate an optimal sharding plan.
To ensure accurate consideration of these factors, the Planner can incorporate data about the embedding tables, constraints, hardware information, and topology to help in generating an optimal plan.
Distributed Training with TorchRec Sharded Modules¶
With many sharding strategies available, how do we determine which one to use? There is a cost associated with each sharding scheme, which in conjunction with model size and number of GPUs determines which sharding strategy is best for a model.
Without sharding, where each GPU keeps a copy of the embedding table (DP), the main cost is computation in which each GPU looks up the embedding vectors in its memory in the forward pass and updates the gradients in the backward pass.
With sharding, there is an added communication cost: each GPU needs to
ask the other GPUs for embedding vector lookup and communicate the
gradients computed as well. This is typically referred to as all2all
communication. In TorchRec, for input data on a given GPU, we determine
where the embedding shard for each part of the data is located and send
it to the target GPU. That target GPU then returns the embedding vectors
back to the original GPU. In the backward pass, the gradients are sent
back to the target GPU and the shards are updated accordingly with the
optimizer.
As described above, sharding requires us to communicate the input data and embedding lookups. TorchRec handles this in three main stages, we will refer to this as the sharded embedding module forward that is used in training and inference of a TorchRec model:
Feature All to All/Input distribution (
input_dist
)Communicate input data (in the form of a
KeyedJaggedTensor
) to the appropriate device containing relevant embedding table shard
Embedding Lookup
Lookup embeddings with new input data formed after feature all to all exchange
Embedding All to All/Output Distribution (
output_dist
)Communicate embedding lookup data back to the appropriate device that asked for it (in accordance with the input data the device received)
The backward pass does the same operations but in reverse order.
The diagram below demonstrates how it works:
DistributedModelParallel¶
All of the above culminates into the main entrypoint that TorchRec uses
to shard and integrate the plan. At a high level,
DistributedModelParallel
does the following:
Initializes the environment by setting up process groups and assigning device type.
Uses default shaders if no shaders are provided, the default includes
EmbeddingBagCollectionSharder
.Takes in the provided sharding plan, if none is provided, it generates one.
Creates a sharded version of modules and replaces the original modules with them, for example, converts
EmbeddingCollection
toShardedEmbeddingCollection
.By default, wraps the
DistributedModelParallel
withDistributedDataParallel
to make the module both model and data parallel.
Optimizer¶
TorchRec modules provide a seamless API to fuse the backwards pass and optimizer step in training, providing a significant optimization in performance and decreasing the memory used, alongside granularity in assigning distinct optimizers to distinct model parameters.
Inference¶
Inference environments are different from training, they are very sensitive to performance and the size of the model. There are two key differences TorchRec inference optimizes for:
Quantization: inference models are quantized for lower latency and reduced model size. This optimization lets us use as few devices as possible for inference to minimize latency.
C++ environment: to minimize latency even further, the model is ran in a C++ environment.
TorchRec provides the following to convert a TorchRec model into being inference ready:
APIs for quantizing the model, including optimizations automatically with FBGEMM TBE
Sharding embeddings for distributed inference
Compiling the model to TorchScript (compatible in C++)