.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "intermediate/torchrec_intro_tutorial.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_intermediate_torchrec_intro_tutorial.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_intermediate_torchrec_intro_tutorial.py:


Introduction to TorchRec
==================================

**TorchRec** is a PyTorch library tailored for building scalable and efficient recommendation systems using embeddings. 
This tutorial guides you through the installation process, introduces the concept of embeddings, and highlights their importance in
recommendation systems. It offers practical demonstrations on implementing embeddings with PyTorch
and TorchRec, focusing on handling large embedding tables through distributed training and advanced optimizations.

.. grid:: 2

    .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
       :class-card: card-prerequisites
       
       * Fundamentals of embeddings and their role in recommendation systems
       * How to set up TorchRec to manage and implement embeddings in PyTorch environments
       * Explore advanced techniques for distributing large embedding tables across multiple GPUs

    .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
       :class-card: card-prerequisites

       * PyTorch v2.5 or later with CUDA 11.8 or later
       * Python 3.9 or later
       * `FBGEMM <https://github.com/pytorch/fbgemm>`__

.. GENERATED FROM PYTHON SOURCE LINES 30-48

Install Dependencies
^^^^^^^^^^^^^^^^^^^^

Before running this tutorial in Google Colab or other environment, install the
following dependencies:

.. code-block:: sh

   !pip3 install --pre torch --index-url https://download.pytorch.org/whl/cu121 -U
   !pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/cu121
   !pip3 install torchmetrics==1.0.3
   !pip3 install torchrec --index-url https://download.pytorch.org/whl/cu121

.. note:: 
   If you are running this in Google Colab, make sure to switch to a GPU runtime type.
   For more information,
   see `Enabling CUDA <https://pytorch.org/tutorials/beginner/colab#enabling-cuda>`__


.. GENERATED FROM PYTHON SOURCE LINES 52-114

Embeddings
~~~~~~~~~~

When building recommendation systems, categorical features typically
have massive cardinality, posts, users, ads, and so on.

In order to represent these entities and model these relationships,
**embeddings** are used. In machine learning, **embeddings are a vectors
of real numbers in a high-dimensional space used to represent meaning in
complex data like words, images, or users**.

Embeddings in RecSys
~~~~~~~~~~~~~~~~~~~~

Now you might wonder, how are these embeddings generated in the first
place? Well, embeddings are represented as individual rows in an
**Embedding Table**, also referred to as embedding weights. The reason
for this is that embeddings or embedding table weights are trained just
like all of the other weights of the model via gradient descent!

Embedding tables are simply a large matrix for storing embeddings, with
two dimensions (B, N), where:

* B is the number of embeddings stored by the table
* N is the number of dimensions per embedding (N-dimensional embedding).

The inputs to embedding tables represent embedding lookups to retrieve
the embedding for a specific index or row. In recommendation systems, such
as those used in many large systems, unique IDs are not only used for
specific users, but also across entities like posts and ads to serve as
lookup indices to respective embedding tables!

Embeddings are trained in RecSys through the following process:

* **Input/lookup indices are fed into the model, as unique IDs**. IDs are
  hashed to the total size of the embedding table to prevent issues when
  the ID > number of rows

* Embeddings are then retrieved and **pooled, such as taking the sum or 
  mean of the embeddings**. This is required as there can be a variable number of
  embeddings per example while the model expects consistent shapes.

* The **embeddings are used in conjunction with the rest of the model to
  produce a prediction**, such as `Click-Through Rate
  (CTR) <https://support.google.com/google-ads/answer/2615875?hl=en>`__
  for an ad.

* The loss is calculated with the prediction and the label
  for an example, and **all weights of the model are updated through
  gradient descent and backpropagation, including the embedding weights**
  that were associated with the example.

These embeddings are crucial for representing categorical features, such
as users, posts, and ads, in order to capture relationships and make
good recommendations. The `Deep learning recommendation
model <https://arxiv.org/abs/1906.00091>`__ (DLRM) paper talks more
about the technical details of using embedding tables in RecSys.

This tutorial introduces the concept of embeddings, showcase
TorchRec specific modules and data types, and depict how distributed training
works with TorchRec.


.. GENERATED FROM PYTHON SOURCE LINES 114-118

.. code-block:: default


    import torch









.. GENERATED FROM PYTHON SOURCE LINES 119-134

Embeddings in PyTorch
---------------------

In PyTorch, we have the following types of embeddings:  

* :class:`torch.nn.Embedding`: An embedding table where forward pass returns the
  embeddings themselves as is.

* :class:`torch.nn.EmbeddingBag`: Embedding table where forward pass returns
  embeddings that are then pooled, for example, sum or mean, otherwise known
  as **Pooled Embeddings**.

In this section, we will go over a very brief introduction to performing
embedding lookups by passing in indices into the table. 


.. GENERATED FROM PYTHON SOURCE LINES 134-178

.. code-block:: default


    num_embeddings, embedding_dim = 10, 4

    # Initialize our embedding table
    weights = torch.rand(num_embeddings, embedding_dim)
    print("Weights:", weights)

    # Pass in pre-generated weights just for example, typically weights are randomly initialized
    embedding_collection = torch.nn.Embedding(
        num_embeddings, embedding_dim, _weight=weights
    )
    embedding_bag_collection = torch.nn.EmbeddingBag(
        num_embeddings, embedding_dim, _weight=weights
    )

    # Print out the tables, we should see the same weights as above
    print("Embedding Collection Table: ", embedding_collection.weight)
    print("Embedding Bag Collection Table: ", embedding_bag_collection.weight)

    # Lookup rows (ids for embedding ids) from the embedding tables
    # 2D tensor with shape (batch_size, ids for each batch)
    ids = torch.tensor([[1, 3]])
    print("Input row IDS: ", ids)

    embeddings = embedding_collection(ids)

    # Print out the embedding lookups
    # You should see the specific embeddings be the same as the rows (ids) of the embedding tables above
    print("Embedding Collection Results: ")
    print(embeddings)
    print("Shape: ", embeddings.shape)

    # ``nn.EmbeddingBag`` default pooling is mean, so should be mean of batch dimension of values above
    pooled_embeddings = embedding_bag_collection(ids)

    print("Embedding Bag Collection Results: ")
    print(pooled_embeddings)
    print("Shape: ", pooled_embeddings.shape)

    # ``nn.EmbeddingBag`` is the same as ``nn.Embedding`` but just with pooling (mean, sum, and so on)
    # We can see that the mean of the embeddings of embedding_collection is the same as the output of the embedding_bag_collection
    print("Mean: ", torch.mean(embedding_collection(ids), dim=1))






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Weights: tensor([[0.8823, 0.9150, 0.3829, 0.9593],
            [0.3904, 0.6009, 0.2566, 0.7936],
            [0.9408, 0.1332, 0.9346, 0.5936],
            [0.8694, 0.5677, 0.7411, 0.4294],
            [0.8854, 0.5739, 0.2666, 0.6274],
            [0.2696, 0.4414, 0.2969, 0.8317],
            [0.1053, 0.2695, 0.3588, 0.1994],
            [0.5472, 0.0062, 0.9516, 0.0753],
            [0.8860, 0.5832, 0.3376, 0.8090],
            [0.5779, 0.9040, 0.5547, 0.3423]])
    Embedding Collection Table:  Parameter containing:
    tensor([[0.8823, 0.9150, 0.3829, 0.9593],
            [0.3904, 0.6009, 0.2566, 0.7936],
            [0.9408, 0.1332, 0.9346, 0.5936],
            [0.8694, 0.5677, 0.7411, 0.4294],
            [0.8854, 0.5739, 0.2666, 0.6274],
            [0.2696, 0.4414, 0.2969, 0.8317],
            [0.1053, 0.2695, 0.3588, 0.1994],
            [0.5472, 0.0062, 0.9516, 0.0753],
            [0.8860, 0.5832, 0.3376, 0.8090],
            [0.5779, 0.9040, 0.5547, 0.3423]], requires_grad=True)
    Embedding Bag Collection Table:  Parameter containing:
    tensor([[0.8823, 0.9150, 0.3829, 0.9593],
            [0.3904, 0.6009, 0.2566, 0.7936],
            [0.9408, 0.1332, 0.9346, 0.5936],
            [0.8694, 0.5677, 0.7411, 0.4294],
            [0.8854, 0.5739, 0.2666, 0.6274],
            [0.2696, 0.4414, 0.2969, 0.8317],
            [0.1053, 0.2695, 0.3588, 0.1994],
            [0.5472, 0.0062, 0.9516, 0.0753],
            [0.8860, 0.5832, 0.3376, 0.8090],
            [0.5779, 0.9040, 0.5547, 0.3423]], requires_grad=True)
    Input row IDS:  tensor([[1, 3]])
    Embedding Collection Results: 
    tensor([[[0.3904, 0.6009, 0.2566, 0.7936],
             [0.8694, 0.5677, 0.7411, 0.4294]]], grad_fn=<EmbeddingBackward0>)
    Shape:  torch.Size([1, 2, 4])
    Embedding Bag Collection Results: 
    tensor([[0.6299, 0.5843, 0.4988, 0.6115]], grad_fn=<EmbeddingBagBackward0>)
    Shape:  torch.Size([1, 4])
    Mean:  tensor([[0.6299, 0.5843, 0.4988, 0.6115]], grad_fn=<MeanBackward1>)




.. GENERATED FROM PYTHON SOURCE LINES 179-185

Congratulations! Now you have a basic understanding of how to use
embedding tables --- one of the foundations of modern recommendation
systems! These tables represent entities and their relationships. For
example, the relationship between a given user and the pages and posts
they have liked.


.. GENERATED FROM PYTHON SOURCE LINES 188-220

TorchRec Features Overview
^^^^^^^^^^^^^^^^^^^^^^^^^^^

In the section above we've learned how to use embedding tables, one of the foundations of
modern recommendation systems! These tables represent entities and
relationships, such as users, pages, posts, etc. Given that these
entities are always increasing, a **hash** function is typically applied
to make sure the IDs are within the bounds of a certain embedding table.
However, in order to represent a vast amount of entities and reduce hash
collisions, these tables can become quite massive (think about the number of ads
for example). In fact, these tables can become so massive that they
won't be able to fit on 1 GPU, even with 80G of memory.

In order to train models with massive embedding tables, sharding these
tables across GPUs is required, which then introduces a whole new set of
problems and opportunities in parallelism and optimization. Luckily, we have
the TorchRec library that has encountered, consolidated, and addressed
many of these concerns. TorchRec serves as a **library that provides
primitives for large scale distributed embeddings**.

Next, we will explore the major features of the TorchRec
library. We will start with ``torch.nn.Embedding`` and will extend that to
custom TorchRec modules, explore distributed training environment with
generating a sharding plan for embeddings, look at inherent TorchRec
optimizations, and extend the model to be ready for inference in C++.
Below is a quick outline of what this section consists of:

* TorchRec Modules and Data Types
* Distributed Training, Sharding, and Optimizations
* Inference

Let's begin with importing TorchRec: 

.. GENERATED FROM PYTHON SOURCE LINES 220-224

.. code-block:: default


    import torchrec









.. GENERATED FROM PYTHON SOURCE LINES 225-247

TorchRec Modules and Data Types
----------------------------------

This section goes over TorchRec Modules and data types including such
entities as ``EmbeddingCollection`` and ``EmbeddingBagCollection``,
``JaggedTensor``, ``KeyedJaggedTensor``, ``KeyedTensor`` and more.

From ``EmbeddingBag`` to ``EmbeddingBagCollection``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

We have already explored :class:`torch.nn.Embedding` and :class:`torch.nn.EmbeddingBag`.
TorchRec extends these modules by creating collections of embeddings, in
other words modules that can have multiple embedding tables, with
``EmbeddingCollection`` and ``EmbeddingBagCollection``
We will use ``EmbeddingBagCollection`` to represent a group of
embedding bags.

In the example code below, we create an ``EmbeddingBagCollection`` (EBC)
with two embedding bags, 1 representing **products** and 1 representing **users**. 
Each table, ``product_table`` and ``user_table``, is represented by a 64 dimension
embedding of size 4096.


.. GENERATED FROM PYTHON SOURCE LINES 247-270

.. code-block:: default


    ebc = torchrec.EmbeddingBagCollection(
        device="cpu",
        tables=[
            torchrec.EmbeddingBagConfig(
                name="product_table",
                embedding_dim=64,
                num_embeddings=4096,
                feature_names=["product"],
                pooling=torchrec.PoolingType.SUM,
            ),
            torchrec.EmbeddingBagConfig(
                name="user_table",
                embedding_dim=64,
                num_embeddings=4096,
                feature_names=["user"],
                pooling=torchrec.PoolingType.SUM,
            )
        ]
    )
    print(ebc.embedding_bags)






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    ModuleDict(
      (product_table): EmbeddingBag(4096, 64, mode='sum')
      (user_table): EmbeddingBag(4096, 64, mode='sum')
    )




.. GENERATED FROM PYTHON SOURCE LINES 271-274

Let’s inspect the forward method for ``EmbeddingBagCollection`` and the
module’s inputs and outputs:


.. GENERATED FROM PYTHON SOURCE LINES 274-282

.. code-block:: default


    import inspect

    # Let's look at the ``EmbeddingBagCollection`` forward method
    # What is a ``KeyedJaggedTensor`` and ``KeyedTensor``?
    print(inspect.getsource(ebc.forward))






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

        def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
            """
            Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
            and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature.

            Args:
                features (KeyedJaggedTensor): Input KJT
            Returns:
                KeyedTensor
            """
            flat_feature_names: List[str] = []
            for names in self._feature_names:
                flat_feature_names.extend(names)
            inverse_indices = reorder_inverse_indices(
                inverse_indices=features.inverse_indices_or_none(),
                feature_names=flat_feature_names,
            )
            pooled_embeddings: List[torch.Tensor] = []
            feature_dict = features.to_dict()
            for i, embedding_bag in enumerate(self.embedding_bags.values()):
                for feature_name in self._feature_names[i]:
                    f = feature_dict[feature_name]
                    res = embedding_bag(
                        input=f.values(),
                        offsets=f.offsets(),
                        per_sample_weights=f.weights() if self._is_weighted else None,
                    ).float()
                    pooled_embeddings.append(res)
            return KeyedTensor(
                keys=self._embedding_names,
                values=process_pooled_embeddings(
                    pooled_embeddings=pooled_embeddings,
                    inverse_indices=inverse_indices,
                ),
                length_per_key=self._lengths_per_embedding,
            )





.. GENERATED FROM PYTHON SOURCE LINES 283-307

TorchRec Input/Output Data Types
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

