Shortcuts

FSDP Notes

FSDP Prefetch Nuances

For overlapping forward all-gathers with forward compute, there are two possible mechanisms:

  1. Implicit forward prefetching (always enabled)

  2. Explicit forward prefetching (forward_prefetch=True)

Implicit forward prefetching refers to relying on issuing the all-gathers from a separate CUDA stream to allow for overlapping an all-gather with forward compute issued before it (from the CPU perspective). For example, if we have layer 0 all-gather -> layer 0 forward compute -> layer 1 all-gather -> …, then layer 1 all-gather can overlap with layer 0 forward compute even though the CPU thread issued it afterwards. (The 1st all-gather will not be able to overlap with anything.)

Explicit forward prefetching refers to changing the CPU thread’s issue order: e.g. layer 0 all-gather -> layer 1 all-gather -> layer 0 forward compute -> …. In eager mode, there is no way to know in general which layer is the next layer (e.g. layer 1 in the example) when still executing on layer 0. Therefore, explicit forward prefetching should only be used for models whose execution order is fixed from iteration to iteration (which we sometimes call “static graph”). An example of a model that does not satisfy this constraint is FLAVA).

Explicit forward prefetching only saves the time taken to issue a layer’s forward compute kernels at the cost that the next all-gather’s output tensor must be allocated while the current one is still in use. By issuing the next all- gather before the current forward compute kernels, the next all-gather can start sooner on GPU. For most LLM workloads, this is not the case, so there is no motivation for enabling forward_prefetch=True.

In contrast, for backward, we must use explicit backward prefetching or else there will be 0 overlap of communication and computation. The reason is because we use a single NCCL process group for both all-gather and reduce-scatter (partially because in earlier NCCL versions, it was not safe to use multiple concurrently on the same device over the same ranks). A single NCCL process group means a single internal NCCL stream on which reduce-scatters and all-gathers run serially. As such, unless we explicitly reorder the CPU issue order to be next all-gather -> current reduce-scatter, then the current reduce-scatter would block the next all-gather and hence the next backward computation, preventing the current reduce-scatter from overlapping.

Communication payload size

In FSDP the communications are:

  1. all-gather on parameters in forward

  2. all-gather on parameters in backward

  3. reduce-scatter on gradients in backward

If activation checkpointing (checkpoint()) is used there is no additional communication since the parameters are prefetched anyway during backward.

In the FSDP design, the communication payload per rank is determined as follows: Each call to FullyShardedDataParallel creates one communication group consisting of the parameters in module.parameters() except any already assigned to a nested FullyShardedDataParallel instance. For example, for Llama, if you apply FullyShardedDataParallel to every transformer block and also to the root module, then there is one communication group for each transformer block and finally one communication group with the initial embedding and final linear. Each communication group corresponds to a single all-gather call and single reduce-scatter call. In that way, how you apply FullyShardedDataParallel determines the communication size. In general, applying FSDP to each transformer block is a good heuristic for LLMs, and it is hard to do better than that given the current design.

Let’s consider an example where we have a Transformer-based model sharded over 8 GPUs, where the sharding happens at the transformer block-level only, and each transformer block contains 1.6B parameters and the parameters are in fp32 (4 bytes each). Which means that once sharded, each transformer block will contain 0.2B parameters on each rank.

  • The forward pass will communicate in chunks of 0.2*4 = 0.8GB in all-gather

  • The backward pass will communicate 2 times 0.8GB each (1x all-gather and 1x reduce-scatter)

In other words there will be 3 communications with a payload of 0.8GB each. If the model was comprised of 10 transformer blocks there would be a total of 30 communications for a total of 30*0.8=24GB.

To formalize the payload size per communication per rank is total_transformer_block_params_in_B*dtype_bytes/num_gpus (GBs).

Please note that in this example we didn’t include the additional communications required for the embedding, which should be accounted for as well. And the math would depend on whether the input and output embeddings are tied or not. If they aren’t tied there will be 2x more communications.

FSDP buffers sizes

First, let’s cover the buffers allocated for communications:

forward currently requires 2x all-gather buffer size. Here is why:

