.. _fsdp_notes:
FSDP Notes
==========
.. _fsdp_prefetch:
FSDP Prefetch Nuances
---------------------
For overlapping ``forward`` all-gathers with ``forward`` compute, there are two possible mechanisms:
1. Implicit forward prefetching (always enabled)
2. Explicit forward prefetching (``forward_prefetch=True``)
Implicit ``forward`` prefetching refers to relying on issuing the all-gathers from a separate CUDA
stream to allow for overlapping an all-gather with ``forward`` compute issued before it (from the CPU
perspective). For example, if we have layer 0 all-gather -> layer 0 ``forward`` compute -> layer 1
all-gather -> …, then layer 1 all-gather can overlap with layer 0 ``forward`` compute even though the
CPU thread issued it afterwards. (The 1st all-gather will not be able to overlap with anything.)
Explicit ``forward`` prefetching refers to changing the CPU thread’s issue order: e.g. layer 0
all-gather -> layer 1 all-gather -> layer 0 ``forward`` compute -> …. In eager mode, there is no way to
know in general which layer is the next layer (e.g. layer 1 in the example) when still executing on
layer 0. Therefore, explicit ``forward`` prefetching should only be used for models whose execution
order is fixed from iteration to iteration (which we sometimes call “static graph”). An example of a
model that does not satisfy this constraint is `FLAVA
`_).
Explicit ``forward`` prefetching only saves the time taken to issue a layer’s ``forward`` compute kernels at
the cost that the next all-gather’s output tensor must be allocated while the current one is still
in use. By issuing the next all- gather before the current ``forward`` compute kernels, the next
all-gather can start sooner on GPU. For most LLM workloads, this is not the case, so there is no
motivation for enabling ``forward_prefetch=True``.
In contrast, for ``backward``, we must use explicit ``backward`` prefetching or else there will be 0 overlap
of communication and computation. The reason is because we use a single NCCL process group for both
all-gather and reduce-scatter (partially because in earlier NCCL versions, it was not safe to use
multiple concurrently on the same device over the same ranks). A single NCCL process group means a
single internal NCCL stream on which reduce-scatters and all-gathers run serially. As such, unless
we explicitly reorder the CPU issue order to be next all-gather -> current reduce-scatter, then the
current reduce-scatter would block the next all-gather and hence the next ``backward`` computation,
preventing the current reduce-scatter from overlapping.
.. _fsdp_comms_payload_size:
Communication payload size
--------------------------
In FSDP the communications are:
1. all-gather on parameters in ``forward``
2. all-gather on parameters in ``backward``
3. reduce-scatter on gradients in ``backward``
If activation checkpointing (:func:`~torch.utils.checkpoint.checkpoint`) is used there is no
additional communication since the parameters are prefetched anyway during ``backward``.
In the FSDP design, the communication payload per rank is determined as follows: Each call to
:class:`FullyShardedDataParallel` creates one communication group consisting of the parameters in
``module.parameters()`` except any already assigned to a nested :class:`FullyShardedDataParallel`
instance. For example, for Llama, if you apply :class:`FullyShardedDataParallel` to every
transformer block and also to the root module, then there is one communication group for each
transformer block and finally one communication group with the initial embedding and final linear.
Each communication group corresponds to a single all-gather call and single reduce-scatter call. In
that way, how you apply :class:`FullyShardedDataParallel` determines the communication size. In
general, applying FSDP to each transformer block is a good heuristic for LLMs, and it is hard to do
better than that given the current design.
Let's consider an example where we have a Transformer-based model sharded over 8 GPUs, where the
sharding happens at the transformer block-level only, and each transformer block contains 1.6B
parameters and the parameters are in fp32 (4 bytes each). Which means that once sharded, each
transformer block will contain 0.2B parameters on each rank.
* The ``forward`` pass will communicate in chunks of ``0.2*4 = 0.8GB`` in all-gather
* The ``backward`` pass will communicate 2 times ``0.8GB`` each (1x all-gather and 1x reduce-scatter)
In other words there will be 3 communications with a payload of ``0.8GB`` each. If the model was
comprised of 10 transformer blocks there would be a total of 30 communications for a total of
``30*0.8=24GB``.
To formalize the payload size per communication per rank is
``total_transformer_block_params_in_B*dtype_bytes/num_gpus`` (GBs).
Please note that in this example we didn't include the additional communications required for the
embedding, which should be accounted for as well. And the math would depend on whether the input and
output embeddings are tied or not. If they aren't tied there will be 2x more communications.
.. _fsdp_buffers_sizes:
FSDP buffers sizes
------------------
First, let's cover the buffers allocated for communications:
``forward`` currently requires 2x all-gather buffer size. Here is why:
As explained in :ref:`fsdp_prefetch` in the case of explicit ``forward`` prefetching
(``forward_prefetch=True`) case of layer 0 all-gather -> layer 0 forward compute -> layer 1
all-gather there is a need for 2 all-gather-sized buffers, because one buffer is used in the current ``forward`` while the other is used to do the prefetching.
While the implicit ``forward`` prefetching (``forward_prefetch=False``, default) case of the same sequence in theory should need only 1 buffer, in reality it's still 2x all-gather-sized buffers. The reason is that in the flat-parameter FSDP design, we do not copy-out of the all-gather buffer. The parameters used for compute are directly viewed into the all-gather buffer (in fact, the main benefit of the "flat parameter" is exactly this reason). In that case, while 'layer 1 all-gather' is overlapping with 'layer 0 forward compute', the 'layer 0 forward compute' is using the parameters viewed into the 'layer 0 all-gather' buffer.
A natural question then is, when would you want ``forward_prefetch=False``? For static-graph models (like most LLMs), there is a major technical reason. It is more that, practically, we added this option quickly for some CPU-bound internal models and have not tested every code path with it in unit testing, so we are less confident in it. ``forward_prefetching=False`` can be slightly easier to reason about since we do not have to check the recorded forward order as a possible 'failure mode'; a module's all-gather can always be found under its own ``record_function`` label in its profiler trace.
``backward`` currently requires at least 2x all-gather buffer size and potentially a bit more. Here is why:
The current FSDP design uses ``recordStream`` to manage allocations produced in one stream consumed in another, which can lead to more memory usage than expected. How much more can be "non-deterministic" in that it depends on GPU kernel timing relative to the CPU. The ``limit_all_gathers=True`` argument is a mitigation to that - for more details refer to this discussion is `FSDP & CUDACachingAllocator `_.
The way existing FSDP works with autograd:
* Existing FSDP all-gathers the ``flat_param``, which is the autograd leaf.
* It calls ``torch.split`` to get 1D views into the ``flat_param`` corresponding to its constituent original parameters.
* It calls ``torch.view`` on each 1D split to view back to ND.
* This means that in ``backward``, we end up with ``ViewBackward`` (ND -> 1D) and ``SplitWithSizesBackward`` (which is a concat). In particular, each individual gradient is computed as a separate allocation, and an explicit concat happens to construct the reduce-scatter input buffer. This implies actually a 2x buffer size for reduce-scatter at that peak memory point.
In summary, for ``backward``, it is about 2x buffer size for reduce-scatter plus any ``recordStream`` effects.
Second, let's discuss the additional buffers:
Once the sharded parameters are gathered from all ranks, they require an additional buffer of `total_transformer_block_params_in_B*dtype_bytes` for the full parameters - so continuing the example from earlier if each transformer block is 1.6B parameters and the parameters are in fp32, then it'd be `1.6*4=6.4GB` buffer.
And there is a need for 2 of those buffers, since there is one currently being used and another being prefetched.
To summarize, we have:
1. 2 times communication buffers of ``total_transformer_block_params_in_B*dtype_bytes/num_gpus``
2. 2 times unsharded transformer block parameters buffer ````total_transformer_block_params_in_B*dtype_bytes``
or if you have been following the example:
1. ``2*1.6*4/8=1.6GB``
2. ``2**1.6*4=12.8GB``
and the total of ``14.4GB``.
Now let's briefly discuss what happens to the embeddings as we have left those out from the calculations:
Given the rule we discussed that you included in the note starting with "the communication buffer
size is determined as follows", we can analyze as follows:
* Suppose we apply FSDP to the root module (e.g. the ``Transformer`` class). Suppose we further apply FSDP to each transformer block (e.g. the ``TransformerBlock`` class).
* Most commonly, the embedding and final linear projection are direct children of the root ``Transformer`` class.
* Following our rule, that means that the embedding and final linear projection are assigned to the root ``Transformer``'s flat parameter.
* We have _another_ special rule, which is that the root does not free its parameters after forward because they will be anyways immediately all-gathered in backward.
* Putting this together, this means that the root's flat parameter including the embedding and final projection are all-gathered to begin forward and kept in GPU memory until the end of backward.
* If the embedding and final linear are not weight-tied, then we _could_ further apply FSDP to the embedding and to the final linear. For weight-tied parameters, we require them to be part of the same flat parameter (or else it would get double-counted). That would allow the embedding to be freed after its usage in forward and only all-gathered toward the end of backward.
* Hopefully, this gives a better sense -- each FSDP module gets assigned parameters in its ``module.parameters`` except those already assigned to another nested FSDP module, and the FSDP module's ``forward`` defines the 'live' interval for its parameters. Hence, the nested ``nn.Module`` structure can affect the all-gather/free schedule and hence the memory/throughput performance.