TorchRec has distinct data types for input and output of its modules:
``JaggedTensor``, ``KeyedJaggedTensor``, and ``KeyedTensor``. Now you
might ask, why create new data types to represent sparse features? To
answer that question, we must understand how sparse features are
represented in code.

Sparse features are otherwise known as ``id_list_feature`` and
``id_score_list_feature``, and are the **IDs** that will be used as
indices to an embedding table to retrieve the embedding for that ID. To
give a very simple example, imagine a single sparse feature being Ads
that a user interacted with. The input itself would be a set of Ad IDs
that a user interacted with, and the embeddings retrieved would be a
semantic representation of those Ads. The tricky part of representing
these features in code is that in each input example, **the number of
IDs is variable**. One day a user might have interacted with only one ad
while the next day they interact with three.

A simple representation is shown below, where we have a ``lengths``
tensor denoting how many indices are in an example for a batch and a
``values`` tensor containing the indices themselves.


.. GENERATED FROM PYTHON SOURCE LINES 307-316

.. code-block:: default


    # Batch Size 2
    # 1 ID in example 1, 2 IDs in example 2
    id_list_feature_lengths = torch.tensor([1, 2])

    # Values (IDs) tensor: ID 5 is in example 1, ID 7, 1 is in example 2
    id_list_feature_values = torch.tensor([5, 7, 1])









.. GENERATED FROM PYTHON SOURCE LINES 317-319

Next, let's look at the offsets as well as what is contained in each batch


.. GENERATED FROM PYTHON SOURCE LINES 319-393

.. code-block:: default


    # Lengths can be converted to offsets for easy indexing of values
    id_list_feature_offsets = torch.cumsum(id_list_feature_lengths, dim=0)

    print("Offsets: ", id_list_feature_offsets)
    print("First Batch: ", id_list_feature_values[: id_list_feature_offsets[0]])
    print(
        "Second Batch: ",
        id_list_feature_values[id_list_feature_offsets[0] : id_list_feature_offsets[1]],
    )

    from torchrec import JaggedTensor

    # ``JaggedTensor`` is just a wrapper around lengths/offsets and values tensors!
    jt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths)

    # Automatically compute offsets from lengths
    print("Offsets: ", jt.offsets())

    # Convert to list of values
    print("List of Values: ", jt.to_dense())

    # ``__str__`` representation
    print(jt)

    from torchrec import KeyedJaggedTensor

    # ``JaggedTensor`` represents IDs for 1 feature, but we have multiple features in an ``EmbeddingBagCollection``
    # That's where ``KeyedJaggedTensor`` comes in! ``KeyedJaggedTensor`` is just multiple ``JaggedTensors`` for multiple id_list_feature_offsets
    # From before, we have our two features "product" and "user". Let's create ``JaggedTensors`` for both!

    product_jt = JaggedTensor(
        values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1])
    )
    user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))

    # Q1: How many batches are there, and which values are in the first batch for ``product_jt`` and ``user_jt``?
    kjt = KeyedJaggedTensor.from_jt_dict({"product": product_jt, "user": user_jt})

    # Look at our feature keys for the ``KeyedJaggedTensor``
    print("Keys: ", kjt.keys())

    # Look at the overall lengths for the ``KeyedJaggedTensor``
    print("Lengths: ", kjt.lengths())

    # Look at all values for ``KeyedJaggedTensor``
    print("Values: ", kjt.values())

    # Can convert ``KeyedJaggedTensor`` to dictionary representation
    print("to_dict: ", kjt.to_dict())

    # ``KeyedJaggedTensor`` string representation
    print(kjt)

    # Q2: What are the offsets for the ``KeyedJaggedTensor``?

    # Now we can run a forward pass on our ``EmbeddingBagCollection`` from before
    result = ebc(kjt)
    result

    # Result is a ``KeyedTensor``, which contains a list of the feature names and the embedding results
    print(result.keys())

    # The results shape is [2, 128], as batch size of 2. Reread previous section if you need a refresher on how the batch size is determined
    # 128 for dimension of embedding. If you look at where we initialized the ``EmbeddingBagCollection``, we have two tables "product" and "user" of dimension 64 each
    # meaning embeddings for both features are of size 64. 64 + 64 = 128
    print(result.values().shape)

    # Nice to_dict method to determine the embeddings that belong to each feature
    result_dict = result.to_dict()
    for key, embedding in result_dict.items():
        print(key, embedding.shape)






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Offsets:  tensor([1, 3])
    First Batch:  tensor([5])
    Second Batch:  tensor([7, 1])
    Offsets:  tensor([0, 1, 3])
    List of Values:  [tensor([5]), tensor([7, 1])]
    JaggedTensor({
        [[5], [7, 1]]
    })

    Keys:  ['product', 'user']
    Lengths:  tensor([3, 1, 2, 2])
    Values:  tensor([1, 2, 1, 5, 2, 3, 4, 1])
    to_dict:  {'product': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7f92ecd6aec0>, 'user': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7f92ecd68550>}
    KeyedJaggedTensor({
        "product": [[1, 2, 1], [5]],
        "user": [[2, 3], [4, 1]]
    })

    ['product', 'user']
    torch.Size([2, 128])
    product torch.Size([2, 64])
    user torch.Size([2, 64])




.. GENERATED FROM PYTHON SOURCE LINES 394-398

Congrats! You now understand TorchRec modules and data types.
Give yourself a pat on the back for making it this far. Next, we will
learn about distributed training and sharding.


.. GENERATED FROM PYTHON SOURCE LINES 401-429

Distributed Training and Sharding
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Now that we have a grasp on TorchRec modules and data types, it's time
to take it to the next level.

Remember, the main purpose of TorchRec is to provide primitives for
distributed embeddings. So far, we've only worked with embedding tables
on a single device. This has been possible given how small the embedding tables
have been, but in a production setting this isn't generally the case.
Embedding tables often get massive, where one table can't fit on a single
GPU, creating the requirement for multiple devices and a distributed
environment.

In this section, we will explore setting up a distributed environment,
exactly how actual production training is done, and explore sharding
embedding tables, all with TorchRec.

**This section will also only use 1 GPU, though it will be treated in a
distributed fashion. This is only a limitation for training, as training
has a process per GPU. Inference does not run into this requirement**

In the example code below, we set up our PyTorch distributed environment.

.. warning:: 
   If you are running this in Google Colab, you can only call this cell once,
   calling it again will cause an error as you can only initialize the process
   group once.

.. GENERATED FROM PYTHON SOURCE LINES 429-450

.. code-block:: default


    import os

    import torch.distributed as dist

    # Set up environment variables for distributed training
    # RANK is which GPU we are on, default 0
    os.environ["RANK"] = "0"
    # How many devices in our "world", colab notebook can only handle 1 process
    os.environ["WORLD_SIZE"] = "1"
    # Localhost as we are training locally
    os.environ["MASTER_ADDR"] = "localhost"
    # Port for distributed training
    os.environ["MASTER_PORT"] = "29500"

    # nccl backend is for GPUs, gloo is for CPUs
    dist.init_process_group(backend="gloo")

    print(f"Distributed environment initialized: {dist}")






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Distributed environment initialized: <module 'torch.distributed' from '/usr/local/lib/python3.10/dist-packages/torch/distributed/__init__.py'>




.. GENERATED FROM PYTHON SOURCE LINES 451-519

Distributed Embeddings
~~~~~~~~~~~~~~~~~~~~~~

We have already worked with the main TorchRec module:
``EmbeddingBagCollection``. We have examined how it works along with how
data is represented in TorchRec. However, we have not yet explored one
of the main parts of TorchRec, which is **distributed embeddings**.

GPUs are the most popular choice for ML workloads by far today, as they
are able to do magnitudes more floating point operations/s
(`FLOPs <https://en.wikipedia.org/wiki/FLOPS>`__) than CPU. However,
GPUs come with the limitation of scarce fast memory (HBM which is
analogous to RAM for CPU), typically, ~10s of GBs.

A RecSys model can contain embedding tables that far exceed the memory
limit for 1 GPU, hence the need for distribution of the embedding tables
across multiple GPUs, otherwise known as **model parallel**. On the
other hand, **data parallel** is where the entire model is replicated on
each GPU, which each GPU taking in a distinct batch of data for
training, syncing gradients on the backwards pass.

Parts of the model that **require less compute but more memory
(embeddings) are distributed with model parallel** while parts that
**require more compute and less memory (dense layers, MLP, etc.) are
distributed with data parallel**.

Sharding
~~~~~~~~

In order to distribute an embedding table, we split up the embedding
table into parts and place those parts onto different devices, also
known as “sharding”.

There are many ways to shard embedding tables. The most common ways are:

* Table-Wise: the table is placed entirely onto one device
* Column-Wise: columns of embedding tables are sharded
* Row-Wise: rows of embedding tables are sharded

Sharded Modules
~~~~~~~~~~~~~~~

While all of this seems like a lot to deal with and implement, you're in
luck. **TorchRec provides all the primitives for easy distributed
training and inference**! In fact, TorchRec modules have two corresponding
classes for working with any TorchRec module in a distributed
environment:

* **The module sharder**: This class exposes a ``shard`` API
  that handles sharding a TorchRec Module, producing a sharded module.
  * For ``EmbeddingBagCollection``, the sharder is `EmbeddingBagCollectionSharder <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder>`__
* **Sharded module**: This class is a sharded variant of a TorchRec module.
  It has the same input/output as a the regular TorchRec module, but much
  more optimized and works in a distributed environment.
  * For ``EmbeddingBagCollection``, the sharded variant is `ShardedEmbeddingBagCollection <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection>`__

Every TorchRec module has an unsharded and sharded variant.

* The unsharded version is meant to be prototyped and experimented with.
* The sharded version is meant to be used in a distributed environment for
  distributed training and inference.

The sharded versions of TorchRec modules, for example
``EmbeddingBagCollection``, will handle everything that is needed for Model
Parallelism, such as communication between GPUs for distributing
embeddings to the correct GPUs.

Refresher of our ``EmbeddingBagCollection`` module

.. GENERATED FROM PYTHON SOURCE LINES 519-535

.. code-block:: default

    ebc

    from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
    from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
    from torchrec.distributed.types import ShardingEnv

    # Corresponding sharder for ``EmbeddingBagCollection`` module
    sharder = EmbeddingBagCollectionSharder()

    # ``ProcessGroup`` from torch.distributed initialized 2 cells above
    pg = dist.GroupMember.WORLD
    assert pg is not None, "Process group is not initialized"

    print(f"Process Group: {pg}")






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Process Group: <torch.distributed.distributed_c10d.ProcessGroup object at 0x7f93cc367ef0>




.. GENERATED FROM PYTHON SOURCE LINES 536-577

Planner
~~~~~~~

Before we can show how sharding works, we must know about the
**planner**, which helps us determine the best sharding configuration.

Given a number of embedding tables and a number of ranks, there are many
different sharding configurations that are possible. For example, given
2 embedding tables and 2 GPUs, you can:

* Place 1 table on each GPU
* Place both tables on a single GPU and no tables on the other
* Place certain rows and columns on each GPU

Given all of these possibilities, we typically want a sharding
configuration that is optimal for performance.

That is where the planner comes in. The planner is able to determine
given the number of embedding tables and the number of GPUs, what is the optimal
configuration. Turns out, this is incredibly difficult to do manually,
with tons of factors that engineers have to consider to ensure an
optimal sharding plan. Luckily, TorchRec provides an auto planner when
the planner is used. 

The TorchRec planner:

* Assesses memory constraints of hardware
* Estimates compute based on memory fetches as embedding lookups
* Addresses data specific factors
* Considers other hardware specifics like bandwidth to generate an optimal sharding plan

In order to take into consideration all these variables, The TorchRec
planner can take in `various amounts of data for embedding tables,
constraints, hardware information, and
topology <https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/planner/planners.py#L147-L155>`__
to aid in generating the optimal sharding plan for a model, which is
routinely provided across stacks.

To learn more about sharding, see our `sharding
tutorial <https://pytorch.org/tutorials/advanced/sharding.html>`__.


.. GENERATED FROM PYTHON SOURCE LINES 577-592

.. code-block:: default


    # In our case, 1 GPU and compute on CUDA device
    planner = EmbeddingShardingPlanner(
        topology=Topology(
            world_size=1,
            compute_device="cuda",
        )
    )

    # Run planner to get plan for sharding
    plan = planner.collective_plan(ebc, [sharder], pg)

    print(f"Sharding Plan generated: {plan}")






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Sharding Plan generated: module: 

        param     | sharding type | compute kernel | ranks
    ------------- | ------------- | -------------- | -----
    product_table | table_wise    | fused          | [0]  
    user_table    | table_wise    | fused          | [0]  

        param     | shard offsets | shard sizes |   placement  
    ------------- | ------------- | ----------- | -------------
    product_table | [0, 0]        | [4096, 64]  | rank:0/cuda:0
    user_table    | [0, 0]        | [4096, 64]  | rank:0/cuda:0




.. GENERATED FROM PYTHON SOURCE LINES 593-605

Planner Result
~~~~~~~~~~~~~~

As you can see above, when running the planner there is quite a bit of output.
We can see a lot of stats being calculated along with where our
tables end up being placed.

The result of running the planner is a static plan, which can be reused
for sharding! This allows sharding to be static for production models
instead of determining a new sharding plan everytime. Below, we use the
sharding plan to finally generate our ``ShardedEmbeddingBagCollection``.


.. GENERATED FROM PYTHON SOURCE LINES 605-617

