.. _distributed-autograd-design:

Distributed Autograd Design
===========================

This note will present the detailed design for distributed autograd and walk
through the internals of the same. Make sure you're familiar with
:ref:`autograd-mechanics` and the :ref:`distributed-rpc-framework` before
proceeding.

Background
^^^^^^^^^^

Let's say you have two nodes and a very simple model partitioned across two
nodes. This can be implemented using :mod:`torch.distributed.rpc` as follows:

.. code::

  import torch
  import torch.distributed.rpc as rpc

  def my_add(t1, t2):
    return torch.add(t1, t2)

  # On worker 0:
  t1 = torch.rand((3, 3), requires_grad=True)
  t2 = torch.rand((3, 3), requires_grad=True)

  # Perform some computation remotely.
  t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))

  # Perform some computation locally based on remote result.
  t4 = torch.rand((3, 3), requires_grad=True)
  t5 = torch.mul(t3, t4)

  # Compute some loss.
  loss = t5.sum()

The main motivation behind distributed autograd is to enable running a backward
pass on such distributed models with the ``loss`` that we've computed and
record appropriate gradients for all tensors that require gradients.

.. attaching_send_recv_functions:

Autograd recording during the forward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

PyTorch builds the autograd graph during the forward pass and this graph is
used to execute the backward pass. For more details see
:ref:`how-autograd-encodes-history`.

For distributed autograd, we need to keep track of all RPCs during the forward
pass to ensure the backward pass is executed appropriately. For this purpose,
we attach ``send`` and ``recv`` functions to the autograd graph when we perform
an RPC.

- The ``send`` function is attached to the source of the RPC and its output
  edges point to the autograd function for the input tensors of the RPC.
  The input for this function during the backward pass is received from the
  destination as the output of the appropriate ``recv`` function.
- The ``recv`` function is attached to the destination of the RPC and its
  inputs are retrieved from operators executed on the destination using the
  input tensors. The output gradients of this function are sent to the source
  node to the appropriate ``send`` function during the backward pass.
- Each ``send-recv`` pair is assigned a globally unique ``autograd_message_id``
  to uniquely identify the pair. This is useful to lookup the corresponding
  function on a remote node during the backward pass.
- For :ref:`rref`, whenever we call :meth:`torch.distributed.rpc.RRef.to_here`
  we attach an appropriate ``send-recv`` pair for the tensors involved.

As an example, this is what the autograd graph for our example above would look
like (t5.sum() excluded for simplicity):

.. image:: ../_static/img/distributed_autograd/send_recv_functions.png

.. autograd_context:

Distributed Autograd Context
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Each forward and backward pass that uses distributed autograd is assigned a
unique :class:`torch.distributed.autograd.context` and this context has a
globally unique ``autograd_context_id``. This context is created on each node
as needed.

This context serves the following purpose:

1. Multiple nodes running distributed backward passes might accumulate
   gradients on the same tensor and as a result the ``.grad`` field of the
   tensor would have gradients from a variety of distributed backward passes
   before we have the opportunity to run the optimizer. This is similar to
   calling :meth:`torch.autograd.backward` multiple times locally. In order to
   provide a way of separating out the gradients for each backward pass, the
   gradients are accumulated in the :class:`torch.distributed.autograd.context`
   for each backward pass.
2. During the forward pass we store the ``send`` and ``recv`` functions for
   each autograd pass in this context. This ensures we hold references to the
   appropriate nodes in the autograd graph to keep it alive. In addition to
   this, it is easy to lookup the appropriate ``send`` and ``recv`` functions
   during the backward pass.
3. In general we also use this context to store some metadata for each
   distributed autograd pass.

|

From the user's perspective the autograd context is setup as follows:

.. code::

  import torch.distributed.autograd as dist_autograd
  with dist_autograd.context() as context_id:
    loss = model.forward()
    dist_autograd.backward(context_id, loss)

It is important to note that your model's forward pass must be invoked within
the distributed autograd context manager, as a valid context is needed in
order to ensure that all ``send`` and ``recv`` functions are stored properly
to run the backward pass across all participating nodes.

Distributed Backward Pass
^^^^^^^^^^^^^^^^^^^^^^^^^

In this section we outline the challenge of computing dependencies accurately
during a distributed backward pass and describe a couple of algorithms (with
tradeoffs) on how we can execute a distributed backward pass.

Computing dependencies
----------------------

Consider the following piece of code being run on a single machine

.. code::

  import torch
  a = torch.rand((3, 3), requires_grad=True)
  b = torch.rand((3, 3), requires_grad=True)
  c = torch.rand((3, 3), requires_grad=True)
  d = a + b
  e = b * c
  d.sum.().backward()

This is what the autograd graph for the code above would look like:

.. image:: ../_static/img/distributed_autograd/local_dependencies.png
  :scale: 80%

The first step the autograd engine performs as part of the backward pass is
computing the number of dependencies for each node in the autograd graph. This
helps the autograd engine know when a node in the graph is ready for execution.
The numbers in brackets for ``add(1)`` and ``mul(0)`` denote the number of
dependencies. As you can see, this means during the backward pass the ``add``
node needs 1 input and the ``mul`` node doesn't need any inputs (in other
words doesn't need to be executed). The local autograd engine computes these
dependencies by traversing the graph from the root nodes (``d`` in this case).

The fact that certain nodes in the autograd graph might not be executed in the
backward pass poses a challenge for distributed autograd. Consider this piece
of code which uses RPC.

.. code::

  import torch
  import torch.distributed.rpc as rpc

  a = torch.rand((3, 3), requires_grad=True)
  b = torch.rand((3, 3), requires_grad=True)
  c = torch.rand((3, 3), requires_grad=True)

  d = rpc.rpc_sync("worker1", torch.add, args=(a, b))
  e = rpc.rpc_sync("worker1", torch.mul, args=(b, c))
  loss = d.sum()

The associated autograd graph for the code above would be:

.. image:: ../_static/img/distributed_autograd/distributed_dependencies.png

Computing dependencies of this distributed autograd graph is much more
challenging and requires some overhead (either in terms of computation or
network communication).

For performance sensitive applications we can avoid a
lot of overhead by assuming every ``send`` and ``recv`` function are valid as
part of the backward pass (most applications don't perform RPCs that aren't
used). This simplifies the distributed autograd algorithm and is much more
efficient, but at the cost that the application needs to be aware of the
limitations. This algorithm is called the `FAST mode algorithm`_ and is
described in detail below.

In the general case it might not be necessary that every ``send`` and ``recv``
function is valid as part of the backward pass. To address this, we have
proposed a `SMART mode algorithm`_ which is described in a later section.
Please note that currently, only the `FAST` mode algorithm is implemented.

.. _fast-mode-algorithm:

FAST mode algorithm
-------------------

The key assumption of this algorithm is that each ``send`` function has a
dependency of 1 when we run a backward pass. In other words, we assume we'll
receive a gradient over RPC from another node.

The algorithm is as follows:

1. We start from the worker which has the roots for the backward pass
   (all roots must be local).
2. Lookup all the ``send`` functions for the current
   `Distributed Autograd Context`_.
3. Compute dependencies locally starting from the provided roots and all the
   ``send`` functions we retrieved.
4. After computing dependencies, kick off the local autograd engine with the
   provided roots.
5. When the autograd engine executes the ``recv`` function, the ``recv``
   function sends the input gradients via RPC to the appropriate worker.
   Each ``recv`` function knows the destination worker id since it is recorded
   as part of the forward pass. The ``recv`` function also sends over the
   ``autograd_context_id`` and ``autograd_message_id`` to the remote host.
6. When this request is received on the remote host, we use the
   ``autograd_context_id`` and ``autograd_message_id`` to look up the
   appropriate ``send`` function.
7. If this is the first time a worker has received a request for the given
   ``autograd_context_id``, it will compute dependencies locally as described
   in points 1-3 above.
8. The ``send`` function retrieved in 6. is then enqueued for execution on the
   local autograd engine for that worker.
9. Finally, instead of accumulating the gradients on the ``.grad`` field of the
   Tensor, we accumulate the gradients separately per
   `Distributed Autograd Context`_. The gradients are stored in a
   ``Dict[Tensor, Tensor]``, which is basically a map from Tensor to its
   associated gradient and this map can be retrieved using the
   :meth:`~torch.distributed.autograd.get_gradients` API.

|

As an example the complete code with distributed autograd would be as follows:

.. code::

  import torch
  import torch.distributed.autograd as dist_autograd
  import torch.distributed.rpc as rpc

  def my_add(t1, t2):
    return torch.add(t1, t2)

  # On worker 0:

  # Setup the autograd context. Computations that take
  # part in the distributed backward pass must be within
  # the distributed autograd context manager.
  with dist_autograd.context() as context_id:
    t1 = torch.rand((3, 3), requires_grad=True)
    t2 = torch.rand((3, 3), requires_grad=True)

    # Perform some computation remotely.
    t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))

    # Perform some computation locally based on remote result.
    t4 = torch.rand((3, 3), requires_grad=True)
    t5 = torch.mul(t3, t4)

    # Compute some loss.
    loss = t5.sum()

    # Run the backward pass.
    dist_autograd.backward(context_id, [loss])

    # Retrieve the gradients from the context.
    dist_autograd.get_gradients(context_id)

