• Docs >
  • torch.distributed.fsdp.fully_shard
Shortcuts

torch.distributed.fsdp.fully_shard

PyTorch FSDP2 (fully_shard)

PyTorch FSDP2 provides a fully sharded data parallelism (FSDP) implementation targeting performant eager-mode while using per-parameter sharding for improved usability.

  • If you are new to FSDP, we recommend that you start with FSDP2 due to improved usability.

  • If you are currently using FSDP1, consider evaluating the following differences to see if you should switch to FSDP2:

Compared to PyTorch FSDP1 (FullyShardedDataParallel):

  • FSDP2 uses DTensor-based dim-0 per-parameter sharding for a simpler sharding representation compared to FSDP1’s flat-parameter sharding, while preserving similar throughput performance. More specifically, FSDP2 chunks each parameter on dim-0 across the data parallel workers (using torch.chunk(dim=0)), whereas FSDP1 flattens, concatenates, and chunks a group of tensors together, making reasoning about what data is present on each worker and resharding to different parallelisms complex. Per-parameter sharding provides a more intuitive user experience, relaxes constraints around frozen parameters, and allows for communication-free (sharded) state dicts, which otherwise require all-gathers in FSDP1.

  • FSDP2 implements a different memory management approach to handle the multi-stream usages that avoids torch.Tensor.record_stream. This ensures deterministic and expected memory usage and does not require blocking the CPU like in FSDP1’s limit_all_gathers=True.

  • FSDP2 exposes APIs for manual control over prefetching and collective scheduling, allowing power users more customization. See the methods on FSDPModule below for details.

  • FSDP2 simplifies some of the API surface: e.g. FSDP2 does not directly support full state dicts. Instead, users can reshard the sharded state dicts containing DTensor s to full state dicts themselves using DTensor APIs like DTensor.full_tensor() or by using higher-level APIs like PyTorch Distributed Checkpoint ‘s distributed state dict APIs. Also, some other args have been removed; see here for details.

If you are onboarding FSDP for the first time or if any of the above appeals to your use case, we recommend that you consider using FSDP2.

See this RFC for details on system design and implementation.

Note

torch.distributed.fsdp.fully_shard is currently in prototype state and under development. The core API will likely not change, but we may make some API changes if necessary.

The frontend API is fully_shard that can be called on a module:

torch.distributed.fsdp.fully_shard(module, *, mesh=None, reshard_after_forward=True, shard_placement_fn=None, mp_policy=MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True), offload_policy=OffloadPolicy())[source]

Apply fully sharded data parallelism (FSDP) to module, where FSDP shards module parameters, gradients, and optimizer states across data parallel workers to save memory at the cost of communication.

At initialization, FSDP shards the module’s parameters across the data parallel workers given by mesh. Before forward, FSDP all-gathers the sharded parameters across the data-parallel workers to get the unsharded parameters for forward computation. If reshard_after_forward is True, then FSDP frees the unsharded parameters after forward and re-all-gathers them in backward before gradient computation. After gradient computation, FSDP frees the unsharded parameters and reduce-scatters the unsharded gradients across data-parallel workers.

This implementation represents the sharded parameters as DTensor s sharded on dim-0, while the unsharded parameters will be like the original parameters on module (e.g. torch.Tensor if originally torch.Tensor). A module forward pre-hook on module all-gathers the parameters, and a module forward hook on module frees them (if needed). Similar backward hooks all-gather parameters and later free parameters and reduce-scatter gradients.

Since grouping multiple tensors together for one collective is critical for communication efficiency, this implementation makes this grouping first class. Calling fully_shard() on module constructs one group that includes the parameters in module.parameters() except those already assigned to a group from an earlier call on a submodule. This means that fully_shard() should be called bottom-up on your model. Each group’s parameters are all-gathered in one collective, and its gradients are reduce-scattered in one collective. Partitioning the model into multiple groups (“layer by layer”) allows for peak memory savings and communication/computation overlap. Users generally should not call fully_shard() only on the topmost root module.

Parameters
  • module (Union[nn.Module, List[nn.Module]) – The module or modules to shard with FSDP and group together for communication.

  • mesh (Optional[DeviceMesh]) – This data parallel mesh defines the sharding and device. If 1D, then parameters are fully sharded across the 1D mesh (FSDP) with (Shard(0),) placement. If 2D, then parameters are sharded across the 1st dim and replicated across the 0th dim (HSDP) with (Replicate(), Shard(0)) placement. The mesh’s device type gives the device type used for communication; if a CUDA or CUDA-like device type, then we use the current device.

  • reshard_after_forward (Union[bool, int]) –

    This controls the parameter behavior after forward and can trade off memory and communication:

    • If True, then this reshards parameters after forward and re-all-gathers in backward.

    • If False, then this keeps the unsharded parameters in memory after forward and avoids the all-gather in backward.

    • If an int, then this represents the world size to reshard to after forward. It should be a non-trivial divisor of the mesh shard dim size (i.e. excluding 1 and the dim size itself). A choice may be the intra-node size (e.g. torch.cuda.device_count()). This allows the all-gather in backward to be over a smaller world size at the cost of higher memory usage than setting to True.

    • The root FSDP state has its value specially set to False as a heuristic since its parameters would typically be immediately all-gathered for backward.

    • After forward, the parameters registered to the module depend on to this: The registered parameters are the sharded parameters if True; unsharded parameters if False; and the paramters resharded to the smaller mesh otherwise. To modify the parameters between forward and backward, the registered parameters must be the sharded parameters. For False or an int, this can be done by manually resharding via reshard().

  • shard_placement_fn (Optional[Callable[[nn.Parameter], Optional[Shard]]]) – This callable can be used to override the sharding placement for a parameter to shard a parameter on a dimension other than dim-0. If this callable returns a Shard placement (not None), then FSDP will shard according to that placement (e.g. Shard(1)). If sharding on a nonzero dim, we currently require even sharding, i.e. the tensor dim size on that dim must be divisible by the FSDP shard mesh size.

  • mp_policy (MixedPrecisionPolicy) – This controls the mixed precision policy, which offers parameter/reduction mixed precision for this module. See MixedPrecisionPolicy for details.

  • offload_policy (OffloadPolicy) – This controls the offloading policy, which offers parameter/gradient/optimizer state offloading. See OffloadPolicy and its subclasses for details.

Calling fully_shard(module) dynamically constructs a new class that subclasses type(module) and an FSDP class FSDPModule. For example, if we call fully_shard(linear) on a module linear: nn.Linear, then FSDP constructs a new class FSDPLinear and changes linear ‘s type to this. Otherwise, fully_shard does not change the module structure and parameter fully-qualified names. The class FSDPModule allows providing some FSDP-specific methods on the module.

class torch.distributed.fsdp.FSDPModule(*args, **kwargs)
reshard()[source][source]

Reshards the module’s parameters, freeing the unsharded parameters if they are allocated and registering the sharded parameters to the module. This method is not recursive.

set_is_last_backward(is_last_backward)[source][source]

Sets whether the next backward is the last one. On the last backward, FSDP waits on pending gradient reduction and clears internal data data structures for backward prefetching. This can be useful for microbatching.

set_modules_to_backward_prefetch(modules)[source][source]

Sets the FSDP modules for which this FSDP module should explicitly prefetch all-gathers in backward. This overrides the default backward pretching implementation that prefetches the next FSDP module based on the reverse post-forward order.

Passing a singleton list containing the previous FSDP module gives the same all-gather overlap behavior as the default overlap behavior. Passing a list with at least length two is required for more aggressive overlap and will use more reserved memory.

Parameters

modules (List[FSDPModule]) – FSDP modules to prefetch.

set_modules_to_forward_prefetch(modules)[source][source]

Sets the FSDP modules for which this FSDP module should explicitly prefetch all-gathers in forward. The prefetching runs after this module’s all-gather copy-out.

Passing a singleton list containing the next FSDP module gives the same all-gather overlap behavior as the default overlap behavior, except the prefetched all-gather is issued earlier from the CPU. Passing a list with at least length two is required for more aggressive overlap and will use more reserved memory.

Parameters

modules (List[FSDPModule]) – FSDP modules to prefetch.

set_post_optim_event(event)[source][source]

Sets a post-optimizer-step event for the root FSDP module to wait the all-gather streams on.

By default, the root FSDP module waits the all-gather streams on the current stream to ensure that the optimizer step has finished before all-gathering. However, this may introduce false dependencies if there is unrelated computation after the optimizer step. This API allows the user to provide their own event to wait on. After the root waits on the event, the event is discarded, so this API should be called with a new event each iteration.

Parameters

event (torch.Event) – Event recorded after the optimizer step to wait all-gather streams on.

set_reduce_scatter_divide_factor(factor)[source][source]

Sets a custom divide factor for the reduce-scatter. This becomes a custom reduce op using NCCL’s PreMulSum, which allows multiplying by the factor before reduction.

Parameters

factor (float) – Custom divide factor.

set_requires_all_reduce(requires_all_reduce, *, recurse=True)[source][source]

Sets if the module should all-reduce gradients. This can be used to implement gradient accumulation with only reduce-scatter but not all-reduce for HSDP.

set_requires_gradient_sync(requires_gradient_sync, *, recurse=True)[source][source]

Sets if the module should sync gradients. This can be used to implement gradient accumulation without communication. For HSDP, this controls both reduce-scatter and all-reduce together.

Parameters
  • requires_gradient_sync (bool) – Whether to reduce gradients for the module’s parameters.

  • recurse (bool) – Whether to set for all FSDP submodules or just the passed-in module.

set_reshard_after_backward(reshard_after_backward, *, recurse=True)[source][source]

Sets if the module should reshard parameters after backward. This can be used during gradient accumulation to trade off higher memory for reduced communication since the unsharded parameters do not need to be re-all-gathered before the next forward.

Parameters
  • reshard_after_backward (bool) – Whether to reshard parameters after backward.

  • recurse (bool) – Whether to set for all FSDP submodules or just the passed-in module.

set_unshard_in_backward(unshard_in_backward)[source][source]

Sets whether the FSDP module’s parameters need to be unsharded in backward. This can be used in expert cases when the user knows that all parameters in this FSDP module’s parameter group are not needed for backward computation (e.g. embedding).

unshard(async_op=False)[source][source]

Unshards the module’s parameters by allocating memory and all-gathering the parameters. This method is not recursive. The unshard follows the MixedPrecisionPolicy, so it will all-gather following param_dtype if set.

Parameters

async_op (bool) – If True, then returns a UnshardHandle that has a wait() method to wait on the unshard op. If False, then returns None and waits on the handle inside this function.

Return type

Optional[UnshardHandle]

Note

If async_op=True, then FSDP will wait on the pending unshard in the module’s pre-forward for the user. The user only needs to call wait() explicitly if the wait should happen before pre-forward.

class torch.distributed.fsdp.UnshardHandle

A handle to wait on a FSDPModule.unshard() op.

wait()[source][source]

Waits on the unshard op. This ensures that the current stream can use the unsharded parameters, which are now registered to the module.

torch.distributed.fsdp.register_fsdp_forward_method(module, method_name)[source]

Registers a method on module to be considered a forward method for FSDP.

FSDP all-gathers parameters pre-forward and optionally frees parameters post-forward (depending on reshard_after_forward). FSDP only knows to do this for nn.Module.forward() by default. This function patches a user-specified method to run the pre/post-forward hooks before/after the method, respectively. If module is not an FSDPModule, then this is a no-op.

Parameters
  • module (nn.Module) – Module to register the forward method on.

  • method_name (str) – Name of the forward method.

class torch.distributed.fsdp.MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True)

This configures FSDP’s mixed precision. Unlike autocast, this applies mixed precision at the module level, not op level, which means low-precision activations are saved for backward and high-to-low-precision casts are incurred only at module boundaries.

FSDP works well with module-level mixed precision since it keeps the high-precision sharded parameters in memory anyway. In other words, FSDP does not require any extra memory to keep a high-precision copy of the parameters for the optimizer step.

Variables
  • param_dtype (Optional[torch.dtype]) – This specifies the dtype for the unsharded parameter and hence the dtype for forward/backward computation and the parameter all-gather. If this is None, then the unsharded parameter uses the original dtype. The optimizer step uses the sharded parameter in the original dtype. (Default: None)

  • reduce_dtype (Optional[torch.dtype]) – This specifies the dtype for gradient reduction (i.e. reduce-scatter or all-reduce). If this is None but param_dtype is not None, then the reduction uses the compute dtype. This can be used to run gradient reduction in full precision while using low precision for compute. If also gradient reduction is disabled via set_requires_gradient_sync(), then FSDP will accumulate gradients using reduce_dtype. (Default: None)

  • output_dtype (Optional[torch.dtype]) – This specifies the dtype for casting floating-point forward outputs. This can be used to help implement cases where different modules have different mixed precision policies. (Default: None)

  • cast_forward_inputs (bool) – This specifies whether FSDP should cast the forward’s floating-point input tensors to param_dtype or not.

class torch.distributed.fsdp.OffloadPolicy

This base class represents the policy of no offloading and is only used as the default value for the offload_policy arg.

class torch.distributed.fsdp.CPUOffloadPolicy(pin_memory=True)

This offload policy offloads parameters, gradients, and optimizer states to CPU. Sharded parameters are copied host-to-device before all-gather. The all-gathered parameters are freed according to reshard_after_forward. Sharded gradients are copied device-to-host in backward, and the optimizer step runs on CPU with CPU optimizer states.

Variables

pin_memory (bool) – Whether to pin sharded parameter and gradient memory. Pinning memory allows both more efficient H2D/D2H copies and for the copies to overlap with compute. However, the pinned memory cannot be used by other processes. Set this to False if you have insufficient CPU memory. (Default: True)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources