Shortcuts

Source code for torch.distributed.fsdp.fully_sharded_data_parallel

# mypy: ignore-errors

import contextlib
import copy
import functools
import math
import traceback
import warnings
from contextlib import contextmanager
from enum import auto, Enum
from typing import (
    Any,
    Callable,
    Dict,
    Generator,
    Iterable,
    Iterator,
    List,
    Optional,
    Tuple,
    Union,
)

import torch
import torch.distributed as dist
import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.nn as nn
from torch.distributed._tensor import DeviceMesh
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    _CHECKPOINT_WRAPPED_MODULE,
    ActivationWrapper,
)
from torch.distributed.algorithms._comm_hooks import LOW_PRECISION_HOOKS
from torch.distributed.fsdp._common_utils import (
    _FSDPState,
    _get_param_to_fqns,
    FSDP_PREFIX,
    FSDP_WRAPPED_MODULE,
    HandleTrainingState,
    TrainingState,
)
from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo
from torch.distributed.fsdp._init_utils import (
    _check_orig_params_flattened,
    _init_buffer_state,
    _init_core_state,
    _init_device_handle,
    _init_extension,
    _init_ignored_module_states,
    _init_param_handle_from_module,
    _init_prefetching_state,
    _init_process_group_state,
    _init_runtime_state,
    _init_state_dict_state,
    HYBRID_SHARDING_STRATEGIES,
    ProcessGroupType,
)
from torch.distributed.fsdp._runtime_utils import (
    _get_fsdp_root_states,
    _is_fsdp_root,
    _lazy_init,
    _post_forward,
    _post_forward_reshard,
    _pre_forward,
    _pre_forward_unshard,
    _root_pre_forward,
    _unshard,
    _wait_for_computation_stream,
)
from torch.distributed.fsdp._wrap_utils import _auto_wrap
from torch.distributed.fsdp.api import (
    BackwardPrefetch,
    CPUOffload,
    FullOptimStateDictConfig,
    FullStateDictConfig,
    LocalOptimStateDictConfig,
    LocalStateDictConfig,
    MixedPrecision,
    OptimStateDictConfig,
    ShardedOptimStateDictConfig,
    ShardedStateDictConfig,
    ShardingStrategy,
    StateDictConfig,
    StateDictSettings,
    StateDictType,
)
from torch.distributed.utils import _p_assert
from ._flat_param import FlatParameter, FlatParamHandle

from ._optim_utils import (
    _flatten_optim_state_dict,
    _get_param_id_to_param_from_optim_input,
    _get_param_key_to_param,
    _get_param_to_param_id_from_optim_input,
    _get_param_to_param_key,
    _optim_state_dict,
    _rekey_sharded_optim_state_dict,
    _set_optim_use_dtensor,
)
from ._state_dict_utils import _register_all_state_dict_hooks
from ._unshard_param_utils import (
    _deregister_orig_params,
    _register_flat_param,
    _register_orig_params,
    _unshard_params,
    _unshard_params_for_summon,
)
from .wrap import CustomPolicy, ModuleWrapPolicy


__all__ = [
    "FullyShardedDataParallel",
    "OptimStateKeyType",
]


FLAT_PARAM = "_flat_param"


class OptimStateKeyType(Enum):
    """Represents the type of key in an optimizer state-dict."""

    PARAM_NAME = auto()
    PARAM_ID = auto()


[docs]class FullyShardedDataParallel(nn.Module, _FSDPState): """A wrapper for sharding module parameters across data parallel workers. This is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_. FullyShardedDataParallel is commonly shortened to FSDP. .. _`Xu et al.`: https://arxiv.org/abs/2004.13336 .. _DeepSpeed: https://www.deepspeed.ai/ To understand FSDP internals, refer to the :ref:`fsdp_notes`. Example:: >>> # xdoctest: +SKIP("undefined variables") >>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> torch.cuda.set_device(device_id) >>> sharded_module = FSDP(my_module) >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) >>> x = sharded_module(x, y=3, z=torch.Tensor([1])) >>> loss = x.sum() >>> loss.backward() >>> optim.step() Using FSDP involves wrapping your module and then initializing your optimizer after. This is required since FSDP changes the parameter variables. When setting up FSDP, you need to consider the destination CUDA device. If the device has an ID (``dev_id``), you have three options: * Place the module on that device * Set the device using ``torch.cuda.set_device(dev_id)`` * Pass ``dev_id`` into the ``device_id`` constructor argument. This ensures that the FSDP instance's compute device is the destination device. For option 1 and 3, the FSDP initialization always occurs on GPU. For option 2, the FSDP initialization happens on module's current device, which may be a CPU. If you're using the ``sync_module_states=True`` flag, you need to ensure that the module is on a GPU or use the ``device_id`` argument to specify a CUDA device that FSDP will move the module to in the FSDP constructor. This is necessary because ``sync_module_states=True`` requires GPU communication. FSDP also takes care of moving input tensors to the forward method to the GPU compute device, so you don't need to manually move them from CPU. For ``use_orig_params=True``, ``ShardingStrategy.SHARD_GRAD_OP`` exposes the unsharded parameters, not the sharded parameters after forward, unlike ``ShardingStrategy.FULL_SHARD``. If you want to inspect the gradients, you can use the ``summon_full_params`` method with ``with_grads=True``. With ``limit_all_gathers=True``, you may see a gap in the FSDP pre-forward where the CPU thread is not issuing any kernels. This is intentional and shows the rate limiter in effect. Synchronizing the CPU thread in that way prevents over-allocating memory for subsequent all-gathers, and it should not actually delay GPU kernel execution. FSDP replaces managed modules' parameters with ``torch.Tensor`` views during forward and backward computation for autograd-related reasons. If your module's forward relies on saved references to the parameters instead of reacquiring the references each iteration, then it will not see FSDP's newly created views, and autograd will not work correctly. Finally, when using ``sharding_strategy=ShardingStrategy.HYBRID_SHARD`` with the sharding process group being intra-node and the replication process group being inter-node, setting ``NCCL_CROSS_NIC=1`` can help improve the all-reduce times over the replication process group for some cluster setups. **Limitations** There are several limitations to be aware of when using FSDP: * FSDP currently does not support gradient accumulation outside ``no_sync()`` when using CPU offloading. This is because FSDP uses the newly-reduced gradient instead of accumulating with any existing gradient, which can lead to incorrect results. * FSDP does not support running the forward pass of a submodule that is contained in an FSDP instance. This is because the submodule's parameters will be sharded, but the submodule itself is not an FSDP instance, so its forward pass will not all-gather the full parameters appropriately. * FSDP does not work with double backwards due to the way it registers backward hooks. * FSDP has some constraints when freezing parameters. For ``use_orig_params=False``, each FSDP instance must manage parameters that are all frozen or all non-frozen. For ``use_orig_params=True``, FSDP supports mixing frozen and non-frozen parameters, but it's recommended to avoid doing so to prevent higher than expected gradient memory usage. * As of PyTorch 1.12, FSDP offers limited support for shared parameters. If enhanced shared parameter support is needed for your use case, please post in `this issue <https://github.com/pytorch/pytorch/issues/77724>`__. * You should avoid modifying the parameters between forward and backward without using the ``summon_full_params`` context, as the modifications may not persist. Args: module (nn.Module): This is the module to be wrapped with FSDP. process_group (Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]): This is the process group over which the model is sharded and thus the one used for FSDP's all-gather and reduce-scatter collective communications. If ``None``, then FSDP uses the default process group. For hybrid sharding strategies such as ``ShardingStrategy.HYBRID_SHARD``, users can pass in a tuple of process groups, representing the groups over which to shard and replicate, respectively. If ``None``, then FSDP constructs process groups for the user to shard intra-node and replicate inter-node. (Default: ``None``) sharding_strategy (Optional[ShardingStrategy]): This configures the sharding strategy, which may trade off memory saving and communication overhead. See :class:`ShardingStrategy` for details. (Default: ``FULL_SHARD``) cpu_offload (Optional[CPUOffload]): This configures CPU offloading. If this is set to ``None``, then no CPU offloading happens. See :class:`CPUOffload` for details. (Default: ``None``) auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy, CustomPolicy]]): This specifies a policy to apply FSDP to submodules of ``module``, which is needed for communication and computation overlap and thus affects performance. If ``None``, then FSDP only applies to ``module``, and users should manually apply FSDP to parent modules themselves (proceeding bottom-up). For convenience, this accepts ``ModuleWrapPolicy`` directly, which allows users to specify the module classes to wrap (e.g. the transformer block). Otherwise, this should be a callable that takes in three arguments ``module: nn.Module``, ``recurse: bool``, and ``nonwrapped_numel: int`` and should return a ``bool`` specifying whether the passed-in ``module`` should have FSDP applied if ``recurse=False`` or if the traversal should continue into the module's subtree if ``recurse=True``. Users may add additional arguments to the callable. The ``size_based_auto_wrap_policy`` in ``torch.distributed.fsdp.wrap.py`` gives an example callable that applies FSDP to a module if the parameters in its subtree exceed 100M numel. We recommend printing the model after applying FSDP and adjusting as needed. Example:: >>> def custom_auto_wrap_policy( >>> module: nn.Module, >>> recurse: bool, >>> nonwrapped_numel: int, >>> # Additional custom arguments >>> min_num_params: int = int(1e8), >>> ) -> bool: >>> return nonwrapped_numel >= min_num_params >>> # Configure a custom `min_num_params` >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5)) backward_prefetch (Optional[BackwardPrefetch]): This configures explicit backward prefetching of all-gathers. If ``None``, then FSDP does not backward prefetch, and there is no communication and computation overlap in the backward pass. See :class:`BackwardPrefetch` for details. (Default: ``BACKWARD_PRE``) mixed_precision (Optional[MixedPrecision]): This configures native mixed precision for FSDP. If this is set to ``None``, then no mixed precision is used. Otherwise, parameter, buffer, and gradient reduction dtypes can be set. See :class:`MixedPrecision` for details. (Default: ``None``) ignored_modules (Optional[Iterable[torch.nn.Module]]): Modules whose own parameters and child modules' parameters and buffers are ignored by this instance. None of the modules directly in ``ignored_modules`` should be :class:`FullyShardedDataParallel` instances, and any child modules that are already-constructed :class:`FullyShardedDataParallel` instances will not be ignored if they are nested under this instance. This argument may be used to avoid sharding specific parameters at module granularity when using an ``auto_wrap_policy`` or if parameters' sharding is not managed by FSDP. (Default: ``None``) param_init_fn (Optional[Callable[[nn.Module], None]]): A ``Callable[torch.nn.Module] -> None`` that specifies how modules that are currently on the meta device should be initialized onto an actual device. As of v1.12, FSDP detects modules with parameters or buffers on meta device via ``is_meta`` and either applies ``param_init_fn`` if specified or calls ``nn.Module.reset_parameters()`` otherwise. For both cases, the implementation should *only* initialize the parameters/buffers of the module, not those of its submodules. This is to avoid re-initialization. In addition, FSDP also supports deferred initialization via torchdistX's (https://github.com/pytorch/torchdistX) ``deferred_init()`` API, where the deferred modules are initialized by calling ``param_init_fn`` if specified or torchdistX's default ``materialize_module()`` otherwise. If ``param_init_fn`` is specified, then it is applied to all meta-device modules, meaning that it should probably case on the module type. FSDP calls the initialization function before parameter flattening and sharding. Example:: >>> # xdoctest: +SKIP("undefined variables") >>> module = MyModule(device="meta") >>> def my_init_fn(module: nn.Module): >>> # E.g. initialize depending on the module type >>> ... >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy) >>> print(next(fsdp_model.parameters()).device) # current CUDA device >>> # With torchdistX >>> module = deferred_init.deferred_init(MyModule, device="cuda") >>> # Will initialize via deferred_init.materialize_module(). >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy) device_id (Optional[Union[int, torch.device]]): An ``int`` or ``torch.device`` giving the CUDA device on which FSDP initialization takes place, including the module initialization if needed and the parameter sharding. This should be specified to improve initialization speed if ``module`` is on CPU. If the default CUDA device was set (e.g. via ``torch.cuda.set_device``), then the user may pass ``torch.cuda.current_device`` to this. (Default: ``None``) sync_module_states (bool): If ``True``, then each FSDP module will broadcast module parameters and buffers from rank 0 to ensure that they are replicated across ranks (adding communication overhead to this constructor). This can help load ``state_dict`` checkpoints via ``load_state_dict`` in a memory efficient way. See :class:`FullStateDictConfig` for an example of this. (Default: ``False``) forward_prefetch (bool): If ``True``, then FSDP *explicitly* prefetches the next forward-pass all-gather before the current forward computation. This is only useful for CPU-bound workloads, in which case issuing the next all-gather earlier may improve overlap. This should only be used for static-graph models since the prefetching follows the first iteration's execution order. (Default: ``False``) limit_all_gathers (bool): If ``True``, then FSDP explicitly synchronizes the CPU thread to ensure GPU memory usage from only *two* consecutive FSDP instances (the current instance running computation and the next instance whose all-gather is prefetched). If ``False``, then FSDP allows the CPU thread to issue all-gathers without any extra synchronization. (Default: ``True``) We often refer to this feature as the "rate limiter". This flag should only be set to ``False`` for specific CPU-bound workloads with low memory pressure in which case the CPU thread can aggressively issue all kernels without concern for the GPU memory usage. use_orig_params (bool): Setting this to ``True`` has FSDP use ``module`` 's original parameters. FSDP exposes those original parameters to the user via :meth:`nn.Module.named_parameters` instead of FSDP's internal :class:`FlatParameter` s. This means that the optimizer step runs on the original parameters, enabling per-original-parameter hyperparameters. FSDP preserves the original parameter variables and manipulates their data between unsharded and sharded forms, where they are always views into the underlying unsharded or sharded :class:`FlatParameter`, respectively. With the current algorithm, the sharded form is always 1D, losing the original tensor structure. An original parameter may have all, some, or none of its data present for a given rank. In the none case, its data will be like a size-0 empty tensor. Users should not author programs relying on what data is present for a given original parameter in its sharded form. ``True`` is required to use ``torch.compile()``. Setting this to ``False`` exposes FSDP's internal :class:`FlatParameter` s to the user via :meth:`nn.Module.named_parameters`. (Default: ``False``) ignored_states (Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]): Ignored parameters or modules that will not be managed by this FSDP instance, meaning that the parameters are not sharded and their gradients are not reduced across ranks. This argument unifies with the existing ``ignored_modules`` argument, and we may deprecate ``ignored_modules`` soon. For backward compatibility, we keep both ``ignored_states`` and `ignored_modules``, but FSDP only allows one of them to be specified as not ``None``. device_mesh (Optional[DeviceMesh]): DeviceMesh can be used as an altenative to process_group. When device_mesh is passed, FSDP will use the underlying process groups for all-gather and reduce-scatter collective communications. Therefore, these two args need to be mutually exclusive. For hybrid sharding strategies such as ``ShardingStrategy.HYBRID_SHARD``, users can pass in a 2D DeviceMesh instead of a tuple of process groups. For 2D FSDP + TP, users are required to pass in device_mesh instead of process_group. For more DeviceMesh info, please visit: https://pytorch.org/tutorials/recipes/distributed_device_mesh.html """ def __init__( self, module: nn.Module, process_group: ProcessGroupType = None, sharding_strategy: Optional[ShardingStrategy] = None, cpu_offload: Optional[CPUOffload] = None, auto_wrap_policy: Optional[ Union[Callable, ModuleWrapPolicy, CustomPolicy] ] = None, backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE, mixed_precision: Optional[MixedPrecision] = None, ignored_modules: Optional[Iterable[torch.nn.Module]] = None, param_init_fn: Optional[Callable[[nn.Module], None]] = None, device_id: Optional[Union[int, torch.device]] = None, sync_module_states: bool = False, forward_prefetch: bool = False, limit_all_gathers: bool = True, use_orig_params: bool = False, ignored_states: Union[ Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]] ] = None, device_mesh: Optional[DeviceMesh] = None, ): torch._C._log_api_usage_once("torch.distributed.fsdp") super().__init__() if isinstance(module, (nn.ModuleList, nn.ModuleDict)): warnings.warn( "FSDP will not all-gather parameters for containers that do " f"not implement forward: {module}", stacklevel=2, ) _init_ignored_module_states(self, module, ignored_modules, ignored_states) _init_device_handle(self, module, self._ignored_params, device_id) # Add module annotations for Dynamo support (see function for details) _annotate_modules_for_dynamo(module, self._ignored_modules, use_orig_params) # Initializes self.process_group, along with rank and world size. This will # also set another attribute, _inter_node_pg, to control the process group # over which sharding occurs, if sharding_strategy is {HYBRID_SHARD, _HYBRID_SHARD_ZERO2}. # Note that this is done before auto_wrapping, so that child FSDP modules simply pick up # the same process group state as the root FSDP module. self._device_mesh = device_mesh _init_process_group_state( self, process_group, sharding_strategy, auto_wrap_policy, device_mesh, ) if auto_wrap_policy is not None: root_kwargs = { "process_group": process_group, "sharding_strategy": sharding_strategy, "cpu_offload": cpu_offload, "backward_prefetch": backward_prefetch, "mixed_precision": mixed_precision, "param_init_fn": param_init_fn, "device_id": device_id, "sync_module_states": sync_module_states, "forward_prefetch": forward_prefetch, "limit_all_gathers": limit_all_gathers, "use_orig_params": use_orig_params, "ignored_states": self._ignored_params, "device_mesh": device_mesh, } if sharding_strategy in HYBRID_SHARDING_STRATEGIES and device_mesh is None: # Share root process groups with children to maintain # the invariant that all FSDP modules will have the same # process groups. root_kwargs["process_group"] = (self.process_group, self._inter_node_pg) _auto_wrap( module, auto_wrap_policy, self._ignored_modules, self._ignored_params, root_kwargs, FullyShardedDataParallel, ) backward_prefetch_limit = 1 forward_prefetch_limit = 1 _init_core_state( self, sharding_strategy, mixed_precision, cpu_offload, limit_all_gathers, use_orig_params, backward_prefetch_limit, forward_prefetch_limit, ) _init_runtime_state(self) _init_prefetching_state(self, backward_prefetch, forward_prefetch) _init_buffer_state(self, module) # extension needs to be set before `_init_param_handle_from_module()` _init_extension(self, device_mesh) _init_param_handle_from_module( self, module, device_id, param_init_fn, sync_module_states, ) self._fsdp_wrapped_module = module if not use_orig_params: _check_orig_params_flattened(self, self._ignored_params) _register_flat_param(self, self) # `_state_dict_type` controls the `state_dict()` behavior, which is # implemented using post-save and pre-load hooks _init_state_dict_state(self) _register_all_state_dict_hooks(self) @property def module(self) -> nn.Module: """Return the wrapped module.""" # FSDP's `.module` must refer to the innermost wrapped module when # composing with other module wrappers in order for state dict to work if isinstance(self._fsdp_wrapped_module, ActivationWrapper): return getattr(self._fsdp_wrapped_module, _CHECKPOINT_WRAPPED_MODULE) return self._fsdp_wrapped_module @property def _has_params(self) -> bool: """Returns whether this FSDP instance manages any parameters.""" return hasattr(self, "_handle") and self._handle is not None @property def _flat_param(self) -> Optional[FlatParameter]: return self._handle.flat_param if self._handle else None def __getattr__(self, name: str) -> Any: """Forward missing attributes to the wrapped module.""" try: return super().__getattr__(name) # defer to nn.Module's logic except AttributeError: return getattr(self._fsdp_wrapped_module, name) def __getitem__(self, key: int) -> Any: """Forward indexing calls in case the module is an ``nn.Sequential``.""" if hasattr(self, FSDP_WRAPPED_MODULE): return self._fsdp_wrapped_module.__getitem__(key) # type: ignore[operator] return super().__getitem__(key)
[docs] def check_is_root(self) -> bool: """Check if this instance is a root FSDP module.""" return _is_fsdp_root(self, self)
[docs] @staticmethod def fsdp_modules( module: nn.Module, root_only: bool = False, ) -> List["FullyShardedDataParallel"]: """Return all nested FSDP instances. This possibly includes ``module`` itself and only includes FSDP root modules if ``root_only=True``. Args: module (torch.nn.Module): Root module, which may or may not be an ``FSDP`` module. root_only (bool): Whether to return only FSDP root modules. (Default: ``False``) Returns: List[FullyShardedDataParallel]: FSDP modules that are nested in the input ``module``. """ if root_only: return _get_fsdp_root_states(module) return traversal_utils._get_fsdp_states(module)
[docs] def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel": r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). Compared to ``torch.nn.Module.apply``, this version additionally gathers the full parameters before applying ``fn``. It should not be called from within another ``summon_full_params`` context. Args: fn (:class:`Module` -> None): function to be applied to each submodule Returns: Module: self """ uninitialized = self._is_root is None self._assert_state(TrainingState.IDLE) # Use `_unshard_params_for_summon()` with `recurse=False` instead of # `_unshard_fsdp_state_params()` directly to perform lazy # initialization, which is needed to initialize `FlatParameter` # parameter attributes as required by the unshard logic with _unshard_params_for_summon( self, self, writeback=True, rank0_only=False, offload_to_cpu=False, with_grads=False, ): ret = super().apply(fn) # Reset lazy init called in `_unshard_params_for_summon()` since # `apply()` may have been called on FSDP instance that is not truly a # root, in which case it will be incorrectly marked as one. if uninitialized and self._is_root: for module in traversal_utils._get_fsdp_states(self): module._reset_lazy_init() return ret
def _mixed_precision_enabled_for_buffers(self) -> bool: """Return whether the user explicitly enabled buffer mixed precision. NOTE: Unlike parameters and gradient reduction, buffer mixed precision is applied at the FSDP instance level, not the ``FlatParameter`` level, which may be different for the composable code path. """ return self.mixed_precision.buffer_dtype is not None def _low_precision_hook_enabled(self) -> bool: """Whether a low precision hook is registered or not.""" return self._comm_hook is not None and self._comm_hook in LOW_PRECISION_HOOKS def _reset_lazy_init(self) -> None: """Reset instance so :func:`_lazy_init` will run on the next forward.""" self._is_root: Optional[bool] = None
[docs] @staticmethod def set_state_dict_type( module: nn.Module, state_dict_type: StateDictType, state_dict_config: Optional[StateDictConfig] = None, optim_state_dict_config: Optional[OptimStateDictConfig] = None, ) -> StateDictSettings: """Set the ``state_dict_type`` of all the descendant FSDP modules of the target module. Also takes (optional) configuration for the model's and optimizer's state dict. The target module does not have to be a FSDP module. If the target module is a FSDP module, its ``state_dict_type`` will also be changed. .. note:: This API should be called for only the top-level (root) module. .. note:: This API enables users to transparently use the conventional ``state_dict`` API to take model checkpoints in cases where the root FSDP module is wrapped by another ``nn.Module``. For example, the following will ensure ``state_dict`` is called on all non-FSDP instances, while dispatching into `sharded_state_dict` implementation for FSDP: Example:: >>> # xdoctest: +SKIP("undefined variables") >>> model = DDP(FSDP(...)) >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.SHARDED_STATE_DICT, >>> state_dict_config = ShardedStateDictConfig(offload_to_cpu=True), >>> optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True), >>> ) >>> param_state_dict = model.state_dict() >>> optim_state_dict = FSDP.optim_state_dict(model, optim) Args: module (torch.nn.Module): Root module. state_dict_type (StateDictType): the desired ``state_dict_type`` to set. state_dict_config (Optional[StateDictConfig]): the configuration for the target ``state_dict_type``. optim_state_dict_config (Optional[OptimStateDictConfig]): the configuration for the optimizer state dict. Returns: A StateDictSettings that include the previous state_dict type and configuration for the module. """ warnings.warn( "FSDP.state_dict_type() and FSDP.set_state_dict_type() are being " "deprecated. Please use APIs, get_state_dict() and set_state_dict(), " "which can support different parallelisms, FSDP1, FSDP2, DDP. " "API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html" "#torch.distributed.checkpoint.state_dict.get_state_dict ." "Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .", FutureWarning, ) _state_dict_type_to_config = { StateDictType.FULL_STATE_DICT: FullStateDictConfig, StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig, StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig, } _optim_state_dict_type_to_config = { StateDictType.FULL_STATE_DICT: FullOptimStateDictConfig, StateDictType.LOCAL_STATE_DICT: LocalOptimStateDictConfig, StateDictType.SHARDED_STATE_DICT: ShardedOptimStateDictConfig, } # Use the default config if a state_dict config is not set. state_dict_config_type = _state_dict_type_to_config[state_dict_type] optim_state_dict_config_type = _optim_state_dict_type_to_config[state_dict_type] if state_dict_config is None: state_dict_config = state_dict_config_type() if optim_state_dict_config is None: optim_state_dict_config = optim_state_dict_config_type() if state_dict_config_type != type(state_dict_config): raise RuntimeError( f"Expected state_dict_config of type {state_dict_config_type} " f"but got {type(state_dict_config)}" ) if optim_state_dict_config_type != type(optim_state_dict_config): raise RuntimeError( f"Expected optim_state_dict_config of type {optim_state_dict_config_type} " f"but got {type(optim_state_dict_config)}" ) # Set the state_dict type and configurations. prev_state_dict_type = None prev_state_dict_config = None prev_optim_state_dict_config = None for submodule in traversal_utils._get_fsdp_states(module): if prev_state_dict_type is None: prev_state_dict_type = submodule._state_dict_type else: assert ( prev_state_dict_type == submodule._state_dict_type ), "All FSDP modules should have the same state_dict_type." if prev_state_dict_config is None: prev_state_dict_config = submodule._state_dict_config else: assert isinstance( submodule._state_dict_config, type(prev_state_dict_config) ), "All FSDP modules must have the same type of state_dict_config." if prev_optim_state_dict_config is None: prev_optim_state_dict_config = submodule._optim_state_dict_config else: assert isinstance( submodule._optim_state_dict_config, type(prev_optim_state_dict_config), ), "All FSDP modules must have the same type of optim_state_dict_config." submodule._state_dict_type = state_dict_type submodule._state_dict_config = state_dict_config submodule._optim_state_dict_config = optim_state_dict_config return StateDictSettings( prev_state_dict_type, prev_state_dict_config, prev_optim_state_dict_config )
[docs] @staticmethod def get_state_dict_type(module: nn.Module) -> StateDictSettings: """Get the state_dict_type and the corresponding configurations for the FSDP modules rooted at ``module``. The target module does not have to be an FSDP module. Returns: A ``StateDictSettings`` containing the state_dict_type and state_dict / optim_state_dict configs that are currently set. Raises: ``AssertionError`` if the ``StateDictSettings`` for different FSDP submodules differ. """ state_dict_settings: Optional[StateDictSettings] = None for submodule in FullyShardedDataParallel.fsdp_modules(module): if state_dict_settings is None: state_dict_settings = StateDictSettings( state_dict_type=submodule._state_dict_type, state_dict_config=submodule._state_dict_config, optim_state_dict_config=submodule._optim_state_dict_config, ) _set_optim_use_dtensor(submodule, state_dict_settings) else: submodule_settings = StateDictSettings( submodule._state_dict_type, submodule._state_dict_config, submodule._optim_state_dict_config, ) assert state_dict_settings == submodule_settings, ( "All FSDP modules must have the same state dict settings." f"Got {submodule_settings} and {state_dict_settings}." ) _set_optim_use_dtensor(submodule, submodule_settings) return state_dict_settings
[docs] @staticmethod @contextlib.contextmanager def state_dict_type( module: nn.Module, state_dict_type: StateDictType, state_dict_config: Optional[StateDictConfig] = None, optim_state_dict_config: Optional[OptimStateDictConfig] = None, ) -> Generator: """Set the ``state_dict_type`` of all the descendant FSDP modules of the target module. This context manager has the same functions as :meth:`set_state_dict_type`. Read the document of :meth:`set_state_dict_type` for the detail. Example:: >>> # xdoctest: +SKIP("undefined variables") >>> model = DDP(FSDP(...)) >>> with FSDP.state_dict_type( >>> model, >>> StateDictType.SHARDED_STATE_DICT, >>> ): >>> checkpoint = model.state_dict() Args: module (torch.nn.Module): Root module. state_dict_type (StateDictType): the desired ``state_dict_type`` to set. state_dict_config (Optional[StateDictConfig]): the model ``state_dict`` configuration for the target ``state_dict_type``. optim_state_dict_config (Optional[OptimStateDictConfig]): the optimizer ``state_dict`` configuration for the target ``state_dict_type``. """ prev_state_dict_settings = FullyShardedDataParallel.set_state_dict_type( module, state_dict_type, state_dict_config, optim_state_dict_config, ) yield FullyShardedDataParallel.set_state_dict_type( module, prev_state_dict_settings.state_dict_type, prev_state_dict_settings.state_dict_config, prev_state_dict_settings.optim_state_dict_config, )
[docs] def forward(self, *args: Any, **kwargs: Any) -> Any: """Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic.""" handle = self._handle with torch.autograd.profiler.record_function( "FullyShardedDataParallel.forward" ): args, kwargs = _root_pre_forward(self, self, args, kwargs) unused = None args, kwargs = _pre_forward( self, handle, _pre_forward_unshard, self._fsdp_wrapped_module, args, kwargs, ) if handle: _p_assert( handle.flat_param.device == self.compute_device, "Expected `FlatParameter` to be on the compute device " f"{self.compute_device} but got {handle.flat_param.device}", ) output = self._fsdp_wrapped_module(*args, **kwargs) return _post_forward( self, handle, _post_forward_reshard, self, unused, output )
[docs] @staticmethod @contextlib.contextmanager def summon_full_params( module: nn.Module, recurse: bool = True, writeback: bool = True, rank0_only: bool = False, offload_to_cpu: bool = False, with_grads: bool = False, ) -> Generator: r"""Expose full params for FSDP instances with this context manager. Can be useful *after* forward/backward for a model to get the params for additional processing or checking. It can take a non-FSDP module and will summon full params for all contained FSDP modules as well as their children, depending on the ``recurse`` argument. .. note:: This can be used on inner FSDPs. .. note:: This can *not* be used within a forward or backward pass. Nor can forward and backward be started from within this context. .. note:: Parameters will revert to their local shards after the context manager exits, storage behavior is the same as forward. .. note:: The full parameters can be modified, but only the portion corresponding to the local param shard will persist after the context manager exits (unless ``writeback=False``, in which case changes will be discarded). In the case where FSDP does not shard the parameters, currently only when ``world_size == 1``, or ``NO_SHARD`` config, the modification is persisted regardless of ``writeback``. .. note:: This method works on modules which are not FSDP themselves but may contain multiple independent FSDP units. In that case, the given arguments will apply to all contained FSDP units. .. warning:: Note that ``rank0_only=True`` in conjunction with ``writeback=True`` is not currently supported and will raise an error. This is because model parameter shapes would be different across ranks within the context, and writing to them can lead to inconsistency across ranks when the context is exited. .. warning:: Note that ``offload_to_cpu`` and ``rank0_only=False`` will result in full parameters being redundantly copied to CPU memory for GPUs that reside on the same machine, which may incur the risk of CPU OOM. It is recommended to use ``offload_to_cpu`` with ``rank0_only=True``. Args: recurse (bool, Optional): recursively summon all params for nested FSDP instances (default: True). writeback (bool, Optional): if ``False``, modifications to params are discarded after the context manager exits; disabling this can be slightly more efficient (default: True) rank0_only (bool, Optional): if ``True``, full parameters are materialized on only global rank 0. This means that within the context, only rank 0 will have full parameters and the other ranks will have sharded parameters. Note that setting ``rank0_only=True`` with ``writeback=True`` is not supported, as model parameter shapes will be different across ranks within the context, and writing to them can lead to inconsistency across ranks when the context is exited. offload_to_cpu (bool, Optional): If ``True``, full parameters are offloaded to CPU. Note that this offloading currently only occurs if the parameter is sharded (which is only not the case for world_size = 1 or ``NO_SHARD`` config). It is recommended to use ``offload_to_cpu`` with ``rank0_only=True`` to avoid redundant copies of model parameters being offloaded to the same CPU memory. with_grads (bool, Optional): If ``True``, gradients are also unsharded with the parameters. Currently, this is only supported when passing ``use_orig_params=True`` to the FSDP constructor and ``offload_to_cpu=False`` to this method. (Default: ``False``) """ with _unshard_params( module, recurse, writeback, rank0_only, offload_to_cpu, with_grads ): yield
@contextlib.contextmanager def _deregister_orig_params_ctx(self): """Deregister the original parameters and expose the :class:`FlatParameter`. If a :class:`FlatParameter` is sharded, then this refreshes the sharded views before exiting. This method should only be called when using the original parameters. """ _p_assert( self._use_orig_params, "`_deregister_orig_params_ctx()` should only be called when " "`_use_orig_params=True`", ) for fsdp_module in traversal_utils._get_fsdp_states(self): _deregister_orig_params(fsdp_module, fsdp_module) try: yield finally: for fsdp_module in traversal_utils._get_fsdp_states(self): _register_orig_params(fsdp_module, fsdp_module) def _apply(self, *args, **kwargs): """Deregister the original parameters and expose the :class:`FlatParameter` s before calling ``_apply()``.""" # When using the original parameters: Since (1) the `FlatParameter`s # own the storage and (2) `_apply()` is the subroutine underlying the # most common storage-changing ops like `to()` and `cuda()`, we # override `_apply()` to have the storage change directly performed on # the `FlatParameter`s instead of applying to the original parameters # and then writing back to the `FlatParameter`s. context = ( self._deregister_orig_params_ctx() if self._use_orig_params else contextlib.nullcontext() ) with context: return super()._apply(*args, **kwargs)
[docs] def named_buffers( self, *args, **kwargs, ) -> Iterator[Tuple[str, torch.Tensor]]: """Return an iterator over module buffers, yielding both the name of the buffer and the buffer itself. Intercepts buffer names and removes all occurrences of the FSDP-specific flattened buffer prefix when inside the :meth:`summon_full_params` context manager. """ should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS for buffer_name, buffer in super().named_buffers(*args, **kwargs): if should_clean_name: # Remove any instances of the FSDP-specific prefix; there can # be multiple in the case of nested FSDP modules buffer_name = buffer_name.replace(FSDP_PREFIX, "") yield (buffer_name, buffer)
[docs] def named_parameters( self, *args, **kwargs, ) -> Iterator[Tuple[str, torch.nn.Parameter]]: """Return an iterator over module parameters, yielding both the name of the parameter and the parameter itself. Intercepts parameter names and removes all occurrences of the FSDP-specific flattened parameter prefix when inside the :meth:`summon_full_params` context manager. """ should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS for param_name, param in super().named_parameters(*args, **kwargs): if should_clean_name: # Remove any instances of the FSDP-specific prefix; there can # be multiple in the case of nested FSDP modules param_name = param_name.replace(FSDP_PREFIX, "") yield (param_name, param)
def _assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None: """Assert we are in the given state.""" # Since assert can be turned off and this error checking # is really important, we use explicit error checking # and raise a ValueError if needed. if isinstance(state, TrainingState): state = [state] if self.training_state not in state: msg = ( f"expected to be in states {state} but current state " f"is {self.training_state}" ) # In case we are failing in the context of autograd hook, asserting # may not generate useful msg. So, let's print it to be sure. if self.rank == 0: print(f"Asserting FSDP instance is: {self}") print(f"ERROR: {msg}") traceback.print_stack() raise ValueError(msg)
[docs] @contextmanager def no_sync(self) -> Generator: """Disable gradient synchronizations across FSDP instances. Within this context, gradients will be accumulated in module variables, which will later be synchronized in the first forward-backward pass after exiting the context. This should only be used on the root FSDP instance and will recursively apply to all children FSDP instances. .. note:: This likely results in higher memory usage because FSDP will accumulate the full model gradients (instead of gradient shards) until the eventual sync. .. note:: When used with CPU offloading, the gradients will not be offloaded to CPU when inside the context manager. Instead, they will only be offloaded right after the eventual sync. """ _lazy_init(self, self) if not self._is_root: raise RuntimeError( "`no_sync()` on inner FSDP instances is not supported. Please call `no_sync()` on root FSDP module." ) self._assert_state(TrainingState.IDLE) old_flags = [] for m in self.modules(): if isinstance(m, FullyShardedDataParallel): old_flags.append((m, m._sync_gradients)) m._sync_gradients = False try: yield finally: for m, old_flag in old_flags: assert not m._sync_gradients, ( "`_sync_gradients` was incorrectly set to " "`True` while in the `no_sync()` context manager" ) m._sync_gradients = old_flag
[docs] @torch.no_grad() def clip_grad_norm_( self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0 ) -> torch.Tensor: """Clip the gradient norm of all parameters. The norm is computed over all parameters' gradients as viewed as a single vector, and the gradients are modified in-place. Args: max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. Returns: Total norm of the parameters (viewed as a single vector). If every FSDP instance uses ``NO_SHARD``, meaning that no gradients are sharded across ranks, then you may directly use :func:`torch.nn.utils.clip_grad_norm_`. If at least some FSDP instance uses a sharded strategy (i.e. one other than ``NO_SHARD``), then you should use this method instead of :func:`torch.nn.utils.clip_grad_norm_` since this method handles the fact that gradients are sharded across ranks. The total norm returned will have the "largest" dtype across all parameters/gradients as defined by PyTorch's type promotion semantics. For example, if *all* parameters/gradients use a low precision dtype, then the returned norm's dtype will be that low precision dtype, but if there exists at least one parameter/ gradient using FP32, then the returned norm's dtype will be FP32. .. warning:: This needs to be called on all ranks since it uses collective communications. """ _lazy_init(self, self) if not self._is_root: raise RuntimeError( "`clip_grad_norm_()` should only be called on the root FSDP instance" ) self._assert_state(TrainingState.IDLE) # If every FSDP instance uses `NO_SHARD`, then we can directly use # the normal `nn.utils` one targeting local gradients all_no_shard = all( not handle.uses_sharded_strategy for handle in self._all_handles ) if all_no_shard: return torch.nn.utils.clip_grad_norm_( self.parameters(), max_norm, norm_type ) # Otherwise, there exists some FSDP instance using a sharded strategy, # where sharded and non-sharded parameters must be handled separately max_norm = float(max_norm) norm_type = float(norm_type) sharded_params = set() nonsharded_params = set() # `NO_SHARD` or not FSDP-managed grads: List[torch.Tensor] = [] for handle in self._all_handles: target_set = ( sharded_params if handle.uses_sharded_strategy else nonsharded_params ) if handle._use_orig_params: for param in handle.flat_param._params: target_set.add(param) if param.grad is not None: grads.append(param.grad) else: target_set.add(handle.flat_param) if handle.flat_param.grad is not None: grads.append(handle.flat_param.grad) for param in self.parameters(): not_fsdp_managed = ( param not in sharded_params and param not in nonsharded_params ) if not_fsdp_managed: nonsharded_params.add(param) if param.grad is not None: grads.append(param.grad) # Compute local norms (forced to be in FP32) local_sharded_norm = _get_grad_norm(sharded_params, norm_type).to( self.compute_device ) local_nonsharded_norm = ( _get_grad_norm(nonsharded_params, norm_type).to(self.compute_device) if nonsharded_params else None ) # Reconstruct the total gradient norm depending on the norm type if norm_type == math.inf: total_norm = ( torch.maximum(local_sharded_norm, local_nonsharded_norm) if local_nonsharded_norm is not None else local_sharded_norm ) dist.all_reduce( total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group ) else: total_norm = local_sharded_norm**norm_type dist.all_reduce(total_norm, group=self.process_group) # All-reducing the local non-sharded norm would count it an extra # world-size-many times if local_nonsharded_norm is not None: total_norm += local_nonsharded_norm**norm_type total_norm = total_norm ** (1.0 / norm_type) if self.cpu_offload.offload_params: total_norm = total_norm.cpu() clip_coef = max_norm / (total_norm + 1e-6) # Multiplying by the clamped coefficient is meaningless when it is # equal to 1, but it avoids the host-device sync that would result from # `if clip_coef < 1` clip_coef_clamped = torch.clamp(clip_coef, max=1.0) for grad in grads: grad.mul_(clip_coef_clamped.to(grad.device, grad.dtype)) # Use the "largest" dtype by type promotion semantics to use the same # dtype as if we did not force local norm computation to be in FP32 if len(grads) == 0: # If this rank has no gradients, then we must default to FP32 # unless we use additional communication, which we prefer to avoid # since `clip_grad_norm_()` is called in the training loop warnings.warn( f"Called FSDP.clip_grad_norm_() on rank {self.rank} with no " "gradients -- returning the total norm in the default dtype " f"{total_norm.dtype}" ) # warn since this is generally unexpected return total_norm total_norm_dtype = functools.reduce( torch.promote_types, [grad.dtype for grad in grads], ) return total_norm.to(total_norm_dtype)
@staticmethod def _warn_optim_input(optim_input, *, stacklevel: int = 1): if optim_input is not None: warnings.warn( "The `optim_input` argument is deprecated and will be removed after PyTorch 1.13. " "You may remove it from your code without changing its functionality.", FutureWarning, stacklevel=stacklevel + 1, ) @staticmethod def _is_using_optim_input(optim_input, optim) -> bool: if optim_input is None and optim is None: # Use the default behavior of `optim_input`` return True if optim_input is not None: # Use the `optim_input` code path return True # Use the `optim` code path return False @staticmethod def _warn_legacy_optim_state_dict(curr: str, new: str, *, stacklevel: int = 1): warnings.warn( f"``FullyShardedDataParallel.{curr}``is being deprecated and is " f"replaced by ``FullyShardedDataParallel.{new}``. " f"``FullyShardedDataParallel.{curr}`` may be removed after PyTorch 2.2.", FutureWarning, stacklevel=stacklevel + 1, ) @staticmethod def _optim_state_dict_impl( model: torch.nn.Module, optim: torch.optim.Optimizer, optim_state_dict: Dict[str, Any], optim_input: Optional[ Union[ List[Dict[str, Any]], Iterable[torch.nn.Parameter], ] ] = None, rank0_only: bool = True, full_state_dict: bool = True, group: Optional[dist.ProcessGroup] = None, cpu_offload: bool = True, *, _stacklevel: int = 1, ) -> Dict[str, Any]: """Transform the state-dict of an optimizer corresponding to a sharded model. This is the internal API that is used by all the optim_state_dict implementations. Given model, optim, the original optim_state_dict, this API removes the FSDP internal information and internal sharding from the optim_state_dict. """ if full_state_dict: FullyShardedDataParallel._warn_optim_input( optim_input, stacklevel=_stacklevel + 1 ) using_optim_input = FullyShardedDataParallel._is_using_optim_input( optim_input, optim, ) else: using_optim_input = False assert optim_input is None and not rank0_only use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[ 0 ]._use_orig_params assert all( use_orig_params == m._use_orig_params for m in FullyShardedDataParallel.fsdp_modules(model) ), "Not all FSDP modules have the same _use_orig_params value" return _optim_state_dict( model=model, optim=optim, optim_state_dict=optim_state_dict, optim_input=optim_input, rank0_only=rank0_only, shard_state=not full_state_dict, group=group, using_optim_input=using_optim_input, use_orig_params=use_orig_params, cpu_offload=cpu_offload, ) @staticmethod def _optim_state_dict_to_load_impl( optim_state_dict: Dict[str, Any], model: torch.nn.Module, optim_input: Optional[ Union[ List[Dict[str, Any]], Iterable[torch.nn.Parameter], ] ] = None, optim: Optional[torch.optim.Optimizer] = None, full_state_dict: bool = True, rank0_only: bool = False, is_named_optimizer: bool = False, group: Optional[dist.ProcessGroup] = None, ) -> Dict[str, Any]: """ Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model. This is the internal API that is used by all the load optim_state_dict implementations. Given model, optim, and the saved optim_state_dict, this API adds the FSDP internal information and internal sharding to the optim_state_dict. """ if full_state_dict: FullyShardedDataParallel._warn_optim_input(optim_input) using_optim_input = FullyShardedDataParallel._is_using_optim_input( optim_input, optim, ) else: using_optim_input = False assert optim_input is None and not rank0_only use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[ 0 ]._use_orig_params assert all( use_orig_params == m._use_orig_params for m in FullyShardedDataParallel.fsdp_modules(model) ), "Not all FSDP modules have the same _use_orig_params value" if rank0_only and dist.get_rank(group) > 0: optim_state_dict = {} sharded_osd = _flatten_optim_state_dict( optim_state_dict, model=model, use_orig_params=use_orig_params, optim=(optim if is_named_optimizer else None), rank0_only=rank0_only, group=group, ) return _rekey_sharded_optim_state_dict( sharded_osd, model=model, optim=optim, optim_input=optim_input, using_optim_input=using_optim_input, is_named_optimizer=is_named_optimizer, )
[docs] @staticmethod def full_optim_state_dict( model: torch.nn.Module, optim: torch.optim.Optimizer, optim_input: Optional[ Union[ List[Dict[str, Any]], Iterable[torch.nn.Parameter], ] ] = None, rank0_only: bool = True, group: Optional[dist.ProcessGroup] = None, ) -> Dict[str, Any]: """Return the full optimizer state-dict. Consolidates the full optimizer state on rank 0 and returns it as a :class:`dict` following the convention of :meth:`torch.optim.Optimizer.state_dict`, i.e. with keys ``"state"`` and ``"param_groups"``. The flattened parameters in ``FSDP`` modules contained in ``model`` are mapped back to their unflattened parameters. This needs to be called on all ranks since it uses collective communications. However, if ``rank0_only=True``, then the state dict is only populated on rank 0, and all other ranks return an empty :class:`dict`. Unlike ``torch.optim.Optimizer.state_dict()``, this method uses full parameter names as keys instead of parameter IDs. Like in :meth:`torch.optim.Optimizer.state_dict`, the tensors contained in the optimizer state dict are not cloned, so there may be aliasing surprises. For best practices, consider saving the returned optimizer state dict immediately, e.g. using ``torch.save()``. Args: model (torch.nn.Module): Root module (which may or may not be a :class:`FullyShardedDataParallel` instance) whose parameters were passed into the optimizer ``optim``. optim (torch.optim.Optimizer): Optimizer for ``model`` 's parameters. optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): Input passed into the optimizer ``optim`` representing either a :class:`list` of parameter groups or an iterable of parameters; if ``None``, then this method assumes the input was ``model.parameters()``. This argument is deprecated, and there is no need to pass it in anymore. (Default: ``None``) rank0_only (bool): If ``True``, saves the populated :class:`dict` only on rank 0; if ``False``, saves it on all ranks. (Default: ``True``) group (dist.ProcessGroup): Model's process group or ``None`` if using the default process group. (Default: ``None``) Returns: Dict[str, Any]: A :class:`dict` containing the optimizer state for ``model`` 's original unflattened parameters and including keys "state" and "param_groups" following the convention of :meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=True``, then nonzero ranks return an empty :class:`dict`. """ FullyShardedDataParallel._warn_legacy_optim_state_dict( "full_optim_state_dict", "optim_state_dict", stacklevel=2, ) return FullyShardedDataParallel._optim_state_dict_impl( model=model, optim=optim, optim_state_dict=optim.state_dict(), optim_input=optim_input, rank0_only=rank0_only, group=group, full_state_dict=True, _stacklevel=2, )
[docs] @staticmethod def sharded_optim_state_dict( model: torch.nn.Module, optim: torch.optim.Optimizer, group: Optional[dist.ProcessGroup] = None, ) -> Dict[str, Any]: """Return the optimizer state-dict in its sharded form. The API is similar to :meth:`full_optim_state_dict` but this API chunks all non-zero-dimension states to :class:`ShardedTensor` to save memory. This API should only be used when the model ``state_dict`` is derived with the context manager ``with state_dict_type(SHARDED_STATE_DICT):``. For the detailed usage, refer to :meth:`full_optim_state_dict`. .. warning:: The returned state dict contains ``ShardedTensor`` and cannot be directly used by the regular ``optim.load_state_dict``. """ FullyShardedDataParallel._warn_legacy_optim_state_dict( "sharded_optim_state_dict", "optim_state_dict", stacklevel=2, ) return FullyShardedDataParallel._optim_state_dict_impl( model=model, optim=optim, optim_state_dict=optim.state_dict(), optim_input=None, rank0_only=False, full_state_dict=False, group=group, _stacklevel=2, )
[docs] @staticmethod def shard_full_optim_state_dict( full_optim_state_dict: Dict[str, Any], model: torch.nn.Module, optim_input: Optional[ Union[ List[Dict[str, Any]], Iterable[torch.nn.Parameter], ] ] = None, optim: Optional[torch.optim.Optimizer] = None, ) -> Dict[str, Any]: """Shard a full optimizer state-dict. Remaps the state in ``full_optim_state_dict`` to flattened parameters instead of unflattened parameters and restricts to only this rank's part of the optimizer state. The first argument should be the return value of :meth:`full_optim_state_dict`. Example:: >>> # xdoctest: +SKIP("undefined variables") >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) >>> torch.save(full_osd, PATH) >>> # Define new model with possibly different world size >>> new_model, new_optim = ... >>> full_osd = torch.load(PATH) >>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model) >>> new_optim.load_state_dict(sharded_osd) .. note:: Both :meth:`shard_full_optim_state_dict` and :meth:`scatter_full_optim_state_dict` may be used to get the sharded optimizer state dict to load. Assuming that the full optimizer state dict resides in CPU memory, the former requires each rank to have the full dict in CPU memory, where each rank individually shards the dict without any communication, while the latter requires only rank 0 to have the full dict in CPU memory, where rank 0 moves each shard to GPU memory (for NCCL) and communicates it to ranks appropriately. Hence, the former has higher aggregate CPU memory cost, while the latter has higher communication cost. Args: full_optim_state_dict (Dict[str, Any]): Optimizer state dict corresponding to the unflattened parameters and holding the full non-sharded optimizer state. model (torch.nn.Module): Root module (which may or may not be a :class:`FullyShardedDataParallel` instance) whose parameters correspond to the optimizer state in ``full_optim_state_dict``. optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): Input passed into the optimizer representing either a :class:`list` of parameter groups or an iterable of parameters; if ``None``, then this method assumes the input was ``model.parameters()``. This argument is deprecated, and there is no need to pass it in anymore. (Default: ``None``) optim (Optional[torch.optim.Optimizer]): Optimizer that will load the state dict returned by this method. This is the preferred argument to use over ``optim_input``. (Default: ``None``) Returns: Dict[str, Any]: The full optimizer state dict now remapped to flattened parameters instead of unflattened parameters and restricted to only include this rank's part of the optimizer state. """ FullyShardedDataParallel._warn_legacy_optim_state_dict( "shard_full_optim_state_dict", "optim_state_dict_to_load", stacklevel=2, ) return FullyShardedDataParallel._optim_state_dict_to_load_impl( optim_state_dict=full_optim_state_dict, model=model, optim_input=optim_input, optim=optim, full_state_dict=True, is_named_optimizer=False, )
[docs] @staticmethod def flatten_sharded_optim_state_dict( sharded_optim_state_dict: Dict[str, Any], model: torch.nn.Module, optim: torch.optim.Optimizer, ) -> Dict[str, Any]: """Flatten a sharded optimizer state-dict. The API is similar to :meth:`shard_full_optim_state_dict`. The only difference is that the input ``sharded_optim_state_dict`` should be returned from :meth:`sharded_optim_state_dict`. Therefore, there will be all-gather calls on each rank to gather ``ShardedTensor`` s. Args: sharded_optim_state_dict (Dict[str, Any]): Optimizer state dict corresponding to the unflattened parameters and holding the sharded optimizer state. model (torch.nn.Module): Refer to :meth:`shard_full_optim_state_dict`. optim (torch.optim.Optimizer): Optimizer for ``model`` 's parameters. Returns: Refer to :meth:`shard_full_optim_state_dict`. """ FullyShardedDataParallel._warn_legacy_optim_state_dict( "flatten_sharded_optim_state_dict", "optim_state_dict_to_load", stacklevel=2, ) return FullyShardedDataParallel._optim_state_dict_to_load_impl( optim_state_dict=sharded_optim_state_dict, model=model, optim_input=None, optim=optim, full_state_dict=False, is_named_optimizer=False, )
[docs] @staticmethod def scatter_full_optim_state_dict( full_optim_state_dict: Optional[Dict[str, Any]], model: torch.nn.Module, optim_input: Optional[ Union[ List[Dict[str, Any]], Iterable[torch.nn.Parameter], ] ] = None, optim: Optional[torch.optim.Optimizer] = None, group: Optional[Any] = None, ) -> Dict[str, Any]: """Scatter the full optimizer state dict from rank 0 to all other ranks. Returns the sharded optimizer state dict on each rank. The return value is the same as :meth:`shard_full_optim_state_dict`, and on rank 0, the first argument should be the return value of :meth:`full_optim_state_dict`. Example:: >>> # xdoctest: +SKIP("undefined variables") >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 >>> # Define new model with possibly different world size >>> new_model, new_optim, new_group = ... >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) >>> new_optim.load_state_dict(sharded_osd) .. note:: Both :meth:`shard_full_optim_state_dict` and :meth:`scatter_full_optim_state_dict` may be used to get the sharded optimizer state dict to load. Assuming that the full optimizer state dict resides in CPU memory, the former requires each rank to have the full dict in CPU memory, where each rank individually shards the dict without any communication, while the latter requires only rank 0 to have the full dict in CPU memory, where rank 0 moves each shard to GPU memory (for NCCL) and communicates it to ranks appropriately. Hence, the former has higher aggregate CPU memory cost, while the latter has higher communication cost. Args: full_optim_state_dict (Optional[Dict[str, Any]]): Optimizer state dict corresponding to the unflattened parameters and holding the full non-sharded optimizer state if on rank 0; the argument is ignored on nonzero ranks. model (torch.nn.Module): Root module (which may or may not be a :class:`FullyShardedDataParallel` instance) whose parameters correspond to the optimizer state in ``full_optim_state_dict``. optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): Input passed into the optimizer representing either a :class:`list` of parameter groups or an iterable of parameters; if ``None``, then this method assumes the input was ``model.parameters()``. This argument is deprecated, and there is no need to pass it in anymore. (Default: ``None``) optim (Optional[torch.optim.Optimizer]): Optimizer that will load the state dict returned by this method. This is the preferred argument to use over ``optim_input``. (Default: ``None``) group (dist.ProcessGroup): Model's process group or ``None`` if using the default process group. (Default: ``None``) Returns: Dict[str, Any]: The full optimizer state dict now remapped to flattened parameters instead of unflattened parameters and restricted to only include this rank's part of the optimizer state. """ FullyShardedDataParallel._warn_legacy_optim_state_dict( "scatter_full_optim_state_dict", "optim_state_dict_to_load", stacklevel=2, ) return FullyShardedDataParallel._optim_state_dict_to_load_impl( optim_state_dict=full_optim_state_dict, model=model, optim_input=optim_input, optim=optim, full_state_dict=True, rank0_only=True, is_named_optimizer=False, group=group, )
[docs] @staticmethod def rekey_optim_state_dict( optim_state_dict: Dict[str, Any], optim_state_key_type: OptimStateKeyType, model: torch.nn.Module, optim_input: Optional[ Union[ List[Dict[str, Any]], Iterable[torch.nn.Parameter], ] ] = None, optim: Optional[torch.optim.Optimizer] = None, ) -> Dict[str, Any]: """Re-keys the optimizer state dict ``optim_state_dict`` to use the key type ``optim_state_key_type``. This can be used to achieve compatibility between optimizer state dicts from models with FSDP instances and ones without. To re-key an FSDP full optimizer state dict (i.e. from :meth:`full_optim_state_dict`) to use parameter IDs and be loadable to a non-wrapped model:: >>> # xdoctest: +SKIP("undefined variables") >>> wrapped_model, wrapped_optim = ... >>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) >>> nonwrapped_model, nonwrapped_optim = ... >>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) >>> nonwrapped_optim.load_state_dict(rekeyed_osd) To re-key a normal optimizer state dict from a non-wrapped model to be loadable to a wrapped model:: >>> # xdoctest: +SKIP("undefined variables") >>> nonwrapped_model, nonwrapped_optim = ... >>> osd = nonwrapped_optim.state_dict() >>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) >>> wrapped_model, wrapped_optim = ... >>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) >>> wrapped_optim.load_state_dict(sharded_osd) Returns: Dict[str, Any]: The optimizer state dict re-keyed using the parameter keys specified by ``optim_state_key_type``. """ FullyShardedDataParallel._warn_optim_input(optim_input) using_optim_input = FullyShardedDataParallel._is_using_optim_input( optim_input, optim, ) assert optim_state_key_type in ( OptimStateKeyType.PARAM_NAME, OptimStateKeyType.PARAM_ID, ) osd = optim_state_dict # alias # Validate that the existing parameter keys are uniformly typed uses_param_name_mask = [type(param_key) is str for param_key in osd["state"]] uses_param_id_mask = [type(param_key) is int for param_key in osd["state"]] if (any(uses_param_name_mask) and not all(uses_param_name_mask)) or ( any(uses_param_id_mask) and not all(uses_param_id_mask) ): error_msg = f"Invalid parameter keys: {osd['state'].keys()}" raise ValueError(error_msg) # Return directly if the existing key type matches the target key type if ( optim_state_key_type == OptimStateKeyType.PARAM_NAME and all(uses_param_name_mask) ) or ( optim_state_key_type == OptimStateKeyType.PARAM_ID and all(uses_param_id_mask) ): return osd # Otherwise, actually perform the re-keying new_osd = {} if optim_state_key_type == OptimStateKeyType.PARAM_NAME: # ID -> name param_id_to_param = ( _get_param_id_to_param_from_optim_input(model, optim_input) if using_optim_input else _get_param_key_to_param(optim) ) param_to_param_name = _get_param_to_fqn(model) param_id_to_param_name: List[str] = [ param_to_param_name[param] for param in param_id_to_param.values() ] new_osd["state"] = { param_id_to_param_name[param_id]: param_state for param_id, param_state in osd["state"].items() } new_osd["param_groups"] = copy.deepcopy(osd["param_groups"]) for param_group in new_osd["param_groups"]: param_group["params"] = sorted( [ param_id_to_param_name[param_id] for param_id in param_group["params"] ] ) return new_osd elif optim_state_key_type == OptimStateKeyType.PARAM_ID: # name -> ID param_name_to_param = _get_fqn_to_param(model) param_to_param_id = ( _get_param_to_param_id_from_optim_input(model, optim_input) if using_optim_input else _get_param_to_param_key(optim) ) # Because not all model parameters may be passed as the optimizer # input, we may need to drop some parameters from this mapping param_name_to_param_id = { param_name: param_to_param_id[param] for param_name, param in param_name_to_param.items() if param in param_to_param_id } new_osd["state"] = { param_name_to_param_id[param_name]: param_state for param_name, param_state in osd["state"].items() } new_osd["param_groups"] = copy.deepcopy(osd["param_groups"]) for param_group in new_osd["param_groups"]: param_group["params"] = sorted( [ param_name_to_param_id[param_name] for param_name in param_group["params"] ] ) return new_osd return new_osd # should never reach here
[docs] @staticmethod def optim_state_dict( model: torch.nn.Module, optim: torch.optim.Optimizer, optim_state_dict: Optional[Dict[str, Any]] = None, group: Optional[dist.ProcessGroup] = None, ) -> Dict[str, Any]: """ Transform the state-dict of an optimizer corresponding to a sharded model. The given state-dict can be transformed to one of three types: 1) full optimizer state_dict, 2) sharded optimizer state_dict, 3) local optimizer state_dict. For full optimizer state_dict, all states are unflattened and not sharded. Rank0 only and CPU only can be specified via :meth:`state_dict_type` to avoid OOM. For sharded optimizer state_dict, all states are unflattened but sharded. CPU only can be specified via :meth:`state_dict_type` to further save memory. For local state_dict, no transformation will be performed. But a state will be converted from nn.Tensor to ShardedTensor to represent its sharding nature (this is not supported yet). Example:: >>> # xdoctest: +SKIP("undefined variables") >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> from torch.distributed.fsdp import FullStateDictConfig >>> from torch.distributed.fsdp import FullOptimStateDictConfig >>> # Save a checkpoint >>> model, optim = ... >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> state_dict = model.state_dict() >>> optim_state_dict = FSDP.optim_state_dict(model, optim) >>> save_a_checkpoint(state_dict, optim_state_dict) >>> # Load a checkpoint >>> model, optim = ... >>> state_dict, optim_state_dict = load_a_checkpoint() >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> model.load_state_dict(state_dict) >>> optim_state_dict = FSDP.optim_state_dict_to_load( >>> model, optim, optim_state_dict >>> ) >>> optim.load_state_dict(optim_state_dict) Args: model (torch.nn.Module): Root module (which may or may not be a :class:`FullyShardedDataParallel` instance) whose parameters were passed into the optimizer ``optim``. optim (torch.optim.Optimizer): Optimizer for ``model`` 's parameters. optim_state_dict (Dict[str, Any]): the target optimizer state_dict to transform. If the value is None, optim.state_dict() will be used. ( Default: ``None``) group (dist.ProcessGroup): Model's process group across which parameters are sharded or ``None`` if using the default process group. ( Default: ``None``) Returns: Dict[str, Any]: A :class:`dict` containing the optimizer state for ``model``. The sharding of the optimizer state is based on ``state_dict_type``. """ state_dict_settings = FullyShardedDataParallel.get_state_dict_type(model) if optim_state_dict is None: optim_state_dict = optim.state_dict() return FullyShardedDataParallel._optim_state_dict_impl( model=model, optim=optim, optim_state_dict=optim_state_dict, optim_input=None, rank0_only=getattr( state_dict_settings.optim_state_dict_config, "rank0_only", False ), full_state_dict=state_dict_settings.state_dict_type == StateDictType.FULL_STATE_DICT, group=group, cpu_offload=getattr( state_dict_settings.optim_state_dict_config, "offload_to_cpu", True ), _stacklevel=2, )
[docs] @staticmethod def optim_state_dict_to_load( model: torch.nn.Module, optim: torch.optim.Optimizer, optim_state_dict: Dict[str, Any], is_named_optimizer: bool = False, load_directly: bool = False, group: Optional[dist.ProcessGroup] = None, ) -> Dict[str, Any]: """ Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model. Given a ``optim_state_dict`` that is transformed through :meth:`optim_state_dict`, it gets converted to the flattened optimizer state_dict that can be loaded to ``optim`` which is the optimizer for ``model``. ``model`` must be sharded by FullyShardedDataParallel. >>> # xdoctest: +SKIP("undefined variables") >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> from torch.distributed.fsdp import FullStateDictConfig >>> from torch.distributed.fsdp import FullOptimStateDictConfig >>> # Save a checkpoint >>> model, optim = ... >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> state_dict = model.state_dict() >>> original_osd = optim.state_dict() >>> optim_state_dict = FSDP.optim_state_dict( >>> model, >>> optim, >>> optim_state_dict=original_osd >>> ) >>> save_a_checkpoint(state_dict, optim_state_dict) >>> # Load a checkpoint >>> model, optim = ... >>> state_dict, optim_state_dict = load_a_checkpoint() >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> model.load_state_dict(state_dict) >>> optim_state_dict = FSDP.optim_state_dict_to_load( >>> model, optim, optim_state_dict >>> ) >>> optim.load_state_dict(optim_state_dict) Args: model (torch.nn.Module): Root module (which may or may not be a :class:`FullyShardedDataParallel` instance) whose parameters were passed into the optimizer ``optim``. optim (torch.optim.Optimizer): Optimizer for ``model`` 's parameters. optim_state_dict (Dict[str, Any]): The optimizer states to be loaded. is_named_optimizer (bool): Is this optimizer a NamedOptimizer or KeyedOptimizer. Only set to True if ``optim`` is TorchRec's KeyedOptimizer or torch.distributed's NamedOptimizer. load_directly (bool): If this is set to True, this API will also call optim.load_state_dict(result) before returning the result. Otherwise, users are responsible to call ``optim.load_state_dict()`` (Default: ``False``) group (dist.ProcessGroup): Model's process group across which parameters are sharded or ``None`` if using the default process group. ( Default: ``None``) """ state_dict_settings = FullyShardedDataParallel.get_state_dict_type(model) result = FullyShardedDataParallel._optim_state_dict_to_load_impl( optim_state_dict=optim_state_dict, model=model, optim_input=None, optim=optim, full_state_dict=( state_dict_settings.state_dict_type == StateDictType.FULL_STATE_DICT ), rank0_only=getattr( state_dict_settings.optim_state_dict_config, "rank0_only", False ), is_named_optimizer=is_named_optimizer, group=group, ) if load_directly: optim.load_state_dict(result) return result
[docs] def register_comm_hook(self, state: object, hook: callable): """Register a communication hook. This is an enhancement that provides a flexible hook to users where they can specify how FSDP aggregates gradients across multiple workers. This hook can be used to implement several algorithms like `GossipGrad <https://arxiv.org/abs/1803.05880>`_ and gradient compression which involve different communication strategies for parameter syncs while training with :class:`FullyShardedDataParallel`. .. warning :: FSDP communication hook should be registered before running an initial forward pass and only once. Args: state (object): Passed to the hook to maintain any state information during the training process. Examples include error feedback in gradient compression, peers to communicate with next in `GossipGrad <https://arxiv.org/abs/1803.05880>`_, etc. It is locally stored by each worker and shared by all the gradient tensors on the worker. hook (Callable): Callable, which has one of the following signatures: 1) ``hook: Callable[torch.Tensor] -> None``: This function takes in a Python tensor, which represents the full, flattened, unsharded gradient with respect to all variables corresponding to the model this FSDP unit is wrapping (that are not wrapped by other FSDP sub-units). It then performs all necessary processing and returns ``None``; 2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``: This function takes in two Python tensors, the first one represents the full, flattened, unsharded gradient with respect to all variables corresponding to the model this FSDP unit is wrapping (that are not wrapped by other FSDP sub-units). The latter represents a pre-sized tensor to store a chunk of a sharded gradient after reduction. In both cases, callable performs all necessary processing and returns ``None``. Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case. Callables with signature 2 are expected to handle gradient communication for sharded cases. """ if not self.check_is_root(): raise AssertionError( "register_comm_hook can only be called on a root instance." ) for fsdp_state in traversal_utils._get_fsdp_states(self): if fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES: raise AssertionError( f"Communication hook is not supported for hybrid strategies: {fsdp_state.sharding_strategy}" ) if fsdp_state._comm_hook is not None: raise AssertionError("A communication hook is already registered") if not callable(hook): raise ValueError( f"The communication hook must be callable but got {hook}" ) fsdp_state._comm_hook = hook fsdp_state._comm_hook_state = state
def _unshard(self, async_op: bool = False): class UnshardHandle: def __init__( self, flat_param_handle: Optional[FlatParamHandle], unshard_event: torch.cuda.Event, ): self._flat_param_handle = flat_param_handle self._unshard_event = unshard_event def wait(self): if self._flat_param_handle is not None: current_stream = ( self._flat_param_handle._device_handle.current_stream() ) current_stream.wait_event(self._unshard_event) self._flat_param_handle = None if self._handle: with self._use_training_state( TrainingState.FORWARD_BACKWARD, HandleTrainingState.FORWARD ): _unshard( self, self._handle, self._unshard_stream, self._pre_unshard_stream ) self._unshard_event = self._unshard_stream.record_event() self._handle._prefetched = True unshard_handle = UnshardHandle(self._handle, self._unshard_stream) if async_op: return unshard_handle unshard_handle.wait() return None def _wait_unshard_streams_on_current_stream(self): _wait_for_computation_stream( self._device_handle.current_stream(), self._unshard_stream, self._pre_unshard_stream, ) @contextlib.contextmanager def _use_training_state( self, training_state: TrainingState, handle_training_state: HandleTrainingState ): prev_training_state = self.training_state self.training_state = training_state if self._handle: prev_handle_training_state = self._handle._training_state self._handle._training_state = handle_training_state try: yield finally: self.training_state = prev_training_state if self._handle: self._handle._training_state = prev_handle_training_state
def _get_grad_norm( params: Iterable[nn.Parameter], norm_type: float, ) -> torch.Tensor: """ Return the gradient norm of parameters ``param`` s, where the gradients are viewed as a single vector. The returned norm is in FP32 even if parameters/gradients are in a low precision. This is because the downstream use of this return value is a reduction across ranks. """ params_with_grad = [param for param in params if param.grad is not None] if len(params_with_grad) == 0: return torch.tensor(0.0) grads = [param.grad for param in params_with_grad] grad_dtypes = {grad.dtype for grad in grads} if len(grad_dtypes) != 1: raise ValueError( f"Requires uniform dtype across all gradients but got {grad_dtypes}" ) # Compute the gradient norm in FP32, where we treat the gradients as a # single vector grad_norm = torch.linalg.vector_norm( torch.stack( [ torch.linalg.vector_norm(grad.detach(), norm_type, dtype=torch.float32) for grad in grads ], ), norm_type, dtype=torch.float32, ) return grad_norm def _get_param_to_fqn( model: torch.nn.Module, ) -> Dict[torch.nn.Parameter, str]: """ Construct a mapping from parameters to their parameter names. The ``model`` should not contain any :class:`FullyShardedDataParallel` instances, which means that none of the parameters should be ``FlatParameter`` s. As a result, compared to :meth:`_get_param_to_fqns`, the mapped values may be flattened from singleton :class:`list` s to the contained names themselves. Args: model (torch.nn.Module): Root module, which should not contain any :class:`FullyShardedDataParallel` instances. """ param_to_param_names = _get_param_to_fqns(model) for param_names in param_to_param_names.values(): assert ( len(param_names) > 0 ), "`_get_param_to_fqns()` should not construct empty lists" if len(param_names) > 1: raise RuntimeError( "Each parameter should only map to one parameter name but got " f"{len(param_names)}: {param_names}" ) param_to_param_name = { param: param_names[0] for param, param_names in param_to_param_names.items() } return param_to_param_name def _get_fqn_to_param( model: torch.nn.Module, ) -> Dict[str, torch.nn.Parameter]: """Construct the inverse mapping of :meth:`_get_param_to_fqn`.""" param_to_param_name = _get_param_to_fqn(model) return dict(zip(param_to_param_name.values(), param_to_param_name.keys()))

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources