FullyShardedDataParallel¶
- class torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=True, use_orig_params=False, ignored_states=None, device_mesh=None)[source][source]¶
A wrapper for sharding module parameters across data parallel workers.
This is inspired by Xu et al. as well as the ZeRO Stage 3 from DeepSpeed. FullyShardedDataParallel is commonly shortened to FSDP.
To understand FSDP internals, refer to the FSDP Notes.
Example:
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> torch.cuda.set_device(device_id) >>> sharded_module = FSDP(my_module) >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) >>> x = sharded_module(x, y=3, z=torch.Tensor([1])) >>> loss = x.sum() >>> loss.backward() >>> optim.step()
Using FSDP involves wrapping your module and then initializing your optimizer after. This is required since FSDP changes the parameter variables.
When setting up FSDP, you need to consider the destination CUDA device. If the device has an ID (
dev_id
), you have three options:Place the module on that device
Set the device using
torch.cuda.set_device(dev_id)
Pass
dev_id
into thedevice_id
constructor argument.
This ensures that the FSDP instance’s compute device is the destination device. For option 1 and 3, the FSDP initialization always occurs on GPU. For option 2, the FSDP initialization happens on module’s current device, which may be a CPU.
If you’re using the
sync_module_states=True
flag, you need to ensure that the module is on a GPU or use thedevice_id
argument to specify a CUDA device that FSDP will move the module to in the FSDP constructor. This is necessary becausesync_module_states=True
requires GPU communication.FSDP also takes care of moving input tensors to the forward method to the GPU compute device, so you don’t need to manually move them from CPU.
For
use_orig_params=True
,ShardingStrategy.SHARD_GRAD_OP
exposes the unsharded parameters, not the sharded parameters after forward, unlikeShardingStrategy.FULL_SHARD
. If you want to inspect the gradients, you can use thesummon_full_params
method withwith_grads=True
.With
limit_all_gathers=True
, you may see a gap in the FSDP pre-forward where the CPU thread is not issuing any kernels. This is intentional and shows the rate limiter in effect. Synchronizing the CPU thread in that way prevents over-allocating memory for subsequent all-gathers, and it should not actually delay GPU kernel execution.FSDP replaces managed modules’ parameters with
torch.Tensor
views during forward and backward computation for autograd-related reasons. If your module’s forward relies on saved references to the parameters instead of reacquiring the references each iteration, then it will not see FSDP’s newly created views, and autograd will not work correctly.Finally, when using
sharding_strategy=ShardingStrategy.HYBRID_SHARD
with the sharding process group being intra-node and the replication process group being inter-node, settingNCCL_CROSS_NIC=1
can help improve the all-reduce times over the replication process group for some cluster setups.Limitations
There are several limitations to be aware of when using FSDP:
FSDP currently does not support gradient accumulation outside
no_sync()
when using CPU offloading. This is because FSDP uses the newly-reduced gradient instead of accumulating with any existing gradient, which can lead to incorrect results.FSDP does not support running the forward pass of a submodule that is contained in an FSDP instance. This is because the submodule’s parameters will be sharded, but the submodule itself is not an FSDP instance, so its forward pass will not all-gather the full parameters appropriately.
FSDP does not work with double backwards due to the way it registers backward hooks.
FSDP has some constraints when freezing parameters. For
use_orig_params=False
, each FSDP instance must manage parameters that are all frozen or all non-frozen. Foruse_orig_params=True
, FSDP supports mixing frozen and non-frozen parameters, but it’s recommended to avoid doing so to prevent higher than expected gradient memory usage.As of PyTorch 1.12, FSDP offers limited support for shared parameters. If enhanced shared parameter support is needed for your use case, please post in this issue.
You should avoid modifying the parameters between forward and backward without using the
summon_full_params
context, as the modifications may not persist.
- Parameters
module (nn.Module) – This is the module to be wrapped with FSDP.
process_group (Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]) – This is the process group over which the model is sharded and thus the one used for FSDP’s all-gather and reduce-scatter collective communications. If
None
, then FSDP uses the default process group. For hybrid sharding strategies such asShardingStrategy.HYBRID_SHARD
, users can pass in a tuple of process groups, representing the groups over which to shard and replicate, respectively. IfNone
, then FSDP constructs process groups for the user to shard intra-node and replicate inter-node. (Default:None
)sharding_strategy (Optional[ShardingStrategy]) – This configures the sharding strategy, which may trade off memory saving and communication overhead. See
ShardingStrategy
for details. (Default:FULL_SHARD
)cpu_offload (Optional[CPUOffload]) – This configures CPU offloading. If this is set to
None
, then no CPU offloading happens. SeeCPUOffload
for details. (Default:None
)auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy, CustomPolicy]]) –
This specifies a policy to apply FSDP to submodules of
module
, which is needed for communication and computation overlap and thus affects performance. IfNone
, then FSDP only applies tomodule
, and users should manually apply FSDP to parent modules themselves (proceeding bottom-up). For convenience, this acceptsModuleWrapPolicy
directly, which allows users to specify the module classes to wrap (e.g. the transformer block). Otherwise, this should be a callable that takes in three argumentsmodule: nn.Module
,recurse: bool
, andnonwrapped_numel: int
and should return abool
specifying whether the passed-inmodule
should have FSDP applied ifrecurse=False
or if the traversal should continue into the module’s subtree ifrecurse=True
. Users may add additional arguments to the callable. Thesize_based_auto_wrap_policy
intorch.distributed.fsdp.wrap.py
gives an example callable that applies FSDP to a module if the parameters in its subtree exceed 100M numel. We recommend printing the model after applying FSDP and adjusting as needed.Example:
>>> def custom_auto_wrap_policy( >>> module: nn.Module, >>> recurse: bool, >>> nonwrapped_numel: int, >>> # Additional custom arguments >>> min_num_params: int = int(1e8), >>> ) -> bool: >>> return nonwrapped_numel >= min_num_params >>> # Configure a custom `min_num_params` >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
backward_prefetch (Optional[BackwardPrefetch]) – This configures explicit backward prefetching of all-gathers. If
None
, then FSDP does not backward prefetch, and there is no communication and computation overlap in the backward pass. SeeBackwardPrefetch
for details. (Default:BACKWARD_PRE
)mixed_precision (Optional[MixedPrecision]) – This configures native mixed precision for FSDP. If this is set to
None
, then no mixed precision is used. Otherwise, parameter, buffer, and gradient reduction dtypes can be set. SeeMixedPrecision
for details. (Default:None
)ignored_modules (Optional[Iterable[torch.nn.Module]]) – Modules whose own parameters and child modules’ parameters and buffers are ignored by this instance. None of the modules directly in
ignored_modules
should beFullyShardedDataParallel
instances, and any child modules that are already-constructedFullyShardedDataParallel
instances will not be ignored if they are nested under this instance. This argument may be used to avoid sharding specific parameters at module granularity when using anauto_wrap_policy
or if parameters’ sharding is not managed by FSDP. (Default:None
)param_init_fn (Optional[Callable[[nn.Module], None]]) –
A
Callable[torch.nn.Module] -> None
that specifies how modules that are currently on the meta device should be initialized onto an actual device. As of v1.12, FSDP detects modules with parameters or buffers on meta device viais_meta
and either appliesparam_init_fn
if specified or callsnn.Module.reset_parameters()
otherwise. For both cases, the implementation should only initialize the parameters/buffers of the module, not those of its submodules. This is to avoid re-initialization. In addition, FSDP also supports deferred initialization via torchdistX’s (https://github.com/pytorch/torchdistX)deferred_init()
API, where the deferred modules are initialized by callingparam_init_fn
if specified or torchdistX’s defaultmaterialize_module()
otherwise. Ifparam_init_fn
is specified, then it is applied to all meta-device modules, meaning that it should probably case on the module type. FSDP calls the initialization function before parameter flattening and sharding.Example:
>>> module = MyModule(device="meta") >>> def my_init_fn(module: nn.Module): >>> # E.g. initialize depending on the module type >>> ... >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy) >>> print(next(fsdp_model.parameters()).device) # current CUDA device >>> # With torchdistX >>> module = deferred_init.deferred_init(MyModule, device="cuda") >>> # Will initialize via deferred_init.materialize_module(). >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
device_id (Optional[Union[int, torch.device]]) – An
int
ortorch.device
giving the CUDA device on which FSDP initialization takes place, including the module initialization if needed and the parameter sharding. This should be specified to improve initialization speed ifmodule
is on CPU. If the default CUDA device was set (e.g. viatorch.cuda.set_device
), then the user may passtorch.cuda.current_device
to this. (Default:None
)sync_module_states (bool) – If
True
, then each FSDP module will broadcast module parameters and buffers from rank 0 to ensure that they are replicated across ranks (adding communication overhead to this constructor). This can help loadstate_dict
checkpoints viaload_state_dict
in a memory efficient way. SeeFullStateDictConfig
for an example of this. (Default:False
)forward_prefetch (bool) – If
True
, then FSDP explicitly prefetches the next forward-pass all-gather before the current forward computation. This is only useful for CPU-bound workloads, in which case issuing the next all-gather earlier may improve overlap. This should only be used for static-graph models since the prefetching follows the first iteration’s execution order. (Default:False
)limit_all_gathers (bool) – If
True
, then FSDP explicitly synchronizes the CPU thread to ensure GPU memory usage from only two consecutive FSDP instances (the current instance running computation and the next instance whose all-gather is prefetched). IfFalse
, then FSDP allows the CPU thread to issue all-gathers without any extra synchronization. (Default:True
) We often refer to this feature as the “rate limiter”. This flag should only be set toFalse
for specific CPU-bound workloads with low memory pressure in which case the CPU thread can aggressively issue all kernels without concern for the GPU memory usage.use_orig_params (bool) – Setting this to
True
has FSDP usemodule
‘s original parameters. FSDP exposes those original parameters to the user viann.Module.named_parameters()
instead of FSDP’s internalFlatParameter
s. This means that the optimizer step runs on the original parameters, enabling per-original-parameter hyperparameters. FSDP preserves the original parameter variables and manipulates their data between unsharded and sharded forms, where they are always views into the underlying unsharded or shardedFlatParameter
, respectively. With the current algorithm, the sharded form is always 1D, losing the original tensor structure. An original parameter may have all, some, or none of its data present for a given rank. In the none case, its data will be like a size-0 empty tensor. Users should not author programs relying on what data is present for a given original parameter in its sharded form.True
is required to usetorch.compile()
. Setting this toFalse
exposes FSDP’s internalFlatParameter
s to the user viann.Module.named_parameters()
. (Default:False
)ignored_states (Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]) – Ignored parameters or modules that will not be managed by this FSDP instance, meaning that the parameters are not sharded and their gradients are not reduced across ranks. This argument unifies with the existing
ignored_modules
argument, and we may deprecateignored_modules
soon. For backward compatibility, we keep bothignored_states
and ignored_modules`, but FSDP only allows one of them to be specified as notNone
.device_mesh (Optional[DeviceMesh]) – DeviceMesh can be used as an altenative to process_group. When device_mesh is passed, FSDP will use the underlying process groups for all-gather and reduce-scatter collective communications. Therefore, these two args need to be mutually exclusive. For hybrid sharding strategies such as
ShardingStrategy.HYBRID_SHARD
, users can pass in a 2D DeviceMesh instead of a tuple of process groups. For 2D FSDP + TP, users are required to pass in device_mesh instead of process_group. For more DeviceMesh info, please visit: https://pytorch.org/tutorials/recipes/distributed_device_mesh.html
- apply(fn)[source][source]¶
Apply
fn
recursively to every submodule (as returned by.children()
) as well as self.Typical use includes initializing the parameters of a model (see also torch.nn.init).
Compared to
torch.nn.Module.apply
, this version additionally gathers the full parameters before applyingfn
. It should not be called from within anothersummon_full_params
context.- Parameters
fn (
Module
-> None) – function to be applied to each submodule- Returns
self
- Return type
- Module
- clip_grad_norm_(max_norm, norm_type=2.0)[source][source]¶
Clip the gradient norm of all parameters.
The norm is computed over all parameters’ gradients as viewed as a single vector, and the gradients are modified in-place.
- Parameters
- Returns
Total norm of the parameters (viewed as a single vector).
- Return type
If every FSDP instance uses
NO_SHARD
, meaning that no gradients are sharded across ranks, then you may directly usetorch.nn.utils.clip_grad_norm_()
.If at least some FSDP instance uses a sharded strategy (i.e. one other than
NO_SHARD
), then you should use this method instead oftorch.nn.utils.clip_grad_norm_()
since this method handles the fact that gradients are sharded across ranks.The total norm returned will have the “largest” dtype across all parameters/gradients as defined by PyTorch’s type promotion semantics. For example, if all parameters/gradients use a low precision dtype, then the returned norm’s dtype will be that low precision dtype, but if there exists at least one parameter/ gradient using FP32, then the returned norm’s dtype will be FP32.
Warning
This needs to be called on all ranks since it uses collective communications.
- static flatten_sharded_optim_state_dict(sharded_optim_state_dict, model, optim)[source][source]¶
Flatten a sharded optimizer state-dict.
The API is similar to
shard_full_optim_state_dict()
. The only difference is that the inputsharded_optim_state_dict
should be returned fromsharded_optim_state_dict()
. Therefore, there will be all-gather calls on each rank to gatherShardedTensor
s.- Parameters
sharded_optim_state_dict (Dict[str, Any]) – Optimizer state dict corresponding to the unflattened parameters and holding the sharded optimizer state.
model (torch.nn.Module) – Refer to
shard_full_optim_state_dict()
.optim (torch.optim.Optimizer) – Optimizer for
model
‘s parameters.
- Returns
Refer to
shard_full_optim_state_dict()
.- Return type
- forward(*args, **kwargs)[source][source]¶
Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic.
- Return type
- static fsdp_modules(module, root_only=False)[source][source]¶
Return all nested FSDP instances.
This possibly includes
module
itself and only includes FSDP root modules ifroot_only=True
.- Parameters
module (torch.nn.Module) – Root module, which may or may not be an
FSDP
module.root_only (bool) – Whether to return only FSDP root modules. (Default:
False
)
- Returns
FSDP modules that are nested in the input
module
.- Return type
List[FullyShardedDataParallel]
- static full_optim_state_dict(model, optim, optim_input=None, rank0_only=True, group=None)[source][source]¶
Return the full optimizer state-dict.
Consolidates the full optimizer state on rank 0 and returns it as a
dict
following the convention oftorch.optim.Optimizer.state_dict()
, i.e. with keys"state"
and"param_groups"
. The flattened parameters inFSDP
modules contained inmodel
are mapped back to their unflattened parameters.This needs to be called on all ranks since it uses collective communications. However, if
rank0_only=True
, then the state dict is only populated on rank 0, and all other ranks return an emptydict
.Unlike
torch.optim.Optimizer.state_dict()
, this method uses full parameter names as keys instead of parameter IDs.Like in
torch.optim.Optimizer.state_dict()
, the tensors contained in the optimizer state dict are not cloned, so there may be aliasing surprises. For best practices, consider saving the returned optimizer state dict immediately, e.g. usingtorch.save()
.- Parameters
model (torch.nn.Module) – Root module (which may or may not be a
FullyShardedDataParallel
instance) whose parameters were passed into the optimizeroptim
.optim (torch.optim.Optimizer) – Optimizer for
model
‘s parameters.optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer
optim
representing either alist
of parameter groups or an iterable of parameters; ifNone
, then this method assumes the input wasmodel.parameters()
. This argument is deprecated, and there is no need to pass it in anymore. (Default:None
)rank0_only (bool) – If
True
, saves the populateddict
only on rank 0; ifFalse
, saves it on all ranks. (Default:True
)group (dist.ProcessGroup) – Model’s process group or
None
if using the default process group. (Default:None
)
- Returns
A
dict
containing the optimizer state formodel
‘s original unflattened parameters and including keys “state” and “param_groups” following the convention oftorch.optim.Optimizer.state_dict()
. Ifrank0_only=True
, then nonzero ranks return an emptydict
.- Return type
Dict[str, Any]
- static get_state_dict_type(module)[source][source]¶
Get the state_dict_type and the corresponding configurations for the FSDP modules rooted at
module
.The target module does not have to be an FSDP module.
- Returns
A
StateDictSettings
containing the state_dict_type and state_dict / optim_state_dict configs that are currently set.- Raises
AssertionError` if the StateDictSettings for differen –
FSDP submodules differ. –
- Return type
- named_buffers(*args, **kwargs)[source][source]¶
Return an iterator over module buffers, yielding both the name of the buffer and the buffer itself.
Intercepts buffer names and removes all occurrences of the FSDP-specific flattened buffer prefix when inside the
summon_full_params()
context manager.
- named_parameters(*args, **kwargs)[source][source]¶
Return an iterator over module parameters, yielding both the name of the parameter and the parameter itself.
Intercepts parameter names and removes all occurrences of the FSDP-specific flattened parameter prefix when inside the
summon_full_params()
context manager.
- no_sync()[source][source]¶
Disable gradient synchronizations across FSDP instances.
Within this context, gradients will be accumulated in module variables, which will later be synchronized in the first forward-backward pass after exiting the context. This should only be used on the root FSDP instance and will recursively apply to all children FSDP instances.
Note
This likely results in higher memory usage because FSDP will accumulate the full model gradients (instead of gradient shards) until the eventual sync.
Note
When used with CPU offloading, the gradients will not be offloaded to CPU when inside the context manager. Instead, they will only be offloaded right after the eventual sync.
- Return type
- static optim_state_dict(model, optim, optim_state_dict=None, group=None)[source][source]¶
Transform the state-dict of an optimizer corresponding to a sharded model.
The given state-dict can be transformed to one of three types: 1) full optimizer state_dict, 2) sharded optimizer state_dict, 3) local optimizer state_dict.
For full optimizer state_dict, all states are unflattened and not sharded. Rank0 only and CPU only can be specified via
state_dict_type()
to avoid OOM.For sharded optimizer state_dict, all states are unflattened but sharded. CPU only can be specified via
state_dict_type()
to further save memory.For local state_dict, no transformation will be performed. But a state will be converted from nn.Tensor to ShardedTensor to represent its sharding nature (this is not supported yet).
Example:
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> from torch.distributed.fsdp import FullStateDictConfig >>> from torch.distributed.fsdp import FullOptimStateDictConfig >>> # Save a checkpoint >>> model, optim = ... >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> state_dict = model.state_dict() >>> optim_state_dict = FSDP.optim_state_dict(model, optim) >>> save_a_checkpoint(state_dict, optim_state_dict) >>> # Load a checkpoint >>> model, optim = ... >>> state_dict, optim_state_dict = load_a_checkpoint() >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> model.load_state_dict(state_dict) >>> optim_state_dict = FSDP.optim_state_dict_to_load( >>> model, optim, optim_state_dict >>> ) >>> optim.load_state_dict(optim_state_dict)
- Parameters
model (torch.nn.Module) – Root module (which may or may not be a
FullyShardedDataParallel
instance) whose parameters were passed into the optimizeroptim
.optim (torch.optim.Optimizer) – Optimizer for
model
‘s parameters.optim_state_dict (Dict[str, Any]) – the target optimizer state_dict to transform. If the value is None, optim.state_dict() will be used. ( Default:
None
)group (dist.ProcessGroup) – Model’s process group across which parameters are sharded or
None
if using the default process group. ( Default:None
)
- Returns
A
dict
containing the optimizer state formodel
. The sharding of the optimizer state is based onstate_dict_type
.- Return type
Dict[str, Any]
- static optim_state_dict_to_load(model, optim, optim_state_dict, is_named_optimizer=False, load_directly=False, group=None)[source][source]¶
Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model.
Given a
optim_state_dict
that is transformed throughoptim_state_dict()
, it gets converted to the flattened optimizer state_dict that can be loaded tooptim
which is the optimizer formodel
.model
must be sharded by FullyShardedDataParallel.>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> from torch.distributed.fsdp import FullStateDictConfig >>> from torch.distributed.fsdp import FullOptimStateDictConfig >>> # Save a checkpoint >>> model, optim = ... >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> state_dict = model.state_dict() >>> original_osd = optim.state_dict() >>> optim_state_dict = FSDP.optim_state_dict( >>> model, >>> optim, >>> optim_state_dict=original_osd >>> ) >>> save_a_checkpoint(state_dict, optim_state_dict) >>> # Load a checkpoint >>> model, optim = ... >>> state_dict, optim_state_dict = load_a_checkpoint() >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> model.load_state_dict(state_dict) >>> optim_state_dict = FSDP.optim_state_dict_to_load( >>> model, optim, optim_state_dict >>> ) >>> optim.load_state_dict(optim_state_dict)
- Parameters
model (torch.nn.Module) – Root module (which may or may not be a
FullyShardedDataParallel
instance) whose parameters were passed into the optimizeroptim
.optim (torch.optim.Optimizer) – Optimizer for
model
‘s parameters.optim_state_dict (Dict[str, Any]) – The optimizer states to be loaded.
is_named_optimizer (bool) – Is this optimizer a NamedOptimizer or KeyedOptimizer. Only set to True if
optim
is TorchRec’s KeyedOptimizer or torch.distributed’s NamedOptimizer.load_directly (bool) – If this is set to True, this API will also call optim.load_state_dict(result) before returning the result. Otherwise, users are responsible to call
optim.load_state_dict()
(Default:False
)group (dist.ProcessGroup) – Model’s process group across which parameters are sharded or
None
if using the default process group. ( Default:None
)
- Return type
- register_comm_hook(state, hook)[source][source]¶
Register a communication hook.
This is an enhancement that provides a flexible hook to users where they can specify how FSDP aggregates gradients across multiple workers. This hook can be used to implement several algorithms like GossipGrad and gradient compression which involve different communication strategies for parameter syncs while training with
FullyShardedDataParallel
.Warning
FSDP communication hook should be registered before running an initial forward pass and only once.
- Parameters
state (object) –
Passed to the hook to maintain any state information during the training process. Examples include error feedback in gradient compression, peers to communicate with next in GossipGrad, etc. It is locally stored by each worker and shared by all the gradient tensors on the worker.
hook (Callable) – Callable, which has one of the following signatures: 1)
hook: Callable[torch.Tensor] -> None
: This function takes in a Python tensor, which represents the full, flattened, unsharded gradient with respect to all variables corresponding to the model this FSDP unit is wrapping (that are not wrapped by other FSDP sub-units). It then performs all necessary processing and returnsNone
; 2)hook: Callable[torch.Tensor, torch.Tensor] -> None
: This function takes in two Python tensors, the first one represents the full, flattened, unsharded gradient with respect to all variables corresponding to the model this FSDP unit is wrapping (that are not wrapped by other FSDP sub-units). The latter represents a pre-sized tensor to store a chunk of a sharded gradient after reduction. In both cases, callable performs all necessary processing and returnsNone
. Callables with signature 1 are expected to handle gradient communication for a NO_SHARD case. Callables with signature 2 are expected to handle gradient communication for sharded cases.
- static rekey_optim_state_dict(optim_state_dict, optim_state_key_type, model, optim_input=None, optim=None)[source][source]¶
Re-keys the optimizer state dict
optim_state_dict
to use the key typeoptim_state_key_type
.This can be used to achieve compatibility between optimizer state dicts from models with FSDP instances and ones without.
To re-key an FSDP full optimizer state dict (i.e. from
full_optim_state_dict()
) to use parameter IDs and be loadable to a non-wrapped model:>>> wrapped_model, wrapped_optim = ... >>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) >>> nonwrapped_model, nonwrapped_optim = ... >>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) >>> nonwrapped_optim.load_state_dict(rekeyed_osd)
To re-key a normal optimizer state dict from a non-wrapped model to be loadable to a wrapped model:
>>> nonwrapped_model, nonwrapped_optim = ... >>> osd = nonwrapped_optim.state_dict() >>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) >>> wrapped_model, wrapped_optim = ... >>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) >>> wrapped_optim.load_state_dict(sharded_osd)
- Returns
The optimizer state dict re-keyed using the parameter keys specified by
optim_state_key_type
.- Return type
Dict[str, Any]
- static scatter_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None, group=None)[source][source]¶
Scatter the full optimizer state dict from rank 0 to all other ranks.
Returns the sharded optimizer state dict on each rank. The return value is the same as
shard_full_optim_state_dict()
, and on rank 0, the first argument should be the return value offull_optim_state_dict()
.Example:
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 >>> # Define new model with possibly different world size >>> new_model, new_optim, new_group = ... >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) >>> new_optim.load_state_dict(sharded_osd)
Note
Both
shard_full_optim_state_dict()
andscatter_full_optim_state_dict()
may be used to get the sharded optimizer state dict to load. Assuming that the full optimizer state dict resides in CPU memory, the former requires each rank to have the full dict in CPU memory, where each rank individually shards the dict without any communication, while the latter requires only rank 0 to have the full dict in CPU memory, where rank 0 moves each shard to GPU memory (for NCCL) and communicates it to ranks appropriately. Hence, the former has higher aggregate CPU memory cost, while the latter has higher communication cost.- Parameters
full_optim_state_dict (Optional[Dict[str, Any]]) – Optimizer state dict corresponding to the unflattened parameters and holding the full non-sharded optimizer state if on rank 0; the argument is ignored on nonzero ranks.
model (torch.nn.Module) – Root module (which may or may not be a
FullyShardedDataParallel
instance) whose parameters correspond to the optimizer state infull_optim_state_dict
.optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer representing either a
list
of parameter groups or an iterable of parameters; ifNone
, then this method assumes the input wasmodel.parameters()
. This argument is deprecated, and there is no need to pass it in anymore. (Default:None
)optim (Optional[torch.optim.Optimizer]) – Optimizer that will load the state dict returned by this method. This is the preferred argument to use over
optim_input
. (Default:None
)group (dist.ProcessGroup) – Model’s process group or
None
if using the default process group. (Default:None
)
- Returns
The full optimizer state dict now remapped to flattened parameters instead of unflattened parameters and restricted to only include this rank’s part of the optimizer state.
- Return type
Dict[str, Any]
- static set_state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source][source]¶
Set the
state_dict_type
of all the descendant FSDP modules of the target module.Also takes (optional) configuration for the model’s and optimizer’s state dict. The target module does not have to be a FSDP module. If the target module is a FSDP module, its
state_dict_type
will also be changed.Note
This API should be called for only the top-level (root) module.
Note
This API enables users to transparently use the conventional
state_dict
API to take model checkpoints in cases where the root FSDP module is wrapped by anothernn.Module
. For example, the following will ensurestate_dict
is called on all non-FSDP instances, while dispatching into sharded_state_dict implementation for FSDP:Example:
>>> model = DDP(FSDP(...)) >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.SHARDED_STATE_DICT, >>> state_dict_config = ShardedStateDictConfig(offload_to_cpu=True), >>> optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True), >>> ) >>> param_state_dict = model.state_dict() >>> optim_state_dict = FSDP.optim_state_dict(model, optim)
- Parameters
module (torch.nn.Module) – Root module.
state_dict_type (StateDictType) – the desired
state_dict_type
to set.state_dict_config (Optional[StateDictConfig]) – the configuration for the target
state_dict_type
.optim_state_dict_config (Optional[OptimStateDictConfig]) – the configuration for the optimizer state dict.
- Returns
A StateDictSettings that include the previous state_dict type and configuration for the module.
- Return type
- static shard_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None)[source][source]¶
Shard a full optimizer state-dict.
Remaps the state in
full_optim_state_dict
to flattened parameters instead of unflattened parameters and restricts to only this rank’s part of the optimizer state. The first argument should be the return value offull_optim_state_dict()
.Example:
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) >>> torch.save(full_osd, PATH) >>> # Define new model with possibly different world size >>> new_model, new_optim = ... >>> full_osd = torch.load(PATH) >>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model) >>> new_optim.load_state_dict(sharded_osd)
Note
Both
shard_full_optim_state_dict()
andscatter_full_optim_state_dict()
may be used to get the sharded optimizer state dict to load. Assuming that the full optimizer state dict resides in CPU memory, the former requires each rank to have the full dict in CPU memory, where each rank individually shards the dict without any communication, while the latter requires only rank 0 to have the full dict in CPU memory, where rank 0 moves each shard to GPU memory (for NCCL) and communicates it to ranks appropriately. Hence, the former has higher aggregate CPU memory cost, while the latter has higher communication cost.- Parameters
full_optim_state_dict (Dict[str, Any]) – Optimizer state dict corresponding to the unflattened parameters and holding the full non-sharded optimizer state.
model (torch.nn.Module) – Root module (which may or may not be a
FullyShardedDataParallel
instance) whose parameters correspond to the optimizer state infull_optim_state_dict
.optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer representing either a
list
of parameter groups or an iterable of parameters; ifNone
, then this method assumes the input wasmodel.parameters()
. This argument is deprecated, and there is no need to pass it in anymore. (Default:None
)optim (Optional[torch.optim.Optimizer]) – Optimizer that will load the state dict returned by this method. This is the preferred argument to use over
optim_input
. (Default:None
)
- Returns
The full optimizer state dict now remapped to flattened parameters instead of unflattened parameters and restricted to only include this rank’s part of the optimizer state.
- Return type
Dict[str, Any]
- static sharded_optim_state_dict(model, optim, group=None)[source][source]¶
Return the optimizer state-dict in its sharded form.
The API is similar to
full_optim_state_dict()
but this API chunks all non-zero-dimension states toShardedTensor
to save memory. This API should only be used when the modelstate_dict
is derived with the context managerwith state_dict_type(SHARDED_STATE_DICT):
.For the detailed usage, refer to
full_optim_state_dict()
.Warning
The returned state dict contains
ShardedTensor
and cannot be directly used by the regularoptim.load_state_dict
.
- static state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source][source]¶
Set the
state_dict_type
of all the descendant FSDP modules of the target module.This context manager has the same functions as
set_state_dict_type()
. Read the document ofset_state_dict_type()
for the detail.Example:
>>> model = DDP(FSDP(...)) >>> with FSDP.state_dict_type( >>> model, >>> StateDictType.SHARDED_STATE_DICT, >>> ): >>> checkpoint = model.state_dict()
- Parameters
module (torch.nn.Module) – Root module.
state_dict_type (StateDictType) – the desired
state_dict_type
to set.state_dict_config (Optional[StateDictConfig]) – the model
state_dict
configuration for the targetstate_dict_type
.optim_state_dict_config (Optional[OptimStateDictConfig]) – the optimizer
state_dict
configuration for the targetstate_dict_type
.
- Return type
- static summon_full_params(module, recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False, with_grads=False)[source][source]¶
Expose full params for FSDP instances with this context manager.
Can be useful after forward/backward for a model to get the params for additional processing or checking. It can take a non-FSDP module and will summon full params for all contained FSDP modules as well as their children, depending on the
recurse
argument.Note
This can be used on inner FSDPs.
Note
This can not be used within a forward or backward pass. Nor can forward and backward be started from within this context.
Note
Parameters will revert to their local shards after the context manager exits, storage behavior is the same as forward.
Note
The full parameters can be modified, but only the portion corresponding to the local param shard will persist after the context manager exits (unless
writeback=False
, in which case changes will be discarded). In the case where FSDP does not shard the parameters, currently only whenworld_size == 1
, orNO_SHARD
config, the modification is persisted regardless ofwriteback
.Note
This method works on modules which are not FSDP themselves but may contain multiple independent FSDP units. In that case, the given arguments will apply to all contained FSDP units.
Warning
Note that
rank0_only=True
in conjunction withwriteback=True
is not currently supported and will raise an error. This is because model parameter shapes would be different across ranks within the context, and writing to them can lead to inconsistency across ranks when the context is exited.Warning
Note that
offload_to_cpu
andrank0_only=False
will result in full parameters being redundantly copied to CPU memory for GPUs that reside on the same machine, which may incur the risk of CPU OOM. It is recommended to useoffload_to_cpu
withrank0_only=True
.- Parameters
recurse (bool, Optional) – recursively summon all params for nested FSDP instances (default: True).
writeback (bool, Optional) – if
False
, modifications to params are discarded after the context manager exits; disabling this can be slightly more efficient (default: True)rank0_only (bool, Optional) – if
True
, full parameters are materialized on only global rank 0. This means that within the context, only rank 0 will have full parameters and the other ranks will have sharded parameters. Note that settingrank0_only=True
withwriteback=True
is not supported, as model parameter shapes will be different across ranks within the context, and writing to them can lead to inconsistency across ranks when the context is exited.offload_to_cpu (bool, Optional) – If
True
, full parameters are offloaded to CPU. Note that this offloading currently only occurs if the parameter is sharded (which is only not the case for world_size = 1 orNO_SHARD
config). It is recommended to useoffload_to_cpu
withrank0_only=True
to avoid redundant copies of model parameters being offloaded to the same CPU memory.with_grads (bool, Optional) – If
True
, gradients are also unsharded with the parameters. Currently, this is only supported when passinguse_orig_params=True
to the FSDP constructor andoffload_to_cpu=False
to this method. (Default:False
)
- Return type
- class torch.distributed.fsdp.BackwardPrefetch(value)[source][source]¶
This configures explicit backward prefetching, which improves throughput by enabling communication and computation overlap in the backward pass at the cost of slightly increased memory usage.
BACKWARD_PRE
: This enables the most overlap but increases memory usage the most. This prefetches the next set of parameters before the current set of parameters’ gradient computation. This overlaps the next all-gather and the current gradient computation, and at the peak, it holds the current set of parameters, next set of parameters, and current set of gradients in memory.BACKWARD_POST
: This enables less overlap but requires less memory usage. This prefetches the next set of parameters after the current set of parameters’ gradient computation. This overlaps the current reduce-scatter and the next gradient computation, and it frees the current set of parameters before allocating memory for the next set of parameters, only holding the next set of parameters and current set of gradients in memory at the peak.FSDP’s
backward_prefetch
argument acceptsNone
, which disables the backward prefetching altogether. This has no overlap and does not increase memory usage. In general, we do not recommend this setting since it may degrade throughput significantly.
For more technical context: For a single process group using NCCL backend, any collectives, even if issued from different streams, contend for the same per-device NCCL stream, which implies that the relative order in which the collectives are issued matters for overlapping. The two backward prefetching values correspond to different issue orders.
- class torch.distributed.fsdp.ShardingStrategy(value)[source][source]¶
This specifies the sharding strategy to be used for distributed training by
FullyShardedDataParallel
.FULL_SHARD
: Parameters, gradients, and optimizer states are sharded. For the parameters, this strategy unshards (via all-gather) before the forward, reshards after the forward, unshards before the backward computation, and reshards after the backward computation. For gradients, it synchronizes and shards them (via reduce-scatter) after the backward computation. The sharded optimizer states are updated locally per rank.SHARD_GRAD_OP
: Gradients and optimizer states are sharded during computation, and additionally, parameters are sharded outside computation. For the parameters, this strategy unshards before the forward, does not reshard them after the forward, and only reshards them after the backward computation. The sharded optimizer states are updated locally per rank. Insideno_sync()
, the parameters are not resharded after the backward computation.NO_SHARD
: Parameters, gradients, and optimizer states are not sharded but instead replicated across ranks similar to PyTorch’sDistributedDataParallel
API. For gradients, this strategy synchronizes them (via all-reduce) after the backward computation. The unsharded optimizer states are updated locally per rank.HYBRID_SHARD
: ApplyFULL_SHARD
within a node, and replicate parameters across nodes. This results in reduced communication volume as expensive all-gathers and reduce-scatters are only done within a node, which can be more performant for medium -sized models._HYBRID_SHARD_ZERO2
: ApplySHARD_GRAD_OP
within a node, and replicate parameters across nodes. This is likeHYBRID_SHARD
, except this may provide even higher throughput since the unsharded parameters are not freed after the forward pass, saving the all-gathers in the pre-backward.
- class torch.distributed.fsdp.MixedPrecision(param_dtype=None, reduce_dtype=None, buffer_dtype=None, keep_low_precision_grads=False, cast_forward_inputs=False, cast_root_forward_inputs=True, _module_classes_to_ignore=(<class 'torch.nn.modules.batchnorm._BatchNorm'>, ))[source][source]¶
This configures FSDP-native mixed precision training.
- Variables
param_dtype (Optional[torch.dtype]) – This specifies the dtype for model parameters during forward and backward and thus the dtype for forward and backward computation. Outside forward and backward, the sharded parameters are kept in full precision (e.g. for the optimizer step), and for model checkpointing, the parameters are always saved in full precision. (Default:
None
)reduce_dtype (Optional[torch.dtype]) – This specifies the dtype for gradient reduction (i.e. reduce-scatter or all-reduce). If this is
None
butparam_dtype
is notNone
, then this takes on theparam_dtype
value, still running gradient reduction in low precision. This is permitted to differ fromparam_dtype
, e.g. to force gradient reduction to run in full precision. (Default:None
)buffer_dtype (Optional[torch.dtype]) – This specifies the dtype for buffers. FSDP does not shard buffers. Rather, FSDP casts them to
buffer_dtype
in the first forward pass and keeps them in that dtype thereafter. For model checkpointing, the buffers are saved in full precision except forLOCAL_STATE_DICT
. (Default:None
)keep_low_precision_grads (bool) – If
False
, then FSDP upcasts gradients to full precision after the backward pass in preparation for the optimizer step. IfTrue
, then FSDP keeps the gradients in the dtype used for gradient reduction, which can save memory if using a custom optimizer that supports running in low precision. (Default:False
)cast_forward_inputs (bool) – If
True
, then this FSDP module casts its forward args and kwargs toparam_dtype
. This is to ensure that parameter and input dtypes match for forward computation, as required by many ops. This may need to be set toTrue
when only applying mixed precision to some but not all FSDP modules, in which case a mixed-precision FSDP submodule needs to recast its inputs. (Default:False
)cast_root_forward_inputs (bool) – If
True
, then the root FSDP module casts its forward args and kwargs toparam_dtype
, overriding the value ofcast_forward_inputs
. For non-root FSDP modules, this does not do anything. (Default:True
)_module_classes_to_ignore (Sequence[Type[torch.nn.modules.module.Module]]) – (Sequence[Type[nn.Module]]): This specifies module classes to ignore for mixed precision when using an
auto_wrap_policy
: Modules of these classes will have FSDP applied to them separately with mixed precision disabled (meaning that the final FSDP construction would deviate from the specified policy). Ifauto_wrap_policy
is not specified, then this does not do anything. This API is experimental and subject to change. (Default:(_BatchNorm,)
)
Note
This API is experimental and subject to change.
Note
Only floating point tensors are cast to their specified dtypes.
Note
In
summon_full_params
, parameters are forced to full precision, but buffers are not.Note
Layer norm and batch norm accumulate in
float32
even when their inputs are in a low precision likefloat16
orbfloat16
. Disabling FSDP’s mixed precision for those norm modules only means that the affine parameters are kept infloat32
. However, this incurs separate all-gathers and reduce-scatters for those norm modules, which may be inefficient, so if the workload permits, the user should prefer to still apply mixed precision to those modules.Note
By default, if the user passes a model with any
_BatchNorm
modules and specifies anauto_wrap_policy
, then the batch norm modules will have FSDP applied to them separately with mixed precision disabled. See the_module_classes_to_ignore
argument.Note
MixedPrecision
hascast_root_forward_inputs=True
andcast_forward_inputs=False
by default. For the root FSDP instance, itscast_root_forward_inputs
takes precedence over itscast_forward_inputs
. For non-root FSDP instances, theircast_root_forward_inputs
values are ignored. The default setting is sufficient for the typical case where each FSDP instance has the sameMixedPrecision
configuration and only needs to cast inputs to theparam_dtype
at the beginning of the model’s forward pass.Note
For nested FSDP instances with different
MixedPrecision
configurations, we recommend setting individualcast_forward_inputs
values to configure casting inputs or not before each instance’s forward. In such a case, since the casts happen before each FSDP instance’s forward, a parent FSDP instance should have its non-FSDP submodules run before its FSDP submodules to avoid the activation dtype being changed due to a differentMixedPrecision
configuration.Example:
>>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) >>> model[1] = FSDP( >>> model[1], >>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True), >>> ) >>> model = FSDP( >>> model, >>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), >>> )
The above shows a working example. On the other hand, if
model[1]
were replaced withmodel[0]
, meaning that the submodule using differentMixedPrecision
ran its forward first, thenmodel[1]
would incorrectly seefloat16
activations instead ofbfloat16
ones.
- class torch.distributed.fsdp.CPUOffload(offload_params=False)[source][source]¶
This configures CPU offloading.
- Variables
offload_params (bool) – This specifies whether to offload parameters to CPU when not involved in computation. If
True
, then this offloads gradients to CPU as well, meaning that the optimizer step runs on CPU.
- class torch.distributed.fsdp.StateDictConfig(offload_to_cpu=False)[source][source]¶
StateDictConfig
is the base class for allstate_dict
configuration classes. Users should instantiate a child class (e.g.FullStateDictConfig
) in order to configure settings for the correspondingstate_dict
type supported by FSDP.- Variables
offload_to_cpu (bool) – If
True
, then FSDP offloads the state dict values to CPU, and ifFalse
, then FSDP keeps them on GPU. (Default:False
)
- class torch.distributed.fsdp.FullStateDictConfig(offload_to_cpu=False, rank0_only=False)[source][source]¶
FullStateDictConfig
is a config class meant to be used withStateDictType.FULL_STATE_DICT
. We recommend enabling bothoffload_to_cpu=True
andrank0_only=True
when saving full state dicts to save GPU memory and CPU memory, respectively. This config class is meant to be used via thestate_dict_type()
context manager as follows:>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> fsdp = FSDP(model, auto_wrap_policy=...) >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): >>> state = fsdp.state_dict() >>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0. >>> # To reload checkpoint for inference, finetuning, transfer learning, etc: >>> model = model_fn() # Initialize model in preparation for wrapping with FSDP >>> if dist.get_rank() == 0: >>> # Load checkpoint only on rank 0 to avoid memory redundancy >>> state_dict = torch.load("my_checkpoint.pt") >>> model.load_state_dict(state_dict) >>> # All ranks initialize FSDP module as usual. `sync_module_states` argument >>> # communicates loaded checkpoint states from rank 0 to rest of the world. >>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True) >>> # After this point, all ranks have FSDP model with loaded checkpoint.
- Variables
rank0_only (bool) – If
True
, then only rank 0 saves the full state dict, and nonzero ranks save an empty dict. IfFalse
, then all ranks save the full state dict. (Default:False
)
- class torch.distributed.fsdp.ShardedStateDictConfig(offload_to_cpu=False, _use_dtensor=False)[source][source]¶
ShardedStateDictConfig
is a config class meant to be used withStateDictType.SHARDED_STATE_DICT
.- Variables
_use_dtensor (bool) – If
True
, then FSDP saves the state dict values asDTensor
, and ifFalse
, then FSDP saves them asShardedTensor
. (Default:False
)
Warning
_use_dtensor
is a private field ofShardedStateDictConfig
and it is used by FSDP to determine the type of state dict values. Users should not manually modify_use_dtensor
.
- class torch.distributed.fsdp.OptimStateDictConfig(offload_to_cpu=True)[source][source]¶
OptimStateDictConfig
is the base class for alloptim_state_dict
configuration classes. Users should instantiate a child class (e.g.FullOptimStateDictConfig
) in order to configure settings for the correspondingoptim_state_dict
type supported by FSDP.- Variables
offload_to_cpu (bool) – If
True
, then FSDP offloads the state dict’s tensor values to CPU, and ifFalse
, then FSDP keeps them on the original device (which is GPU unless parameter CPU offloading is enabled). (Default:True
)
- class torch.distributed.fsdp.FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False)[source][source]¶
- Variables
rank0_only (bool) – If
True
, then only rank 0 saves the full state dict, and nonzero ranks save an empty dict. IfFalse
, then all ranks save the full state dict. (Default:False
)
- class torch.distributed.fsdp.ShardedOptimStateDictConfig(offload_to_cpu=True, _use_dtensor=False)[source][source]¶
ShardedOptimStateDictConfig
is a config class meant to be used withStateDictType.SHARDED_STATE_DICT
.- Variables
_use_dtensor (bool) – If
True
, then FSDP saves the state dict values asDTensor
, and ifFalse
, then FSDP saves them asShardedTensor
. (Default:False
)
Warning
_use_dtensor
is a private field ofShardedOptimStateDictConfig
and it is used by FSDP to determine the type of state dict values. Users should not manually modify_use_dtensor
.
- class torch.distributed.fsdp.LocalOptimStateDictConfig(offload_to_cpu: bool = False)[source][source]¶
- class torch.distributed.fsdp.StateDictSettings(state_dict_type: torch.distributed.fsdp.api.StateDictType, state_dict_config: torch.distributed.fsdp.api.StateDictConfig, optim_state_dict_config: torch.distributed.fsdp.api.OptimStateDictConfig)[source][source]¶