Introduction to TorchRec ======================== .. tip:: To get the most of this tutorial, we suggest using this `Colab Version `__. This will allow you to experiment with the information presented below. Follow along with the video below or on `youtube `__. .. raw:: html
When building recommendation systems, we frequently want to represent entities like products or pages with embeddings. For example, see Meta AI’s `Deep learning recommendation model `__, or DLRM. As the number of entities grow, the size of the embedding tables can exceed a single GPU’s memory. A common practice is to shard the embedding table across devices, a type of model parallelism. To that end, TorchRec introduces its primary API called |DistributedModelParallel|_, or DMP. Like PyTorch’s DistributedDataParallel, DMP wraps a model to enable distributed training. Installation ------------ Requirements: python >= 3.7 We highly recommend CUDA when using TorchRec (If using CUDA: cuda >= 11.0). .. code:: shell # install pytorch with cudatoolkit 11.3 conda install pytorch cudatoolkit=11.3 -c pytorch-nightly -y # install TorchRec pip3 install torchrec-nightly Overview -------- This tutorial will cover three pieces of TorchRec: the ``nn.module`` |EmbeddingBagCollection|_, the |DistributedModelParallel|_ API, and the datastructure |KeyedJaggedTensor|_. Distributed Setup ~~~~~~~~~~~~~~~~~ We setup our environment with torch.distributed. For more info on distributed, see this `tutorial `__. Here, we use one rank (the colab process) corresponding to our 1 colab GPU. .. code:: python import os import torch import torchrec import torch.distributed as dist os.environ["RANK"] = "0" os.environ["WORLD_SIZE"] = "1" os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29500" # Note - you will need a V100 or A100 to run tutorial as as! # If using an older GPU (such as colab free K80), # you will need to compile fbgemm with the appripriate CUDA architecture # or run with "gloo" on CPUs dist.init_process_group(backend="nccl") From EmbeddingBag to EmbeddingBagCollection ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PyTorch represents embeddings through |torch.nn.Embedding|_ and |torch.nn.EmbeddingBag|_. EmbeddingBag is a pooled version of Embedding. TorchRec extends these modules by creating collections of embeddings. We will use |EmbeddingBagCollection|_ to represent a group of EmbeddingBags. Here, we create an EmbeddingBagCollection (EBC) with two embedding bags. Each table, ``product_table`` and ``user_table``, is represented by a 64 dimension embedding of size 4096. Note how we initially allocate the EBC on device “meta”. This will tell EBC to not allocate memory yet. .. code:: python ebc = torchrec.EmbeddingBagCollection( device="meta", tables=[ torchrec.EmbeddingBagConfig( name="product_table", embedding_dim=64, num_embeddings=4096, feature_names=["product"], pooling=torchrec.PoolingType.SUM, ), torchrec.EmbeddingBagConfig( name="user_table", embedding_dim=64, num_embeddings=4096, feature_names=["user"], pooling=torchrec.PoolingType.SUM, ) ] ) DistributedModelParallel ~~~~~~~~~~~~~~~~~~~~~~~~ Now, we’re ready to wrap our model with |DistributedModelParallel|_ (DMP). Instantiating DMP will: 1. Decide how to shard the model. DMP will collect the available ‘sharders’ and come up with a ‘plan’ of the optimal way to shard the embedding table(s) (i.e., the EmbeddingBagCollection). 2. Actually shard the model. This includes allocating memory for each embedding table on the appropriate device(s). In this toy example, since we have two EmbeddingTables and one GPU, TorchRec will place both on the single GPU. .. code:: python model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device("cuda")) print(model) print(model.plan) Query vanilla nn.EmbeddingBag with input and offsets ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We query |nn.Embedding|_ and |nn.EmbeddingBag|_ with ``input`` and ``offsets``. Input is a 1-D tensor containing the lookup values. Offsets is a 1-D tensor where the sequence is a cumulative sum of the number of values to pool per example. Let’s look at an example, recreating the product EmbeddingBag above: :: |------------| | product ID | |------------| | [101, 202] | | [] | | [303] | |------------| .. code:: python product_eb = torch.nn.EmbeddingBag(4096, 64) product_eb(input=torch.tensor([101, 202, 303]), offsets=torch.tensor([0, 2, 2])) Representing minibatches with KeyedJaggedTensor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We need an efficient representation of multiple examples of an arbitrary number of entity IDs per feature per example. In order to enable this “jagged” representation, we use the TorchRec datastructure |KeyedJaggedTensor|_ (KJT). Let’s take a look at how to lookup a collection of two embedding bags, “product” and “user”. Assume the minibatch is made up of three examples for three users. The first of which has two product IDs, the second with none, and the third with one product ID. :: |------------|------------| | product ID | user ID | |------------|------------| | [101, 202] | [404] | | [] | [505] | | [303] | [606] | |------------|------------| The query should be: .. code:: python mb = torchrec.KeyedJaggedTensor( keys = ["product", "user"], values = torch.tensor([101, 202, 303, 404, 505, 606]).cuda(), lengths = torch.tensor([2, 0, 1, 1, 1, 1], dtype=torch.int64).cuda(), ) print(mb.to(torch.device("cpu"))) Note that the KJT batch size is ``batch_size = len(lengths)//len(keys)``. In the above example, batch_size is 3. Putting it all together, querying our distributed model with a KJT minibatch ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Finally, we can query our model using our minibatch of products and users. The resulting lookup will contain a KeyedTensor, where each key (or feature) contains a 2D tensor of size 3x64 (batch_size x embedding_dim). .. code:: python pooled_embeddings = model(mb) print(pooled_embeddings) More resources -------------- For more information, please see our `dlrm `__ example, which includes multinode training on the criteo terabyte dataset, using Meta’s `DLRM `__. .. |DistributedModelParallel| replace:: ``DistributedModelParallel`` .. _DistributedModelParallel: https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.model_parallel.DistributedModelParallel .. |EmbeddingBagCollection| replace:: ``EmbeddingBagCollection`` .. _EmbeddingBagCollection: https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingBagCollection .. |KeyedJaggedTensor| replace:: ``KeyedJaggedTensor`` .. _KeyedJaggedTensor: https://pytorch.org/torchrec/torchrec.sparse.html#torchrec.sparse.jagged_tensor.JaggedTensor .. |torch.nn.Embedding| replace:: ``torch.nn.Embedding`` .. _torch.nn.Embedding: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html .. |torch.nn.EmbeddingBag| replace:: ``torch.nn.EmbeddingBag`` .. _torch.nn.EmbeddingBag: https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html .. |nn.Embedding| replace:: ``nn.Embedding`` .. _nn.Embedding: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html .. |nn.EmbeddingBag| replace:: ``nn.EmbeddingBag`` .. _nn.EmbeddingBag: https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html