.. code-block:: default


    # The static plan that was generated
    plan

    env = ShardingEnv.from_process_group(pg)

    # Shard the ``EmbeddingBagCollection`` module using the ``EmbeddingBagCollectionSharder``
    sharded_ebc = sharder.shard(ebc, plan.plan[""], env, torch.device("cuda"))

    print(f"Sharded EBC Module: {sharded_ebc}")






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Sharded EBC Module: ShardedEmbeddingBagCollection(
      (lookups): 
       GroupedPooledEmbeddingsLookup(
          (_emb_modules): ModuleList(
            (0): BatchedFusedEmbeddingBag(
              (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
            )
          )
        )
       (_output_dists): 
       TwPooledEmbeddingDist()
      (embedding_bags): ModuleDict(
        (product_table): Module()
        (user_table): Module()
      )
    )




.. GENERATED FROM PYTHON SOURCE LINES 618-629

GPU Training with ``LazyAwaitable``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Remember that TorchRec is a highly optimized library for distributed
embeddings. A concept that TorchRec introduces to enable higher
performance for training on GPU is a
`LazyAwaitable <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.types.LazyAwaitable>`__.
You will see ``LazyAwaitable`` types as outputs of various sharded
TorchRec modules. All a ``LazyAwaitable`` type does is delay calculating some
result as long as possible, and it does it by acting like an async type.


.. GENERATED FROM PYTHON SOURCE LINES 629-669

.. code-block:: default


    from typing import List

    from torchrec.distributed.types import LazyAwaitable


    # Demonstrate a ``LazyAwaitable`` type:
    class ExampleAwaitable(LazyAwaitable[torch.Tensor]):
        def __init__(self, size: List[int]) -> None:
            super().__init__()
            self._size = size

        def _wait_impl(self) -> torch.Tensor:
            return torch.ones(self._size)


    awaitable = ExampleAwaitable([3, 2])
    awaitable.wait()

    kjt = kjt.to("cuda")
    output = sharded_ebc(kjt)
    # The output of our sharded ``EmbeddingBagCollection`` module is an `Awaitable`?
    print(output)

    kt = output.wait()
    # Now we have our ``KeyedTensor`` after calling ``.wait()``
    # If you are confused as to why we have a ``KeyedTensor ``output,
    # give yourself a refresher on the unsharded ``EmbeddingBagCollection`` module
    print(type(kt))

    print(kt.keys())

    print(kt.values().shape)

    # Same output format as unsharded ``EmbeddingBagCollection``
    result_dict = kt.to_dict()
    for key, embedding in result_dict.items():
        print(key, embedding.shape)






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    <torchrec.distributed.embeddingbag.EmbeddingBagCollectionAwaitable object at 0x7f92ed2463b0>
    <class 'torchrec.sparse.jagged_tensor.KeyedTensor'>
    ['product', 'user']
    torch.Size([2, 128])
    product torch.Size([2, 64])
    user torch.Size([2, 64])




.. GENERATED FROM PYTHON SOURCE LINES 670-700

Anatomy of Sharded TorchRec modules
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

We have now successfully sharded an ``EmbeddingBagCollection`` given a
sharding plan that we generated! The sharded module has common APIs from
TorchRec which abstract away distributed communication/compute amongst
multiple GPUs. In fact, these APIs are highly optimized for performance
in training and inference. **Below are the three common APIs for
distributed training/inference** that are provided by TorchRec:

* ``input_dist``: Handles distributing inputs from GPU to GPU.
* ``lookups``: Does the actual embedding lookup in an optimized,
  batched manner using FBGEMM TBE (more on this later).
* ``output_dist``: Handles distributing outputs from GPU to GPU.

The distribution of inputs and outputs is done through `NCCL
Collectives <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/overview.html>`__,
namely
`All-to-Alls <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/p2p.html#all-to-all>`__,
which is where all GPUs send and receive data to and from one another.
TorchRec interfaces with PyTorch distributed for collectives and
provides clean abstractions to the end users, removing the concern for
the lower level details.

The backwards pass does all of these collectives but in the reverse
order for distribution of gradients. ``input_dist``, ``lookup``, and
``output_dist`` all depend on the sharding scheme. Since we sharded in a
table-wise fashion, these APIs are modules that are constructed by
`TwPooledEmbeddingSharding <https://pytorch.org/torchrec/torchrec.distributed.sharding.html#torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingSharding>`__.


.. GENERATED FROM PYTHON SOURCE LINES 700-710

.. code-block:: default


    sharded_ebc

    # Distribute input KJTs to all other GPUs and receive KJTs
    sharded_ebc._input_dists

    # Distribute output embeddings to all other GPUs and receive embeddings
    sharded_ebc._output_dists






.. rst-class:: sphx-glr-script-out

 .. code-block:: none


    [TwPooledEmbeddingDist(
      (_dist): PooledEmbeddingsAllToAll()
    )]



.. GENERATED FROM PYTHON SOURCE LINES 711-734

Optimizing Embedding Lookups
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

In performing lookups for a collection of embedding tables, a trivial
solution would be to iterate through all the ``nn.EmbeddingBags`` and do
a lookup per table. This is exactly what the standard, unsharded
``EmbeddingBagCollection`` does. However, while this solution
is simple, it is extremely slow.

`FBGEMM <https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu>`__ is a
library that provides GPU operators (otherwise known as kernels) that
are very optimized. One of these operators is known as **Table Batched
Embedding** (TBE), provides two major optimizations:

-  Table batching, which allows you to look up multiple embeddings with
   one kernel call.
-  Optimizer Fusion, which allows the module to update itself given the
   canonical pytorch optimizers and arguments.

The ``ShardedEmbeddingBagCollection`` uses the FBGEMM TBE as the lookup
instead of traditional ``nn.EmbeddingBags`` for optimized embedding
lookups.


.. GENERATED FROM PYTHON SOURCE LINES 734-738

.. code-block:: default


    sharded_ebc._lookups






.. rst-class:: sphx-glr-script-out

 .. code-block:: none


    [GroupedPooledEmbeddingsLookup(
      (_emb_modules): ModuleList(
        (0): BatchedFusedEmbeddingBag(
          (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
        )
      )
    )]



.. GENERATED FROM PYTHON SOURCE LINES 739-763

``DistributedModelParallel``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

We have now explored sharding a single ``EmbeddingBagCollection``! We were
able to take the ``EmbeddingBagCollectionSharder`` and use the unsharded
``EmbeddingBagCollection`` to generate a
``ShardedEmbeddingBagCollection`` module. This workflow is fine, but
typically when implementing model parallel,
`DistributedModelParallel <https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.model_parallel.DistributedModelParallel>`__
(DMP) is used as the standard interface. When wrapping your model (in
our case ``ebc``), with DMP, the following will occur:

1. Decide how to shard the model. DMP will collect the available
   sharders and come up with a plan of the optimal way to shard the
   embedding table(s) (for example, ``EmbeddingBagCollection``)
2. Actually shard the model. This includes allocating memory for each
   embedding table on the appropriate device(s).

DMP takes in everything that we've just experimented with, like a static
sharding plan, a list of sharders, etc. However, it also has some nice
defaults to seamlessly shard a TorchRec model. In this toy example,
since we have two embedding tables and one GPU, TorchRec will place both
on the single GPU.


.. GENERATED FROM PYTHON SOURCE LINES 763-774

.. code-block:: default


    ebc

    model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device("cuda"))

    out = model(kjt)
    out.wait()

    model






.. rst-class:: sphx-glr-script-out

 .. code-block:: none


    DistributedModelParallel(
      (_dmp_wrapped_module): ShardedEmbeddingBagCollection(
        (lookups): 
         GroupedPooledEmbeddingsLookup(
            (_emb_modules): ModuleList(
              (0): BatchedFusedEmbeddingBag(
                (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
              )
            )
          )
         (_input_dists): 
         TwSparseFeaturesDist(
            (_dist): KJTAllToAll()
          )
         (_output_dists): 
         TwPooledEmbeddingDist(
            (_dist): PooledEmbeddingsAllToAll()
          )
        (embedding_bags): ModuleDict(
          (product_table): Module()
          (user_table): Module()
        )
      )
    )



.. GENERATED FROM PYTHON SOURCE LINES 775-787

Sharding Best Practices
~~~~~~~~~~~~~~~~~~~~~~~

Currently, our configuration is only sharding on 1 GPU (or rank), which
is trivial: just place all the tables on 1 GPUs memory. However, in real
production use cases, embedding tables are **typically sharded on
hundreds of GPUs**, with different sharding methods such as table-wise,
row-wise, and column-wise. It is incredibly important to determine a
proper sharding configuration (to prevent out of memory issues) while
keeping it balanced not only in terms of memory but also compute for
optimal performance.


.. GENERATED FROM PYTHON SOURCE LINES 790-850

Adding in the Optimizer
~~~~~~~~~~~~~~~~~~~~~~~

Remember that TorchRec modules are hyperoptimized for large scale
distributed training. An important optimization is in regards to the
optimizer. 

TorchRec modules provide a seamless API to fuse the
backwards pass and optimize step in training, providing a significant
optimization in performance and decreasing the memory used, alongside
granularity in assigning distinct optimizers to distinct model
parameters.

Optimizer Classes
^^^^^^^^^^^^^^^^^

TorchRec uses ``CombinedOptimizer``, which contains a collection of
``KeyedOptimizers``. A ``CombinedOptimizer`` effectively makes it easy
to handle multiple optimizers for various sub groups in the model. A
``KeyedOptimizer`` extends the ``torch.optim.Optimizer`` and is
initialized through a dictionary of parameters exposes the parameters.
Each ``TBE`` module in a ``EmbeddingBagCollection`` will have it's own
``KeyedOptimizer`` which combines into one ``CombinedOptimizer``.

Fused optimizer in TorchRec
^^^^^^^^^^^^^^^^^^^^^^^^^^^

Using ``DistributedModelParallel``, the **optimizer is fused, which
means that the optimizer update is done in the backward**. This is an
optimization in TorchRec and FBGEMM, where the optimizer embedding
gradients are not materialized and applied directly to the parameters.
This brings significant memory savings as embedding gradients are
typically size of the parameters themselves.

You can, however, choose to make the optimizer ``dense`` which does not
apply this optimization and let's you inspect the embedding gradients or
apply computations to it as you wish. A dense optimizer in this case
would be your `canonical PyTorch model training loop with
optimizer. <https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html>`__

Once the optimizer is created through ``DistributedModelParallel``, you
still need to manage an optimizer for the other parameters not
associated with TorchRec embedding modules. To find the other
parameters,
use ``in_backward_optimizer_filter(model.named_parameters())``. 
Apply an optimizer to those parameters as you would a normal Torch
optimizer and combine this and the ``model.fused_optimizer`` into one
``CombinedOptimizer`` that you can use in your training loop to
``zero_grad`` and ``step`` through.

Adding an Optimizer to ``EmbeddingBagCollection``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We will do this in two ways, which are equivalent, but give you options
depending on your preferences:

1. Passing optimizer kwargs through ``fused_params`` in sharder.
2. Through ``apply_optimizer_in_backward``, which converts the optimizer
   parameters to ``fused_params`` to pass to the ``TBE`` in the ``EmbeddingBagCollection`` or ``EmbeddingCollection``.


.. GENERATED FROM PYTHON SOURCE LINES 850-918

.. code-block:: default


    # Option 1: Passing optimizer kwargs through fused parameters
    from torchrec.optim.optimizers import in_backward_optimizer_filter
    from fbgemm_gpu.split_embedding_configs import EmbOptimType


    # We initialize the sharder with
    fused_params = {
        "optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD,
        "learning_rate": 0.02,
        "eps": 0.002,
    }

    # Initialize sharder with ``fused_params``
    sharder_with_fused_params = EmbeddingBagCollectionSharder(fused_params=fused_params)

    # We'll use same plan and unsharded EBC as before but this time with our new sharder
    sharded_ebc_fused_params = sharder_with_fused_params.shard(ebc, plan.plan[""], env, torch.device("cuda"))

    # Looking at the optimizer of each, we can see that the learning rate changed, which indicates our optimizer has been applied correctly.
    # If seen, we can also look at the TBE logs of the cell to see that our new optimizer is indeed being applied
    print(f"Original Sharded EBC fused optimizer: {sharded_ebc.fused_optimizer}")
    print(f"Sharded EBC with fused parameters fused optimizer: {sharded_ebc_fused_params.fused_optimizer}")

    print(f"Type of optimizer: {type(sharded_ebc_fused_params.fused_optimizer)}")

    from torch.distributed.optim import _apply_optimizer_in_backward as apply_optimizer_in_backward
    import copy
    # Option 2: Applying optimizer through apply_optimizer_in_backward
    # Note: we need to call apply_optimizer_in_backward on unsharded model first and then shard it

    # We can achieve the same result as we did in the previous
    ebc_apply_opt = copy.deepcopy(ebc)
    optimizer_kwargs = {"lr": 0.5}

    for name, param in ebc_apply_opt.named_parameters():
        print(f"{name=}")
        apply_optimizer_in_backward(torch.optim.SGD, [param], optimizer_kwargs)

    sharded_ebc_apply_opt = sharder.shard(ebc_apply_opt, plan.plan[""], env, torch.device("cuda"))

    # Now when we print the optimizer, we will see our new learning rate, you can verify momentum through the TBE logs as well if outputted
    print(sharded_ebc_apply_opt.fused_optimizer)
    print(type(sharded_ebc_apply_opt.fused_optimizer))

    # We can also check through the filter other parameters that aren't associated with the "fused" optimizer(s)
    # Practically, just non TorchRec module parameters. Since our module is just a TorchRec EBC
    # there are no other parameters that aren't associated with TorchRec
    print("Non Fused Model Parameters:")
    print(dict(in_backward_optimizer_filter(sharded_ebc_fused_params.named_parameters())).keys())

    # Here we do a dummy backwards call and see that parameter updates for fused
    # optimizers happen as a result of the backward pass

    ebc_output = sharded_ebc_fused_params(kjt).wait().values()
    loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
    print(f"First Iteration Loss: {loss}")

    loss.backward()

    ebc_output = sharded_ebc_fused_params(kjt).wait().values()
    loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
    # We don't call an optimizer.step(), so for the loss to have changed here,
    # that means that the gradients were somehow updated, which is what the
    # fused optimizer automatically handles for us
    print(f"Second Iteration Loss: {loss}")






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Original Sharded EBC fused optimizer: : EmbeddingFusedOptimizer (
    Parameter Group 0
        lr: 0.01
    )
    Sharded EBC with fused parameters fused optimizer: : EmbeddingFusedOptimizer (
    Parameter Group 0
        lr: 0.02
    )
    Type of optimizer: <class 'torchrec.optim.keyed.CombinedOptimizer'>
    name='embedding_bags.product_table.weight'
    name='embedding_bags.user_table.weight'
    : EmbeddingFusedOptimizer (
    Parameter Group 0
        lr: 0.5
    )
    <class 'torchrec.optim.keyed.CombinedOptimizer'>
    Non Fused Model Parameters:
    dict_keys([])
    First Iteration Loss: 255.66006469726562
    Second Iteration Loss: 245.43795776367188




.. GENERATED FROM PYTHON SOURCE LINES 919-956

Inference
~~~~~~~~~

Now that we are able to train distributed embeddings, how can we take
the trained model and optimize it for inference? Inference is typically
very sensitive to **performance and size of the model**. Running just
the trained model in a Python environment is incredibly inefficient.
There are two key differences between inference and training
environments:

* **Quantization**: Inference models are typically
  quantized, where model parameters lose precision for lower latency in
  predictions and reduced model size. For example FP32 (4 bytes) in
  trained model to INT8 (1 byte) for each embedding weight. This is also
  necessary given the vast scale of embedding tables, as we want to use as
  few devices as possible for inference to minimize latency.

* **C++ environment**: Inference latency is very important, so in order to ensure
  ample performance, the model is typically ran in a C++ environment,
  along with the situations where we don't have a Python runtime, like on
  device.

TorchRec provides primitives for converting a TorchRec model into being
inference ready with:

* APIs for quantizing the model, introducing
  optimizations automatically with FBGEMM TBE
* Sharding embeddings for distributed inference
* Compiling the model to `TorchScript <https://pytorch.org/docs/stable/jit.html>`__
  (compatible in C++)

In this section, we will go over this entire workflow of:

* Quantizing the model
* Sharding the quantized model
* Compiling the sharded quantized model into TorchScript


.. GENERATED FROM PYTHON SOURCE LINES 956-974

.. code-block:: default


    ebc

    class InferenceModule(torch.nn.Module):
        def __init__(self, ebc: torchrec.EmbeddingBagCollection):
            super().__init__()
            self.ebc_ = ebc

        def forward(self, kjt: KeyedJaggedTensor):
            return self.ebc_(kjt)

    module = InferenceModule(ebc)
    for name, param in module.named_parameters():
        # Here, the parameters should still be FP32, as we are using a standard EBC
        # FP32 is default, regularly used for training
        print(name, param.shape, param.dtype)






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    ebc_.embedding_bags.product_table.weight torch.Size([4096, 64]) torch.float32
    ebc_.embedding_bags.user_table.weight torch.Size([4096, 64]) torch.float32




.. GENERATED FROM PYTHON SOURCE LINES 975-982

Quantization
~~~~~~~~~~~~

As you can see above, the normal EBC contains embedding table weights as
FP32 precision (32 bits for each weight). Here, we will use the TorchRec
inference library to quantize the embedding weights of the model to INT8


.. GENERATED FROM PYTHON SOURCE LINES 982-1035

.. code-block:: default


    from torch import quantization as quant
    from torchrec.modules.embedding_configs import QuantConfig
    from torchrec.quant.embedding_modules import (
        EmbeddingBagCollection as QuantEmbeddingBagCollection,
    )


    quant_dtype = torch.int8


    qconfig = QuantConfig(
        # dtype of the result of the embedding lookup, post activation
        # torch.float generally for compatibility with rest of the model
        # as rest of the model here usually isn't quantized
        activation=quant.PlaceholderObserver.with_args(dtype=torch.float),
        # quantized type for embedding weights, aka parameters to actually quantize
        weight=quant.PlaceholderObserver.with_args(dtype=quant_dtype),
    )
    qconfig_spec = {
        # Map of module type to qconfig
        torchrec.EmbeddingBagCollection: qconfig,
    }
    mapping = {
        # Map of module type to quantized module type
        torchrec.EmbeddingBagCollection: QuantEmbeddingBagCollection,
    }


    module = InferenceModule(ebc)

    # Quantize the module
    qebc = quant.quantize_dynamic(
        module,
        qconfig_spec=qconfig_spec,
        mapping=mapping,
        inplace=False,
    )


    print(f"Quantized EBC: {qebc}")

    kjt = kjt.to("cpu")

    qebc(kjt)

    # Once quantized, goes from parameters -> buffers, as no longer trainable
    for name, buffer in qebc.named_buffers():
        # The shapes of the tables should be the same but the dtype should be int8 now
        # post quantization
        print(name, buffer.shape, buffer.dtype)






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Quantized EBC: InferenceModule(
      (ebc_): QuantizedEmbeddingBagCollection(
        (_kjt_to_jt_dict): ComputeKJTToJTDict()
        (embedding_bags): ModuleDict(
          (product_table): Module()
          (user_table): Module()
        )
      )
    )
    ebc_.embedding_bags.product_table.weight torch.Size([4096, 80]) torch.uint8
    ebc_.embedding_bags.user_table.weight torch.Size([4096, 80]) torch.uint8




.. GENERATED FROM PYTHON SOURCE LINES 1036-1043

Shard
~~~~~

Here we perform sharding of the TorchRec quantized model. This is to
ensure we are using the performant module through FBGEMM TBE. Here we
are using one device to be consistent with training (1 TBE).


.. GENERATED FROM PYTHON SOURCE LINES 1043-1063

.. code-block:: default


    from torchrec import distributed as trec_dist
    from torchrec.distributed.shard import _shard_modules


    sharded_qebc = _shard_modules(
        module=qebc,
        device=torch.device("cpu"),
        env=trec_dist.ShardingEnv.from_local(
            1,
            0,
        ),
    )


    print(f"Sharded Quantized EBC: {sharded_qebc}")

    sharded_qebc(kjt)






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Sharded Quantized EBC: InferenceModule(
      (ebc_): ShardedQuantEmbeddingBagCollection(
        (lookups): 
         InferGroupedPooledEmbeddingsLookup()
        (_output_dists): ModuleList()
        (embedding_bags): ModuleDict(
          (product_table): Module()
          (user_table): Module()
        )
        (_input_dist_module): ShardedQuantEbcInputDist()
      )
    )

    <torchrec.sparse.jagged_tensor.KeyedTensor object at 0x7f92f0da4310>



.. GENERATED FROM PYTHON SOURCE LINES 1064-1076

Compilation
~~~~~~~~~~~

Now we have the optimized eager TorchRec inference model. The next step
is to ensure that this model is loadable in C++, as currently it is only
runnable in a Python runtime.

The recommended method of compilation at Meta is two fold: `torch.fx
tracing <https://pytorch.org/docs/stable/fx.html>`__ (generate
intermediate representation of model) and converting the result to
TorchScript, where TorchScript is C++ compatible.


.. GENERATED FROM PYTHON SOURCE LINES 1076-1095

.. code-block:: default


    from torchrec.fx import Tracer


    tracer = Tracer(leaf_modules=["IntNBitTableBatchedEmbeddingBagsCodegen"])

    graph = tracer.trace(sharded_qebc)
    gm = torch.fx.GraphModule(sharded_qebc, graph)

    print("Graph Module Created!")

    print(gm.code)

    scripted_gm = torch.jit.script(gm)
    print("Scripted Graph Module Created!")

    print(scripted_gm.code)






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Graph Module Created!

    torch.fx._symbolic_trace.wrap("torchrec_distributed_quant_embeddingbag_flatten_feature_lengths")
    torch.fx._symbolic_trace.wrap("torchrec_fx_utils__fx_marker")
    torch.fx._symbolic_trace.wrap("torchrec_distributed_quant_embedding_kernel__unwrap_kjt")
    torch.fx._symbolic_trace.wrap("fbgemm_gpu_split_table_batched_embeddings_ops_inference_inputs_to_device")
    torch.fx._symbolic_trace.wrap("torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference")

    def forward(self, kjt : torchrec_sparse_jagged_tensor_KeyedJaggedTensor):
        flatten_feature_lengths = torchrec_distributed_quant_embeddingbag_flatten_feature_lengths(kjt);  kjt = None
        _fx_marker = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_BEGIN', flatten_feature_lengths);  _fx_marker = None
        split = flatten_feature_lengths.split([2])
        getitem = split[0];  split = None
        to = getitem.to(device(type='cuda', index=0), non_blocking = True);  getitem = None
        _fx_marker_1 = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_END', flatten_feature_lengths);  flatten_feature_lengths = _fx_marker_1 = None
        _unwrap_kjt = torchrec_distributed_quant_embedding_kernel__unwrap_kjt(to);  to = None
        getitem_1 = _unwrap_kjt[0]
        getitem_2 = _unwrap_kjt[1]
        getitem_3 = _unwrap_kjt[2];  _unwrap_kjt = getitem_3 = None
        inputs_to_device = fbgemm_gpu_split_table_batched_embeddings_ops_inference_inputs_to_device(getitem_1, getitem_2, None, device(type='cuda', index=0));  getitem_1 = getitem_2 = None
        getitem_4 = inputs_to_device[0]
        getitem_5 = inputs_to_device[1]
        getitem_6 = inputs_to_device[2];  inputs_to_device = None
        _tensor_constant0 = self._tensor_constant0
        _tensor_constant1 = self._tensor_constant1
        bounds_check_indices = torch.ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_4, getitem_5, 1, _tensor_constant1, getitem_6);  _tensor_constant0 = _tensor_constant1 = bounds_check_indices = None
        _tensor_constant2 = self._tensor_constant2
        _tensor_constant3 = self._tensor_constant3
        _tensor_constant4 = self._tensor_constant4
        _tensor_constant5 = self._tensor_constant5
        _tensor_constant6 = self._tensor_constant6
        _tensor_constant7 = self._tensor_constant7
        _tensor_constant8 = self._tensor_constant8
        _tensor_constant9 = self._tensor_constant9
        int_nbit_split_embedding_codegen_lookup_function = torch.ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(dev_weights = _tensor_constant2, uvm_weights = _tensor_constant3, weights_placements = _tensor_constant4, weights_offsets = _tensor_constant5, weights_tys = _tensor_constant6, D_offsets = _tensor_constant7, total_D = 128, max_int2_D = 0, max_int4_D = 0, max_int8_D = 64, max_float16_D = 0, max_float32_D = 0, indices = getitem_4, offsets = getitem_5, pooling_mode = 0, indice_weights = getitem_6, output_dtype = 0, lxu_cache_weights = _tensor_constant8, lxu_cache_locations = _tensor_constant9, row_alignment = 16, max_float8_D = 0, fp8_exponent_bits = -1, fp8_exponent_bias = -1);  _tensor_constant2 = _tensor_constant3 = _tensor_constant4 = _tensor_constant5 = _tensor_constant6 = _tensor_constant7 = getitem_4 = getitem_5 = getitem_6 = _tensor_constant8 = _tensor_constant9 = None
        embeddings_cat_empty_rank_handle_inference = torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference([int_nbit_split_embedding_codegen_lookup_function], dim = 1, device = 'cuda:0', dtype = torch.float32);  int_nbit_split_embedding_codegen_lookup_function = None
        to_1 = embeddings_cat_empty_rank_handle_inference.to(device(type='cpu'));  embeddings_cat_empty_rank_handle_inference = None
        keyed_tensor = torchrec_sparse_jagged_tensor_KeyedTensor(keys = ['product', 'user'], length_per_key = [64, 64], values = to_1, key_dim = 1);  to_1 = None
        return keyed_tensor
    
    /usr/local/lib/python3.10/dist-packages/torch/jit/_check.py:178: UserWarning:

    The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.

    Scripted Graph Module Created!
    def forward(self,
        kjt: __torch__.torchrec.sparse.jagged_tensor.KeyedJaggedTensor) -> __torch__.torchrec.sparse.jagged_tensor.KeyedTensor:
      _0 = __torch__.torchrec.distributed.quant_embeddingbag.flatten_feature_lengths
      _1 = __torch__.torchrec.fx.utils._fx_marker
      _2 = __torch__.torchrec.distributed.quant_embedding_kernel._unwrap_kjt
      _3 = __torch__.fbgemm_gpu.split_table_batched_embeddings_ops_inference.inputs_to_device
      _4 = __torch__.torchrec.distributed.embedding_lookup.embeddings_cat_empty_rank_handle_inference
      flatten_feature_lengths = _0(kjt, )
      _fx_marker = _1("KJT_ONE_TO_ALL_FORWARD_BEGIN", flatten_feature_lengths, )
      split = (flatten_feature_lengths).split([2], )
      getitem = split[0]
      to = (getitem).to(torch.device("cuda", 0), True, None, )
      _fx_marker_1 = _1("KJT_ONE_TO_ALL_FORWARD_END", flatten_feature_lengths, )
      _unwrap_kjt = _2(to, )
      getitem_1 = (_unwrap_kjt)[0]
      getitem_2 = (_unwrap_kjt)[1]
      inputs_to_device = _3(getitem_1, getitem_2, None, torch.device("cuda", 0), )
      getitem_4 = (inputs_to_device)[0]
      getitem_5 = (inputs_to_device)[1]
      getitem_6 = (inputs_to_device)[2]
      _tensor_constant0 = self._tensor_constant0
      _tensor_constant1 = self._tensor_constant1
      ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_4, getitem_5, 1, _tensor_constant1, getitem_6)
      _tensor_constant2 = self._tensor_constant2
      _tensor_constant3 = self._tensor_constant3
      _tensor_constant4 = self._tensor_constant4
      _tensor_constant5 = self._tensor_constant5
      _tensor_constant6 = self._tensor_constant6
      _tensor_constant7 = self._tensor_constant7
      _tensor_constant8 = self._tensor_constant8
      _tensor_constant9 = self._tensor_constant9
      int_nbit_split_embedding_codegen_lookup_function = ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(_tensor_constant2, _tensor_constant3, _tensor_constant4, _tensor_constant5, _tensor_constant6, _tensor_constant7, 128, 0, 0, 64, 0, 0, getitem_4, getitem_5, 0, getitem_6, 0, _tensor_constant8, _tensor_constant9, 16)
      _5 = [int_nbit_split_embedding_codegen_lookup_function]
      embeddings_cat_empty_rank_handle_inference = _4(_5, 1, "cuda:0", 6, )
      to_1 = torch.to(embeddings_cat_empty_rank_handle_inference, torch.device("cpu"))
      _6 = ["product", "user"]
      _7 = [64, 64]
      keyed_tensor = __torch__.torchrec.sparse.jagged_tensor.KeyedTensor.__new__(__torch__.torchrec.sparse.jagged_tensor.KeyedTensor)
      _8 = (keyed_tensor).__init__(_6, _7, to_1, 1, None, None, )
      return keyed_tensor





.. GENERATED FROM PYTHON SOURCE LINES 1096-1105

Conclusion
^^^^^^^^^^

In this tutorial, you have gone from training a distributed RecSys model all the way
to making it inference ready. The `TorchRec repo
<https://github.com/pytorch/torchrec/tree/main/torchrec/inference>`__ has a
full example of how to load a TorchRec TorchScript model into C++ for
inference.


.. GENERATED FROM PYTHON SOURCE LINES 1108-1117

See Also
--------------

For more information, please see our
`dlrm <https://github.com/facebookresearch/dlrm/tree/main/torchrec_dlrm/>`__
example, which includes multinode training on the Criteo 1TB
dataset using the methods described in `Deep Learning Recommendation Model
for Personalization and Recommendation Systems <https://arxiv.org/abs/1906.00091>`__.



.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  0.847 seconds)


.. _sphx_glr_download_intermediate_torchrec_intro_tutorial.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example


    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: torchrec_intro_tutorial.py <torchrec_intro_tutorial.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: torchrec_intro_tutorial.ipynb <torchrec_intro_tutorial.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_