The distributed autograd graph with dependencies would be as follows:

.. image:: ../_static/img/distributed_autograd/distributed_dependencies_computed.png

The `FAST mode algorithm`_ applied to the above example would be as follows:

1. On ``Worker 0`` we start from the roots ``loss`` and ``send1`` to compute
   dependencies. As a result ``send1`` is marked with a dependency of 1 and ``mul``
   on ``Worker 0`` is marked with a dependency of 1.
2. Now, we kickoff the local autograd engine on ``Worker 0``. We first execute
   the ``mul`` function, accumulate its output in the autograd context as the
   gradient for ``t4``. Then, we execute ``recv2`` which sends the gradients to
   ``Worker 1``.
3. Since this is the first time ``Worker 1`` has heard about this backward pass,
   it starts dependency computation and marks the dependencies for ``send2``,
   ``add`` and ``recv1`` appropriately.
4. Next, we enqueue ``send2`` on the local autograd engine of ``Worker 1``, which
   in turn executes ``add`` and ``recv1``.
5. When ``recv1`` is executed it sends the gradients over to ``Worker 0``.
6. Since ``Worker 0`` has already computed dependencies for this backward pass,
   it just enqueues and executes ``send1`` locally.
7. Finally, gradients for ``t1``, ``t2`` and ``t4`` are accumulated in the
   `Distributed Autograd Context`_.

SMART mode algorithm
--------------------
Full details of this algorithm are still in the works, but for the general idea
you can refer to **Distributed Autograd Algorithm Smart mode** section in the
`RFC`_.

Distributed Optimizer
^^^^^^^^^^^^^^^^^^^^^

The :class:`~torch.distributed.optim.DistributedOptimizer` operates as follows:

1. Takes a list of remote parameters (:class:`~torch.distributed.rpc.RRef`) to
   optimize. These could also be local parameters wrapped within a local
   ``RRef``.
2. Takes a :class:`~torch.optim.Optimizer` class as the local
   optimizer to run on all distinct ``RRef`` owners.
3. The distributed optimizer creates an instance of the local ``Optimizer`` on
   each of the worker nodes and holds an ``RRef`` to them.
4. When :meth:`torch.distributed.optim.DistributedOptimizer.step` is invoked,
   the distributed optimizer uses RPC to remotely execute all the local
   optimizers on the appropriate remote workers. A distributed autograd
   ``context_id`` must be provided as input to
   :meth:`torch.distributed.optim.DistributedOptimizer.step`. This is used
   by local optimizers to apply gradients stored in the corresponding
   context.
5. If multiple concurrent distributed optimizers are updating the same
   parameters on a worker, these updates are serialized via a lock.

Simple end to end example
^^^^^^^^^^^^^^^^^^^^^^^^^

Putting it all together, the following is a simple end to end example using
distributed autograd and the distributed optimizer. If the code is placed into a
file called "dist_autograd_simple.py", it can be run with the command
:code:`MASTER_ADDR="localhost" MASTER_PORT=29500 python dist_autograd_simple.py`:

.. code::

  import torch
  import torch.multiprocessing as mp
  import torch.distributed.autograd as dist_autograd
  from torch.distributed import rpc
  from torch import optim
  from torch.distributed.optim import DistributedOptimizer

  def random_tensor():
      return torch.rand((3, 3), requires_grad=True)

  def _run_process(rank, dst_rank, world_size):
      name = "worker{}".format(rank)
      dst_name = "worker{}".format(dst_rank)

      # Initialize RPC.
      rpc.init_rpc(
          name=name,
          rank=rank,
          world_size=world_size
      )

      # Use a distributed autograd context.
      with dist_autograd.context() as context_id:
          # Forward pass (create references on remote nodes).
          rref1 = rpc.remote(dst_name, random_tensor)
          rref2 = rpc.remote(dst_name, random_tensor)
          loss = rref1.to_here() + rref2.to_here()

          # Backward pass (run distributed autograd).
          dist_autograd.backward(context_id, [loss.sum()])

          # Build DistributedOptimizer.
          dist_optim = DistributedOptimizer(
          optim.SGD,
          [rref1, rref2],
          lr=0.05,
          )

          # Run the distributed optimizer step.
          dist_optim.step(context_id)

  def run_process(rank, world_size):
      dst_rank = (rank + 1) % world_size
      _run_process(rank, dst_rank, world_size)
      rpc.shutdown()

  if __name__ == '__main__':
    # Run world_size workers
    world_size = 2
    mp.spawn(run_process, args=(world_size,), nprocs=world_size)

.. _RFC: https://github.com/pytorch/pytorch/issues/23110