As explained in FSDP Prefetch Nuances in the case of explicit forward prefetching (forward_prefetch=True`) case of layer 0 all-gather -> layer 0 forward compute -> layer 1 all-gather there is a need for 2 all-gather-sized buffers, because one buffer is used in the current ``forward while the other is used to do the prefetching.

While the implicit forward prefetching (forward_prefetch=False, default) case of the same sequence in theory should need only 1 buffer, in reality it’s still 2x all-gather-sized buffers. The reason is that in the flat-parameter FSDP design, we do not copy-out of the all-gather buffer. The parameters used for compute are directly viewed into the all-gather buffer (in fact, the main benefit of the “flat parameter” is exactly this reason). In that case, while ‘layer 1 all-gather’ is overlapping with ‘layer 0 forward compute’, the ‘layer 0 forward compute’ is using the parameters viewed into the ‘layer 0 all-gather’ buffer.

A natural question then is, when would you want forward_prefetch=False? For static-graph models (like most LLMs), there is a major technical reason. It is more that, practically, we added this option quickly for some CPU-bound internal models and have not tested every code path with it in unit testing, so we are less confident in it. forward_prefetching=False can be slightly easier to reason about since we do not have to check the recorded forward order as a possible ‘failure mode’; a module’s all-gather can always be found under its own record_function label in its profiler trace.

backward currently requires at least 2x all-gather buffer size and potentially a bit more. Here is why:

The current FSDP design uses recordStream to manage allocations produced in one stream consumed in another, which can lead to more memory usage than expected. How much more can be “non-deterministic” in that it depends on GPU kernel timing relative to the CPU. The limit_all_gathers=True argument is a mitigation to that - for more details refer to this discussion is FSDP & CUDACachingAllocator.

The way existing FSDP works with autograd:

  • Existing FSDP all-gathers the flat_param, which is the autograd leaf.

  • It calls torch.split to get 1D views into the flat_param corresponding to its constituent original parameters.

  • It calls torch.view on each 1D split to view back to ND.

  • This means that in backward, we end up with ViewBackward (ND -> 1D) and SplitWithSizesBackward (which is a concat). In particular, each individual gradient is computed as a separate allocation, and an explicit concat happens to construct the reduce-scatter input buffer. This implies actually a 2x buffer size for reduce-scatter at that peak memory point.

In summary, for backward, it is about 2x buffer size for reduce-scatter plus any recordStream effects.

Second, let’s discuss the additional buffers:

Once the sharded parameters are gathered from all ranks, they require an additional buffer of total_transformer_block_params_in_B*dtype_bytes for the full parameters - so continuing the example from earlier if each transformer block is 1.6B parameters and the parameters are in fp32, then it’d be 1.6*4=6.4GB buffer.

And there is a need for 2 of those buffers, since there is one currently being used and another being prefetched.

To summarize, we have:

  1. 2 times communication buffers of total_transformer_block_params_in_B*dtype_bytes/num_gpus

  2. 2 times unsharded transformer block parameters buffer ``total_transformer_block_params_in_B*dtype_bytes

or if you have been following the example:

  1. 2*1.6*4/8=1.6GB

  2. 2**1.6*4=12.8GB

and the total of 14.4GB.

Now let’s briefly discuss what happens to the embeddings as we have left those out from the calculations:

Given the rule we discussed that you included in the note starting with “the communication buffer size is determined as follows”, we can analyze as follows:

  • Suppose we apply FSDP to the root module (e.g. the Transformer class). Suppose we further apply FSDP to each transformer block (e.g. the TransformerBlock class).

  • Most commonly, the embedding and final linear projection are direct children of the root Transformer class.

  • Following our rule, that means that the embedding and final linear projection are assigned to the root Transformer’s flat parameter.

  • We have _another_ special rule, which is that the root does not free its parameters after forward because they will be anyways immediately all-gathered in backward.

  • Putting this together, this means that the root’s flat parameter including the embedding and final projection are all-gathered to begin forward and kept in GPU memory until the end of backward.

  • If the embedding and final linear are not weight-tied, then we _could_ further apply FSDP to the embedding and to the final linear. For weight-tied parameters, we require them to be part of the same flat parameter (or else it would get double-counted). That would allow the embedding to be freed after its usage in forward and only all-gathered toward the end of backward.

  • Hopefully, this gives a better sense – each FSDP module gets assigned parameters in its module.parameters except those already assigned to another nested FSDP module, and the FSDP module’s forward defines the ‘live’ interval for its parameters. Hence, the nested nn.Module structure can affect the all-gather/free schedule and hence the memory/throughput performance.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources