torch.distributed.fsdp.fully_shard¶
PyTorch FSDP2 (fully_shard
)¶
PyTorch FSDP2 provides a fully sharded data parallelism (FSDP) implementation targeting performant eager-mode while using per-parameter sharding for improved usability.
If you are new to FSDP, we recommend that you start with FSDP2 due to improved usability.
If you are currently using FSDP1, consider evaluating the following differences to see if you should switch to FSDP2:
Compared to PyTorch FSDP1 (FullyShardedDataParallel
):
FSDP2 uses
DTensor
-based dim-0 per-parameter sharding for a simpler sharding representation compared to FSDP1’s flat-parameter sharding, while preserving similar throughput performance. More specifically, FSDP2 chunks each parameter on dim-0 across the data parallel workers (usingtorch.chunk(dim=0)
), whereas FSDP1 flattens, concatenates, and chunks a group of tensors together, making reasoning about what data is present on each worker and resharding to different parallelisms complex. Per-parameter sharding provides a more intuitive user experience, relaxes constraints around frozen parameters, and allows for communication-free (sharded) state dicts, which otherwise require all-gathers in FSDP1.FSDP2 implements a different memory management approach to handle the multi-stream usages that avoids
torch.Tensor.record_stream
. This ensures deterministic and expected memory usage and does not require blocking the CPU like in FSDP1’slimit_all_gathers=True
.FSDP2 exposes APIs for manual control over prefetching and collective scheduling, allowing power users more customization. See the methods on
FSDPModule
below for details.FSDP2 simplifies some of the API surface: e.g. FSDP2 does not directly support full state dicts. Instead, users can reshard the sharded state dicts containing
DTensor
s to full state dicts themselves usingDTensor
APIs likeDTensor.full_tensor()
or by using higher-level APIs like PyTorch Distributed Checkpoint ‘s distributed state dict APIs. Also, some other args have been removed; see here for details.
If you are onboarding FSDP for the first time or if any of the above appeals to your use case, we recommend that you consider using FSDP2.
See this RFC for details on system design and implementation.
Note
torch.distributed.fsdp.fully_shard
is currently in prototype state and
under development. The core API will likely not change, but we may make some
API changes if necessary.
The frontend API is fully_shard
that can be called on a module
:
- torch.distributed.fsdp.fully_shard(module, *, mesh=None, reshard_after_forward=True, shard_placement_fn=None, mp_policy=MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True), offload_policy=OffloadPolicy())[source]¶
Apply fully sharded data parallelism (FSDP) to
module
, where FSDP shards module parameters, gradients, and optimizer states across data parallel workers to save memory at the cost of communication.At initialization, FSDP shards the module’s parameters across the data parallel workers given by
mesh
. Before forward, FSDP all-gathers the sharded parameters across the data-parallel workers to get the unsharded parameters for forward computation. Ifreshard_after_forward
isTrue
, then FSDP frees the unsharded parameters after forward and re-all-gathers them in backward before gradient computation. After gradient computation, FSDP frees the unsharded parameters and reduce-scatters the unsharded gradients across data-parallel workers.This implementation represents the sharded parameters as
DTensor
s sharded on dim-0, while the unsharded parameters will be like the original parameters onmodule
(e.g.torch.Tensor
if originallytorch.Tensor
). A module forward pre-hook onmodule
all-gathers the parameters, and a module forward hook onmodule
frees them (if needed). Similar backward hooks all-gather parameters and later free parameters and reduce-scatter gradients.Since grouping multiple tensors together for one collective is critical for communication efficiency, this implementation makes this grouping first class. Calling
fully_shard()
onmodule
constructs one group that includes the parameters inmodule.parameters()
except those already assigned to a group from an earlier call on a submodule. This means thatfully_shard()
should be called bottom-up on your model. Each group’s parameters are all-gathered in one collective, and its gradients are reduce-scattered in one collective. Partitioning the model into multiple groups (“layer by layer”) allows for peak memory savings and communication/computation overlap. Users generally should not callfully_shard()
only on the topmost root module.- Parameters
module (Union[nn.Module, List[nn.Module]) – The module or modules to shard with FSDP and group together for communication.
mesh (Optional[DeviceMesh]) – This data parallel mesh defines the sharding and device. If 1D, then parameters are fully sharded across the 1D mesh (FSDP) with
(Shard(0),)
placement. If 2D, then parameters are sharded across the 1st dim and replicated across the 0th dim (HSDP) with(Replicate(), Shard(0))
placement. The mesh’s device type gives the device type used for communication; if a CUDA or CUDA-like device type, then we use the current device.reshard_after_forward (Union[bool, int]) –
This controls the parameter behavior after forward and can trade off memory and communication:
If
True
, then this reshards parameters after forward and re-all-gathers in backward.If
False
, then this keeps the unsharded parameters in memory after forward and avoids the all-gather in backward.If an
int
, then this represents the world size to reshard to after forward. It should be a non-trivial divisor of themesh
shard dim size (i.e. excluding 1 and the dim size itself). A choice may be the intra-node size (e.g.torch.cuda.device_count()
). This allows the all-gather in backward to be over a smaller world size at the cost of higher memory usage than setting toTrue
.The root FSDP state has its value specially set to
False
as a heuristic since its parameters would typically be immediately all-gathered for backward.After forward, the parameters registered to the module depend on to this: The registered parameters are the sharded parameters if
True
; unsharded parameters ifFalse
; and the paramters resharded to the smaller mesh otherwise. To modify the parameters between forward and backward, the registered parameters must be the sharded parameters. ForFalse
or anint
, this can be done by manually resharding viareshard()
.
shard_placement_fn (Optional[Callable[[nn.Parameter], Optional[Shard]]]) – This callable can be used to override the sharding placement for a parameter to shard a parameter on a dimension other than dim-0. If this callable returns a
Shard
placement (notNone
), then FSDP will shard according to that placement (e.g.Shard(1)
). If sharding on a nonzero dim, we currently require even sharding, i.e. the tensor dim size on that dim must be divisible by the FSDP shard mesh size.mp_policy (MixedPrecisionPolicy) – This controls the mixed precision policy, which offers parameter/reduction mixed precision for this module. See
MixedPrecisionPolicy
for details.offload_policy (OffloadPolicy) – This controls the offloading policy, which offers parameter/gradient/optimizer state offloading. See
OffloadPolicy
and its subclasses for details.
Calling fully_shard(module)
dynamically constructs a new class that
subclasses type(module)
and an FSDP class FSDPModule
. For example, if
we call fully_shard(linear)
on a module linear: nn.Linear
, then FSDP
constructs a new class FSDPLinear
and changes linear
‘s type to this.
Otherwise, fully_shard
does not change the module structure and parameter
fully-qualified names. The class FSDPModule
allows providing some
FSDP-specific methods on the module.
- class torch.distributed.fsdp.FSDPModule(*args, **kwargs)¶
- reshard()[source][source]¶
Reshards the module’s parameters, freeing the unsharded parameters if they are allocated and registering the sharded parameters to the module. This method is not recursive.
- set_is_last_backward(is_last_backward)[source][source]¶
Sets whether the next backward is the last one. On the last backward, FSDP waits on pending gradient reduction and clears internal data data structures for backward prefetching. This can be useful for microbatching.
- set_modules_to_backward_prefetch(modules)[source][source]¶
Sets the FSDP modules for which this FSDP module should explicitly prefetch all-gathers in backward. This overrides the default backward pretching implementation that prefetches the next FSDP module based on the reverse post-forward order.
Passing a singleton list containing the previous FSDP module gives the same all-gather overlap behavior as the default overlap behavior. Passing a list with at least length two is required for more aggressive overlap and will use more reserved memory.
- Parameters
modules (List[FSDPModule]) – FSDP modules to prefetch.
- set_modules_to_forward_prefetch(modules)[source][source]¶
Sets the FSDP modules for which this FSDP module should explicitly prefetch all-gathers in forward. The prefetching runs after this module’s all-gather copy-out.
Passing a singleton list containing the next FSDP module gives the same all-gather overlap behavior as the default overlap behavior, except the prefetched all-gather is issued earlier from the CPU. Passing a list with at least length two is required for more aggressive overlap and will use more reserved memory.
- Parameters
modules (List[FSDPModule]) – FSDP modules to prefetch.
- set_post_optim_event(event)[source][source]¶
Sets a post-optimizer-step event for the root FSDP module to wait the all-gather streams on.
By default, the root FSDP module waits the all-gather streams on the current stream to ensure that the optimizer step has finished before all-gathering. However, this may introduce false dependencies if there is unrelated computation after the optimizer step. This API allows the user to provide their own event to wait on. After the root waits on the event, the event is discarded, so this API should be called with a new event each iteration.
- Parameters
event (torch.Event) – Event recorded after the optimizer step to wait all-gather streams on.
- set_reduce_scatter_divide_factor(factor)[source][source]¶
Sets a custom divide factor for the reduce-scatter. This becomes a custom reduce op using NCCL’s PreMulSum, which allows multiplying by the factor before reduction.
- Parameters
factor (float) – Custom divide factor.
- set_requires_all_reduce(requires_all_reduce, *, recurse=True)[source][source]¶
Sets if the module should all-reduce gradients. This can be used to implement gradient accumulation with only reduce-scatter but not all-reduce for HSDP.
- set_requires_gradient_sync(requires_gradient_sync, *, recurse=True)[source][source]¶
Sets if the module should sync gradients. This can be used to implement gradient accumulation without communication. For HSDP, this controls both reduce-scatter and all-reduce together.
- set_reshard_after_backward(reshard_after_backward, *, recurse=True)[source][source]¶
Sets if the module should reshard parameters after backward. This can be used during gradient accumulation to trade off higher memory for reduced communication since the unsharded parameters do not need to be re-all-gathered before the next forward.
- set_unshard_in_backward(unshard_in_backward)[source][source]¶
Sets whether the FSDP module’s parameters need to be unsharded in backward. This can be used in expert cases when the user knows that all parameters in this FSDP module’s parameter group are not needed for backward computation (e.g. embedding).
- unshard(async_op=False)[source][source]¶
Unshards the module’s parameters by allocating memory and all-gathering the parameters. This method is not recursive. The unshard follows the
MixedPrecisionPolicy
, so it will all-gather followingparam_dtype
if set.- Parameters
async_op (bool) – If
True
, then returns aUnshardHandle
that has await()
method to wait on the unshard op. IfFalse
, then returnsNone
and waits on the handle inside this function.- Return type
Note
If
async_op=True
, then FSDP will wait on the pending unshard in the module’s pre-forward for the user. The user only needs to callwait()
explicitly if the wait should happen before pre-forward.
- class torch.distributed.fsdp.UnshardHandle¶
A handle to wait on a
FSDPModule.unshard()
op.
- torch.distributed.fsdp.register_fsdp_forward_method(module, method_name)[source]¶
Registers a method on
module
to be considered a forward method for FSDP.FSDP all-gathers parameters pre-forward and optionally frees parameters post-forward (depending on
reshard_after_forward
). FSDP only knows to do this fornn.Module.forward()
by default. This function patches a user-specified method to run the pre/post-forward hooks before/after the method, respectively. Ifmodule
is not anFSDPModule
, then this is a no-op.
- class torch.distributed.fsdp.MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True)¶
This configures FSDP’s mixed precision. Unlike autocast, this applies mixed precision at the module level, not op level, which means low-precision activations are saved for backward and high-to-low-precision casts are incurred only at module boundaries.
FSDP works well with module-level mixed precision since it keeps the high-precision sharded parameters in memory anyway. In other words, FSDP does not require any extra memory to keep a high-precision copy of the parameters for the optimizer step.
- Variables
param_dtype (Optional[torch.dtype]) – This specifies the dtype for the unsharded parameter and hence the dtype for forward/backward computation and the parameter all-gather. If this is
None
, then the unsharded parameter uses the original dtype. The optimizer step uses the sharded parameter in the original dtype. (Default:None
)reduce_dtype (Optional[torch.dtype]) – This specifies the dtype for gradient reduction (i.e. reduce-scatter or all-reduce). If this is
None
butparam_dtype
is notNone
, then the reduction uses the compute dtype. This can be used to run gradient reduction in full precision while using low precision for compute. If also gradient reduction is disabled viaset_requires_gradient_sync()
, then FSDP will accumulate gradients usingreduce_dtype
. (Default:None
)output_dtype (Optional[torch.dtype]) – This specifies the dtype for casting floating-point forward outputs. This can be used to help implement cases where different modules have different mixed precision policies. (Default:
None
)cast_forward_inputs (bool) – This specifies whether FSDP should cast the forward’s floating-point input tensors to
param_dtype
or not.
- class torch.distributed.fsdp.OffloadPolicy¶
This base class represents the policy of no offloading and is only used as the default value for the
offload_policy
arg.
- class torch.distributed.fsdp.CPUOffloadPolicy(pin_memory=True)¶
This offload policy offloads parameters, gradients, and optimizer states to CPU. Sharded parameters are copied host-to-device before all-gather. The all-gathered parameters are freed according to
reshard_after_forward
. Sharded gradients are copied device-to-host in backward, and the optimizer step runs on CPU with CPU optimizer states.- Variables
pin_memory (bool) – Whether to pin sharded parameter and gradient memory. Pinning memory allows both more efficient H2D/D2H copies and for the copies to overlap with compute. However, the pinned memory cannot be used by other processes. Set this to
False
if you have insufficient CPU memory. (Default:True
)