Source code for torch.distributed.fsdp.fully_sharded_data_parallel
importcollectionsimportcontextlibimportcopyimportfunctoolsimportitertoolsimportmathimporttracebackimportwarningsfromcontextlibimportcontextmanagerfromdataclassesimportdataclassfromenumimportEnum,autofromtypingimport(Any,Callable,Deque,Dict,Generator,Iterable,Iterator,List,Mapping,NamedTuple,Optional,Set,Tuple,Union,cast,)importtorchimporttorch.distributedasdistimporttorch.distributed.algorithms._checkpoint.checkpoint_wrapperascheckpoint_wrapperimporttorch.nnasnnimporttorch.nn.functionalasFfromtorch.autogradimportVariablefromtorch.distributedimportProcessGroupfromtorch.distributed._shard.sharded_tensorimport(Shard,ShardedTensor,init_from_local_shards,)fromtorch.distributed.algorithms._checkpoint.checkpoint_wrapperimport(_CHECKPOINT_PREFIX,)fromtorch.distributed.algorithms._comm_hooksimport(LOW_PRECISION_HOOKS,default_hooks,)fromtorch.distributed.distributed_c10dimport_get_default_groupfromtorch.distributed.utilsimport(_replace_by_prefix,_sync_params_and_buffers,_to_kwargs,)fromtorch.nn.parameterimportParameterfrom._optim_utilsimport(_broadcast_pos_dim_tensor_states,_broadcast_processed_optim_state_dict,_flatten_optim_state_dict,_get_param_id_to_param,_get_param_id_to_param_from_optim_input,_get_param_to_param_id,_get_param_to_param_id_from_optim_input,_optim_state_dict,_process_pos_dim_tensor_state,_rekey_sharded_optim_state_dict,)from._fsdp_extensionsimport_ext_chunk_tensor,_ext_pre_load_state_dict_transformfrom._utilsimport(_apply_to_modules,_apply_to_tensors,_contains_batchnorm,_free_storage,_is_fsdp_flattened,_override_batchnorm_mixed_precision,p_assert,)from.flat_paramimport(FlatParameter,FlatParamHandle,HandleConfig,HandleShardingStrategy,HandleTrainingState,)from.flatten_params_wrapperimport(FLAT_PARAM,FPW_MODULE,FlattenParamsWrapper,)from.wrapimport(ParamExecOrderWrapPolicy,_or_policy,_recursive_wrap,_wrap_batchnorm_individually,)_TORCHDISTX_AVAIL=Truetry:fromtorchdistximportdeferred_init,fakeexceptImportError:_TORCHDISTX_AVAIL=False_TORCH_FX_AVAIL=Trueifnothasattr(torch,"fx"):_TORCH_FX_AVAIL=Falseif_TORCH_FX_AVAIL:from._symbolic_traceimport(TracingConfig,_init_execution_info,_patch_tracer,)__all__=["FullyShardedDataParallel","ShardingStrategy","MixedPrecision","CPUOffload","BackwardPrefetch","StateDictType","StateDictConfig","FullStateDictConfig","LocalStateDictConfig","ShardedStateDictConfig","OptimStateKeyType","TrainingState_","clean_tensor_name",]FSDP_WRAPPED_MODULE="_fsdp_wrapped_module"FSDP_PREFIX=FSDP_WRAPPED_MODULE+"."+FPW_MODULE+"."_PARAM_BROADCAST_BUCKET_SIZE=int(250*1024*1024)classShardingStrategy(Enum):""" This specifies the sharding strategy to be used for distributed training by :class:`FullyShardedDataParallel`. FULL_SHARD: Parameters, gradients, and optimizer states are sharded. For the parameters, this algorithm all-gathers before the forward, reshards after the forward, all-gathers before the backward computation, and reshards after the backward computation. The gradients are synchronized and sharded via reduce-scatter after the backward computation. The sharded optimizer states are updated locally. SHARD_GRAD_OP: Gradients and optimizer states are sharded during computation, and additionally parameters are sharded outside computation. For the parameters, this algorithm all-gathers before the forward, does not reshard after the forward, and only reshards after the backward computation. The gradients are synchronized and sharded via reduce-scatter after the backward computation. The sharded optimizer states are updated locally. Inside ``no_sync()``, the parameters are not resharded after the backward computation. NO_SHARD: Parameters, gradients, and optimizer states are not sharded but instead replicated across ranks, similar to PyTorch's ``DistributedDataParallel`` API. The gradients are synchronized via all-reduce after the backward computation. The unsharded optimizer states are updated locally. HYBRID_SHARD(future support): Apply ``FULL_SHARD`` intra-node and ``NO_SHARD`` inter-node. """FULL_SHARD=auto()SHARD_GRAD_OP=auto()NO_SHARD=auto()# TODO# HYBRID_SHARD = auto()@dataclassclassMixedPrecision:""" A config to enable mixed precision training with FullyShardedDataParallel. This class can be constructed with several flags: ``param_dtype`` controls the precision of model parameters, inputs, and therefore the precision under which computation happens. After forward and backward passes, FSDP parameters point to full precision shards that are kept in memory. Full precision parameters are always checkpointed. ``reduce_dtype`` controls the precision under which gradient reduction would occur, which can potentially be different than ``param_dtype`` for use cases such as communication efficiency. ``buffer_dtype`` controls the precision that buffers are cast to. Note that buffers are unsharded and are cast in the first forward pass, and remain in their reduced precision state even after forward/backward passes. However, when taking checkpoints with ``state_dict``, buffers are checkpointed in their full precision (and then restored back to to their reduced precision) as expected. Note that this checkpoint support is currently limited to ``StateDictType.FULL_STATE_DICT``. ``keep_low_precision_grads``: Whether to upcast gradients back to the full parameter precision after backwards or not. This can be disabled to keep the gradients in the lower precision, which can potentially save memory if custom Optimizers are able to perform parameter updates effectively with lower precision grads. .. note:: In ``summon_full_params``, parameters are summoned in full precision but buffers are not. .. note:: Parameters and buffers are checkpointed in full precision. For buffers, this is only guaranteed to work for ``StateDictType.FULL_STATE_DICT``. .. note:: This API is experimental and subject to change. .. note:: Specification of reduced precision types must be explicit, in that if, for example, ``param_dtype`` is not specified, it will not be cast by FSDP. Thus, a config such as ``MixedPrecision(reduce_dtype=torch.float16)`` will not cast buffers or parameters. Note that if a ``MixedPrecision`` config is specified without a ``reduce_dtype``, gradient communication would occur in the `param_dtype` precision, if given, otherwise, in the original parameter precision. """# maintain a tensor of this dtype that the fp32 param shard will be cast to.# Will control the precision of model params, inputs, and thus compute as# well.param_dtype:Optional[torch.dtype]=None# Gradient communication precision.reduce_dtype:Optional[torch.dtype]=None# Buffer precision.# TODO: buffer + param are usually of the same type, if user specifies# param but not buffer, should we automatically make buffer be the same?buffer_dtype:Optional[torch.dtype]=Nonekeep_low_precision_grads:Optional[bool]=False@dataclassclassCPUOffload:""" CPU offloading config. Currently, only parameter and gradient CPU offload are supported. offload_params: Offloading parameters to CPUs when these parameters are not used for computation on GPUs. This implicitly enables gradient offloading to CPUs in order for parameters and gradients to be on the same device to work with optimizer. """offload_params:bool=FalseclassBackwardPrefetch(Enum):""" Specify where to prefetch next layer's full parameters during backward pass. BACKWARD_PRE: prefetch right before current layer's backward computation starts, this approach will increase backward communication and computation overalpping and potentialy improve training performance, but it may increase the peak memory usage as the prefetched full parameters will be kept in the GPU memory until next layer's backward computation is done. BACKWARD_POST: prefetch right after current layer's backward computation finishes, this approach will not increase peak memory as prefetching happens after current layer's full parameters are freed. It could potentially improve backward communication and computation overlapping as it avoids all_gather and reduce_scatter are blocked each other in the single NCCL stream. However, based on our experiments, for some models, the backward post backward hook fire order is not always the reversed forward computation order, so this approach may prefetch full parameters for layers ahead of next layer, this 'ahead' all_gather could delay next layer's all_gather in the single NCCL stream and cause the next layer's computation delay. So it may cause some performance regession for some models. """BACKWARD_PRE=auto()BACKWARD_POST=auto()# TODO, BACKWARD_PRE_CPU, prefetch full parameters and keep them in the CPU memoryclassTrainingState_(Enum):""" Simple enum to indicate what state FSDP is in. Used for asserting to make sure APIs are called in the correct state. ..note:: ``BACKWARD_PRE`` and ``BACKWARD_POST`` states are used to ensure we receives backward hooks in the correct order. It is used to catch unexpected order of hooks being called (likely due to our hook registration logic or autograd engine logic changes). """IDLE=auto()FORWARD=auto()BACKWARD_PRE=auto()BACKWARD_POST=auto()SUMMON_FULL_PARAMS=auto()classStateDictType(Enum):""" This enum indicates that which type of ``state_dict`` the FSDP module is currently processing (returning or loading). The default value is FULL_STATE_DICT to comply the PyTorch convention. ..note:: FSDP currently supports three types of ``state_dict``: 1. ``state_dict/load_state_dict`: this pair of APIs return and load the non-sharded, unflattened parameters. The semantics is the same as using DDP. 2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return and load local sharded, flattened parameters. The values returned by ``_local_state_dict`` can be directly used by FSDP and is only meaningful to FSDP (because parameters are flattened). Note that these APIs are meant for use via the :func:`state_dict_type` context manager as follows: >>> # xdoctest: +SKIP("undefined variables") >>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT): ... state = fsdp.state_dict() # loads local state dict 3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs return and load sharded, unflattened parameters. The ``state_dict`` return by ``sharded_state_dict`` can be used by all other parallel schemes (resharding may be required). """FULL_STATE_DICT=auto()LOCAL_STATE_DICT=auto()SHARDED_STATE_DICT=auto()@dataclassclassStateDictConfig:""" ``StateDictConfig`` is the base class for all state_dict configuration classes. Users should instantiate a child version (i.e. ``FullStateDictConfig``) in order to configure settings for the particular type of ``state_dict`` implementation FSDP will use. """pass@dataclassclassFullStateDictConfig(StateDictConfig):""" ``FullStateDictConfig`` is a config class meant to be used with ``StateDictType.FULL_STATE_DICT``. Currently, it accepts two parameters, ``offload_to_cpu`` and ``rank0_only`` which can be configured to offload the full ``state_dict`` to CPU and to materialize the ``state_dict`` on rank 0 only. When used, it is recommended to enable both of these flags together to optimize memory savings when taking checkpoints. Note that this config class is meant for user via the :func:`state_dict_type` context manager as follows: >>> # xdoctest: +SKIP("undefined variables") >>> fsdp = FSDP(model, auto_wrap_policy=...) >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) >>> with FullyShardedDataParallel.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): >>> state = fsdp.state_dict() >>> # state will be empty on non rank 0 and contain CPU tensors on rank 0. >>> # To reload checkpoint for inference, finetuning, transfer learning, etc: >>> model = model_fn() # Initialize model on CPU in preparation for wrapping with FSDP >>> if dist.get_rank() == 0: >>> # Load checkpoint only on rank 0 to avoid memory redundancy >>> state_dict = torch.load("my_checkpoint.pt") >>> model.load_state_dict(state_dict) >>> # All ranks initialize FSDP module as usual. ``sync_module_states`` argument >>> # communicates loaded checkpoint states from rank 0 to rest of the world. >>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True) >>> # After this point, all ranks have FSDP model with loaded checkpoint. """offload_to_cpu:bool=Falserank0_only:bool=False@dataclassclassLocalStateDictConfig(StateDictConfig):pass@dataclassclassShardedStateDictConfig(StateDictConfig):pass_state_dict_type_to_config={StateDictType.FULL_STATE_DICT:FullStateDictConfig,StateDictType.LOCAL_STATE_DICT:LocalStateDictConfig,StateDictType.SHARDED_STATE_DICT:ShardedStateDictConfig,}classOptimStateKeyType(Enum):PARAM_NAME=auto()PARAM_ID=auto()# A handles key represents the group of `FlatParamHandle`s involved in a given# module's forward. These will be all-gathered together in the pre-forward and# pre-backward._HandlesKey=Tuple[FlatParamHandle,...]class_ExecOrderWarnStatus(Enum):"""Used internally for execution order validation."""NONE=auto()# no deviation yetWARNING=auto()# deviated this iteration; currently issuing warningsWARNED=auto()# deviated in a previous iterationclass_ExecOrderData:""" This contains the data structures to track the execution order. We track the pre-forward order on the *first* iteration for forward prefetching (which thus assumes static graph) and the post-forward order on *every* iteration for backward prefetching (which thus does not assume static graph but may be provide an incorrect order). """def__init__(self,debug_level:dist.DebugLevel,backward_prefetch_limit:int,forward_prefetch_limit:int,)->None:# Tracks the (static) pre-forward order for execution order validation# and forward prefetchingself.handles_pre_forward_order:List[int]=[]# Maps each handles key to its index in `handles_pre_forward_order`self.handles_to_pre_forward_order_index:Dict[_HandlesKey,int]={}# Tracks the post-forward order for pre-backward prefetchingself.handles_post_forward_order:List[int]=[]# Maps each handles key to its index in `handles_post_forward_order`self.handles_to_post_forward_order_index:Dict[_HandlesKey,int]={}self.is_first_iter=True# Gives the max number of backward/forward prefetched all-gathers by a# single moduleself._backward_prefetch_limit=backward_prefetch_limitself._forward_prefetch_limit=forward_prefetch_limit# Data structures for execution order validationself._checking_order:bool=(debug_levelin[dist.DebugLevel.INFO,dist.DebugLevel.DETAIL])self.process_group:Optional[dist.ProcessGroup]=Noneself.world_size:Optional[int]=Noneself.all_handles:List[FlatParamHandle]=[]# Maps each handle to its index in `all_handles`, which must be the# same across ranks for the execution order validation to workself.handle_to_handle_index:Dict[FlatParamHandle,int]={}# Names are prefixed from the root moduleself.flat_param_to_prefixed_param_names:Dict[FlatParameter,List[str]]={}# Current index in the pre-forward execution orderself.current_order_index=0self.warn_status=_ExecOrderWarnStatus.NONEdefinit(self,fsdp_root:"FullyShardedDataParallel",process_group:dist.ProcessGroup,)->None:""" Initializes the data structures needed for checking the forward order. This should be called after a root FSDP instance has been set during lazy initialization. """self.process_group=process_groupself.rank=process_group.rank()self.world_size=process_group.size()# Fix an order over the handles, which should be the same across ranksforfsdp_moduleinfsdp_root.fsdp_modules(fsdp_root):forhandleinfsdp_module._handles:index=len(self.all_handles)self.all_handles.append(handle)self.handle_to_handle_index[handle]=indexself.flat_param_to_prefixed_param_names=cast(Dict[FlatParameter,List[str]],_get_param_to_unflat_param_names(fsdp_root),)# TODO (awgu): We can broadcast the metadata of rank 0's `all_handles`# to check that all ranks have the same handles in the same order.# https://github.com/pytorch/pytorch/issues/79620defget_handles_to_backward_prefetch(self,current_handles_key:_HandlesKey,)->List[_HandlesKey]:""" Returns a :class:`list` of the handles keys of the handles to backward prefetch given the current handles key. If there are no valid handles keys to prefetch, then this returns an empty :class:`list`. """current_index=self.handles_to_post_forward_order_index.get(current_handles_key,None)ifcurrent_indexisNone:returnNonetarget_index=current_index-1target_handles_keys:List[_HandlesKey]=[]for_inrange(self._backward_prefetch_limit):iftarget_index<0:breaktarget_handles_keys.append(self.handles_post_forward_order[target_index])target_index-=1returntarget_handles_keysdefget_handles_to_forward_prefetch(self,current_handles_key:_HandlesKey,)->List[_HandlesKey]:""" Returns a :class:`list` of the handles keys of the handles to forward prefetch given the current handles key. If there are no valid handles keys to prefetch, then this returns an empty :class:`list`. """current_index=self.handles_to_pre_forward_order_index.get(current_handles_key,None)ifcurrent_indexisNone:returnNonetarget_index=current_index+1target_handles_keys:List[_HandlesKey]=[]for_inrange(self._forward_prefetch_limit):iftarget_index>=len(self.handles_pre_forward_order):breaktarget_handles_keys.append(self.handles_pre_forward_order[target_index])target_index+=1returntarget_handles_keysdefrecord_post_forward(self,handles:List[FlatParamHandle])->None:""" Records ``handles`` in the post-forward order, where ``handles`` should be a group of handles used in the same module's forward. If ``handles`` is empty, then it is omitted. Unlike :meth:`record_pre_forward`, this records the order *every* iteration with the expectation that the recorded order is reset in :meth:`next_iter`. """ifnothandles:returnhandles_key=tuple(handles)# Only record the first usage of a handles keyifhandles_keyinself.handles_to_post_forward_order_index:returnindex=len(self.handles_post_forward_order)self.handles_to_post_forward_order_index[handles_key]=indexself.handles_post_forward_order.append(handles_key)defrecord_pre_forward(self,handles:List[FlatParamHandle],is_training:bool)->None:""" Records ``handles`` in the pre-forward order on the first iteration, where ``handles`` should be a group of handles used in the same module's forward. If ``handles`` is empty, then it is omitted. On the first iteration, this checks the execution order across ranks. See :meth:`_check_order` for details. """ifnothandles:returnhandles_key=tuple(handles)self._check_order(handles_key,is_training)# Fix the order after the first iteration and only record the first# usage of a handles keyif(notself.is_first_iterorhandles_keyinself.handles_to_pre_forward_order_index):returnindex=len(self.handles_pre_forward_order)self.handles_to_pre_forward_order_index[handles_key]=indexself.handles_pre_forward_order.append(handles_key)def_check_order(self,handles_key:_HandlesKey,is_training:bool)->None:""" Checks the forward execution order as long as ``is_training`` is ``True`` since checking in eval mode is not supported. - On the first iteration, this uses all-gathers to check that all ranks are all-gathering the same handles and hence ``FlatParameter`` s, raising an error if not. - On subsequent iterations, if the distributed debug level is at least INFO, then this checks that each rank is locally consistent with its own forward order from the first iteration, issuing a warning if not. This issues a warning on the first deviating iteration and stops warning thereafter. """# Do not check order in eval mode since the post-backward callback does# not run so it cannot be used to mark the end of an iterationifnotis_training:returnifself.is_first_iter:msg_prefix="Forward order differs across ranks:"local_indices:Optional[Tuple[int,...]]=self._get_handle_indices(handles_key)device=handles_key[0].device# guaranteed to be non-CPUnum_valid_indices=sum((indexisnotNone)forindexinlocal_indices)tensor_kwargs={"dtype":torch.int32,"device":device}world_num_valid_indices=torch.zeros(self.world_size,**tensor_kwargs)local_num_valid_indices=torch.tensor([num_valid_indices],**tensor_kwargs)dist._all_gather_base(world_num_valid_indices,local_num_valid_indices,group=self.process_group,)# Check that all ranks plan to all-gather the same number of# parameters# TODO (awgu): Since every module has at most one handle in the# current implementation, this should never raise the error.for(r1,n1),(r2,n2)initertools.combinations(((rank,world_num_valid_indices[rank])forrankinrange(self.world_size)),2,):ifn1!=n2:raiseRuntimeError(f"{msg_prefix} rank {r1} is all-gathering {n1} parameters "f"while rank {r2} is all-gathering {n2} parameters")world_indices=torch.zeros(self.world_size*num_valid_indices,**tensor_kwargs)local_indices=torch.tensor(local_indices,**tensor_kwargs)dist._all_gather_base(world_indices,local_indices,group=self.process_group)# Check that all ranks plan to all-gather the same index parametersfor(r1,i1),(r2,i2)initertools.combinations(((rank,world_indices[rank*num_valid_indices:(rank+1)*num_valid_indices],)forrankinrange(self.world_size)),2,):ifi1!=i2:r1_param_names=self._get_names_from_handle_indices(i1)r2_param_names=self._get_names_from_handle_indices(i2)raiseRuntimeError(f"{msg_prefix} rank {r1} is all-gathering parameters "f"for {r1_param_names} while rank {r2} is all-gathering "f"parameters for {r2_param_names}")elifself._checking_order:# Only issue warnings on the first deviating iteration and stop# checking thereafter to avoid flooding the consoleifself.warn_status==_ExecOrderWarnStatus.WARNED:returnmsg_prefix=None# non-`None` means we should warnifself.current_order_index>=len(self.handles_pre_forward_order):# This iteration sees extra all-gather(s) compared to the firstmsg_prefix=("Expected to not all-gather any more parameters in the ""forward but trying to all-gather parameters for ")else:expected_handles_key=self.handles_pre_forward_order[self.current_order_index]ifexpected_handles_key!=handles_key:expected_param_names=self._get_names_from_handles(expected_handles_key)msg_prefix=(f"Expected to all-gather for {expected_param_names} ""but trying to all-gather parameters for ")ifmsg_prefixisnotNone:param_names=self._get_names_from_handles(handles_key)msg_suffix=(f"{param_names}"ifparam_nameselse"a newly-added parameter since construction time")warnings.warn("Forward order differs from that of the first iteration "f"on rank {self.rank}. Collectives are unchecked and may "f"give incorrect results or hang.\n{msg_prefix}{msg_suffix}")self.warn_status=_ExecOrderWarnStatus.WARNINGself.current_order_index+=1def_get_handle_indices(self,handles_key:_HandlesKey,)->Tuple[Optional[int],...]:""" Returns the handle indices (i.e. indices into ``self.all_handles``) corresponding to the handles in ``handles_key``. An entry in the returned tuple is ``None`` if the handle is invalid. """indices:List[int]=[]forhandleinhandles_key:ifhandlenotinself.handle_to_handle_index:indices.append(None)else:indices.append(self.handle_to_handle_index[handle])returntuple(indices)def_get_names_from_handle_indices(self,handle_indices:Tuple[int,...],)->List[List[str]]:""" Returns a list of prefixed parameter names for each handle in ``handle_indices``. If a handle index is invalid, then its prefixed parameter names are omitted from the returned list. """prefixed_param_names:List[List[str]]=[]forindexinhandle_indices:ifindexisNoneorindex<0orindex>=len(self.all_handles):continuehandle=self.all_handles[index]flat_param=handle.flat_paramprefixed_param_names.append(self.flat_param_to_prefixed_param_names[flat_param])returnprefixed_param_namesdef_get_names_from_handles(self,handles_key:_HandlesKey,)->List[List[str]]:""" Returns a list of prefixed parameter names for each handle in ``handles_key``. If a handle is invalid, then its prefixed parameter names are omitted from the returned list. """prefixed_param_names:List[List[str]]=[]forhandleinhandles_key:flat_param=handle.flat_paramifflat_paramnotinself.flat_param_to_prefixed_param_names:continueprefixed_param_names.append(self.flat_param_to_prefixed_param_names[flat_param])returnprefixed_param_namesdefnext_iter(self):""" Advances the internal data structures per iteration. This should be called in the post-backward callback since that marks the true end of an iteration. """self.is_first_iter=Falseself.handles_to_post_forward_order_index.clear()self.handles_post_forward_order.clear()ifself._checking_order:self.current_order_index=0ifself.warn_status==_ExecOrderWarnStatus.WARNING:self.warn_status=_ExecOrderWarnStatus.WARNEDclass_FreeEventQueue:""" This tracks all pending frees corresponding to inflight all-gathers. The queueing pattern is iterative enqueues with a single dequeue per iteration once the limit ``_max_num_inflight_all_gathers`` is reached. """def__init__(self)->None:self._queue:Deque[torch.cuda.Event]=collections.deque()self._max_num_inflight_all_gathers=2# empirically chosendefenqueue(self,free_event:torch.cuda.Event)->None:"""Enqueues a free event."""self._queue.append(free_event)defdequeue_if_needed(self)->Optional[torch.cuda.Event]:"""Dequeues a single event if the limit is reached."""iflen(self._queue)>=self._max_num_inflight_all_gathers:returnself._dequeue()returnNonedef_dequeue(self)->Optional[torch.cuda.Event]:"""Dequeues a free event if possible."""ifself._queue:event=self._queue.popleft()returneventreturnNone# TODO (awgu): Refactor this latersharding_strategy_map={ShardingStrategy.NO_SHARD:HandleShardingStrategy.NO_SHARD,ShardingStrategy.FULL_SHARD:HandleShardingStrategy.FULL_SHARD,ShardingStrategy.SHARD_GRAD_OP:HandleShardingStrategy.SHARD_GRAD_OP,}
[docs]classFullyShardedDataParallel(nn.Module):""" A wrapper for sharding Module parameters across data parallel workers. This is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_. FullyShardedDataParallel is commonly shortened to FSDP. .. _`Xu et al.`: https://arxiv.org/abs/2004.13336 .. _DeepSpeed: https://www.deepspeed.ai/ Example:: >>> # xdoctest: +SKIP("undefined variables") >>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> torch.cuda.set_device(device_id) >>> sharded_module = FSDP(my_module) >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) >>> x = sharded_module(x, y=3, z=torch.Tensor([1])) >>> loss = x.sum() >>> loss.backward() >>> optim.step() .. warning:: The optimizer must be initialized *after* the module has been wrapped, since FSDP will shard parameters in-place and this will break any previously initialized optimizers. .. warning:: If the destination CUDA device has ID ``dev_id``, either (1) ``module`` should already be placed on that device, (2) the device should be set using ``torch.cuda.set_device(dev_id)``, or (3) ``dev_id`` should be passed into the ``device_id`` constructor argument. This FSDP instance's compute device will be that destination device. For (1) and (3), the FSDP initialization always occurs on GPU. For (2), the FSDP initialization happens on ``module`` 's current device, which may be CPU. .. warning:: FSDP currently does not support gradient accumulation outside ``no_sync()`` when using CPU offloading. Trying to do so yields incorrect results since FSDP will use the newly-reduced gradient instead of accumulating with any existing gradient. .. warning:: Changing the original parameter variable names after construction will lead to undefined behavior. .. warning:: Passing in `sync_module_states=True` flag requires module to be put on GPU, or to use ``device_id`` argument to specify a CUDA device that FSDP will move module to. This is because ``sync_module_states=True`` requires GPU communication. .. warning:: As of PyTorch 1.12, FSDP only offers limited support for shared parameters (for example, setting one ``Linear`` layer's weight to another's). In particular, modules that share parameters must be wrapped as part of the same FSDP unit. If enhanced shared parameter support is needed for your use case, please ping https://github.com/pytorch/pytorch/issues/77724 .. note:: Inputs into FSDP ``forward`` function will be moved to compute device (same device FSDP module is on) before running ``forward``, so user does not have to manually move inputs from CPU -> GPU. Args: module (nn.Module): module to be wrapped with FSDP. process_group (Optional[ProcessGroup]): process group for sharding sharding_strategy (Optional[ShardingStrategy]): Config sharding algorithm, different sharding algorithm has trade off between memory saving and communication overhead. ``FULL_SHARD`` will be chosen if sharding_strategy is not specified. cpu_offload (Optional[CPUOffload]): CPU offloading config. Currently, only parameter and gradient CPU offload is supported. It can be enabled via passing in ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently implicitly enables gradient offloading to CPU in order for params and grads to be on same device to work with optimizer. This API is subject to change. Default is ``None`` in which case there will be no offloading. auto_wrap_policy (Optional[Callable[[nn.Module, bool, int], bool]]): A callable specifying a policy to recursively wrap layers with FSDP. Note that this policy currently will only apply to child modules of the passed in module. The remainder modules are always wrapped in the returned FSDP root instance. ``size_based_auto_wrap_policy`` written in ``torch.distributed.fsdp.wrap`` is an example of ``auto_wrap_policy`` callable, this policy wraps layers with the number of parameters larger than 100M. ``transformer_auto_wrap_policy`` written in ``torch.distributed.fsdp.wrap`` is an example of ``auto_wrap_policy`` callable for transformer-like model architectures. Users can supply the customized ``auto_wrap_policy`` callable that should accept following arguments: ``module: nn.Module``, ``recurse: bool``, ``unwrapped_params: int``, and return a ``bool`` specifying whether the passed in ``module``` should be wrapped (if ``recurse=False``) or whether we should recurse down the subgraph of ``module`` children (if ``recurse=True``). Extra customized arguments could be added to the customized ``auto_wrap_policy`` callable as well. It is a good practice to print out the sharded model and check whether the sharded model is what the application wants and then adjust accordingly. Example:: >>> def custom_auto_wrap_policy( >>> module: nn.Module, >>> recurse: bool, >>> unwrapped_params: int, >>> # These are customizable for this policy function. >>> min_num_params: int = int(1e8), >>> ) -> bool: >>> return unwrapped_params >= min_num_params >>> # Configure a custom min_num_params >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=1e5) backward_prefetch (Optional[BackwardPrefetch]): This is an experimental feature that is subject to change in the the near future. It allows users to enable two different backward_prefetch algorithms to help backward communication and computation overlapping. Pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. mixed_precision (Optional[MixedPrecision]): A ``MixedPrecision`` instance describing the mixed precision training config to be used. ``MixedPrecision`` supports configuring parameter, buffer, and gradient communication dtype. Note that only floating point data is cast to the reduced precision. This allows users potential memory saving and training speedup while trading off accuracy during model training. If ``None``, no mixed precision is applied. Note that if ``mixed_precision`` is enabled for FSDP model that contains ``BatchNorm`` with ``auto_wrap_policy``, FSDP will take care to disable mixed precision for ``BatchNorm`` units by wrapping them separately in their own FSDP unit with ``mixed_precision=None``. This is done because several ``BatchNorm`` kernels do not implement reduced type support at the moment. If individually wrapping the model, users must take care to set ``mixed_precision=None`` for ``BatchNorm`` units. (Default: ``None``) ignored_modules (Optional[Iterable[torch.nn.Module]]): Modules whose own parameters and child modules' parameters and buffers are ignored by this instance. None of the modules directly in ``ignored_modules`` should be :class:`FullyShardedDataParallel` instances, and any child modules that are already-constructed :class:`FullyShardedDataParallel` instances will not be ignored if they are nested under this instance. This argument may be used to avoid sharding specific parameters at module granularity when using an ``auto_wrap_policy`` or if parameters' sharding is not managed by FSDP. (Default: ``None``) param_init_fn (Optional[Callable[[nn.Module], None]]): A ``Callable[torch.nn.Module] -> None`` that specifies how modules that are currently on the meta device should be initialized onto an actual device. Note that as of v1.12, we detect modules on the meta device via ``is_meta`` check and apply a default initialization that calls ``reset_parameters`` method on the passed in ``nn.Module`` if ``param_init_fn`` is not specified, otherwise we run ``param_init_fn`` to initialize the passed in ``nn.Module``. In particular, this means that if ``is_meta=True`` for any module parameters for modules that will be wrapped with FSDP and ``param_init_fn`` is not specified, we assume your module properly implements a ``reset_paramters()`` and will throw errors if not. Note that additionally, we offer support for modules initialized with torchdistX's (https://github.com/pytorch/torchdistX) ``deferred_init`` API. In this case, deferred modules would be initialized by a default initialization function that calls torchdistX's ``materialize_module``, or the passed in ``param_init_fn``, if it is not ``None``. The same ``Callable`` is applied to initialize all meta modules. Note that this initialization function is applied before doing any FSDP sharding logic. Example:: >>> # xdoctest: +SKIP("undefined variables") >>> module = MyModule(device="meta") >>> def my_init_fn(module): >>> # responsible for initializing a module, such as with reset_parameters >>> ... >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy) >>> print(next(fsdp_model.parameters()).device) # current CUDA device >>> # With torchdistX >>> module = deferred_init.deferred_init(MyModule, device="cuda") >>> # Will initialize via deferred_init.materialize_module(). >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy) device_id (Optional[Union[int, torch.device]]): An ``int`` or ``torch.device`` describing the CUDA device the FSDP module should be moved to determining where initialization such as sharding takes place. If this argument is not specified and ``module`` is on CPU, we issue a warning mentioning that this argument can be specified for faster initialization. If specified, resulting FSDP instances will reside on this device, including moving ignored modules' parameters if needed. Note that if ``device_id`` is specified but ``module`` is already on a different CUDA device, an error will be thrown. (Default: ``None``) sync_module_states (bool): If ``True``, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to ensure they are the same across all ranks after initialization. This helps ensure model parameters are the same across ranks before starting training, but adds communication overhead to ``__init__``, as at least one broadcast is triggered per individually wrapped FSDP unit. This can also help load checkpoints taken by ``state_dict`` and to be loaded by ``load_state_dict`` in a memory efficient way. See documentation for :class:`FullStateDictConfig` for an example of this. (Default: ``False``) forward_prefetch (bool): If ``True``, then FSDP *explicitly* prefetches the next upcoming all-gather while executing in the forward pass. This may improve communication and computation overlap for CPU bound workloads. This should only be used for static graph models since the forward order is fixed based on the first iteration's execution. (Default: ``False``) limit_all_gathers (bool): If ``False``, then FSDP allows the CPU thread to schedule all-gathers without any extra synchronization. If ``True``, then FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. This ``bool`` only affects the sharded strategies that schedule all-gathers. Enabling this can help lower the number of CUDA malloc retries. """def__init__(self,module:nn.Module,process_group:Optional[ProcessGroup]=None,sharding_strategy:Optional[ShardingStrategy]=None,cpu_offload:Optional[CPUOffload]=None,auto_wrap_policy:Optional[Callable]=None,backward_prefetch:Optional[BackwardPrefetch]=None,mixed_precision:Optional[MixedPrecision]=None,ignored_modules:Optional[Iterable[torch.nn.Module]]=None,param_init_fn:Optional[Callable[[nn.Module],None]]=None,device_id:Optional[Union[int,torch.device]]=None,sync_module_states:bool=False,forward_prefetch:bool=False,limit_all_gathers:bool=False,):ifisinstance(auto_wrap_policy,ParamExecOrderWrapPolicy):self._init_param_exec_order_wrap_policy(module=module,process_group=process_group,sharding_strategy=sharding_strategy,cpu_offload=cpu_offload,auto_wrap_policy=auto_wrap_policy,backward_prefetch=backward_prefetch,mixed_precision=mixed_precision,ignored_modules=ignored_modules,param_init_fn=param_init_fn,device_id=device_id,sync_module_states=sync_module_states,forward_prefetch=forward_prefetch,limit_all_gathers=limit_all_gathers,)returntorch._C._log_api_usage_once("torch.distributed.fsdp")super().__init__()self._ignored_modules=self._get_ignored_modules(module,ignored_modules)ignored_params,self._ignored_param_names=self._get_ignored_params(module,self._ignored_modules)self._buffer_names=self._get_buffer_names(module)ifauto_wrap_policyisnotNone:auto_wrap_kwargs={"module":module,"auto_wrap_policy":auto_wrap_policy,"wrapper_cls":FullyShardedDataParallel,"ignored_modules":self._ignored_modules,"ignored_params":ignored_params,"only_wrap_children":True,# avoid double wrapping the root}fsdp_kwargs={"process_group":process_group,"sharding_strategy":sharding_strategy,"cpu_offload":cpu_offload,"backward_prefetch":backward_prefetch,"mixed_precision":mixed_precision,"param_init_fn":param_init_fn,"device_id":device_id,"sync_module_states":sync_module_states,"forward_prefetch":forward_prefetch,"limit_all_gathers":limit_all_gathers,}self._auto_wrap(auto_wrap_kwargs,fsdp_kwargs)self.process_group=process_groupor_get_default_group()self.rank=self.process_group.rank()self.world_size=self.process_group.size()self.training_state=TrainingState_.IDLEself.cpu_offload=cpu_offloadorCPUOffload()self.backward_prefetch=backward_prefetchself.forward_prefetch=forward_prefetchself.limit_all_gathers=limit_all_gathersbackward_prefetch_limit=1forward_prefetch_limit=1# We clamp the strategy to `NO_SHARD` for world size of 1 since they# are currently functionally equivalent. This may change if/when we# integrate FSDP with MoE.ifself.world_size==1:sharding_strategy=ShardingStrategy.NO_SHARDself.sharding_strategy=sharding_strategyorShardingStrategy.FULL_SHARDself.mixed_precision=mixed_precisionorMixedPrecision()# Save a mapping from fully prefixed buffer name to its original dtype# since for mixed precision, buffers are restored to their original# dtype for model checkpointingself._buffer_name_to_orig_dtype:Dict[str,torch.dtype]={}self._check_single_device_module(module,ignored_params)device_from_device_id:Optional[torch.device]=self._get_device_from_device_id(device_id)self._materialize_module(module,param_init_fn,device_from_device_id)self._move_module_to_device(module,ignored_params,device_from_device_id)self.compute_device=self._get_compute_device(module,ignored_params,device_from_device_id)params_to_flatten=list(self._get_orig_params(module,ignored_params))ifsync_module_states:self._sync_module_states(module,params_to_flatten)# This FSDP instance's handles should inherit the same process group,# compute device, CPU offload, and mixed precision settings. However,# different sharding strategies are allowed.config=HandleConfig(sharding_strategy_map[self.sharding_strategy],self.cpu_offload.offload_params,self.mixed_precision.param_dtype,self.mixed_precision.reduce_dtype,self.mixed_precision.keep_low_precision_grads,)self._fsdp_wrapped_module=FlattenParamsWrapper(module,params_to_flatten,self.compute_device,config,)self._check_orig_params_flattened(ignored_params)# Invariant: `self.params` contains exactly the `FlatParameter`s of the# handles in `self._handles`self._handles:List[FlatParamHandle]=[]self.params:List[FlatParameter]=[]ifself._fsdp_wrapped_module.has_params:handle=self._fsdp_wrapped_module.handleself.params.append(handle.flat_param)self._register_param_handle(handle)handle.shard(self.process_group)ifself.cpu_offload.offload_paramsandhandle.flat_param.device!=torch.device("cpu"):withtorch.no_grad():handle._flat_param_to(torch.device("cpu"))self._sync_gradients=Trueself._communication_hook=self._get_default_comm_hook()self._communication_hook_state=self._get_default_comm_hook_state()self._hook_registered=False# Used to prevent running the pre-backward hook multiple timesself._ran_pre_backward_hook:Dict[_HandlesKey,bool]={}self._is_root:Optional[bool]=None# `None` indicates not yet set# The following attributes are owned by the root FSDP instance and# shared with non-root FSDP instancesself._streams:Dict[str,torch.cuda.Stream]={}self._free_event_queue=_FreeEventQueue()self._debug_level=dist.get_debug_level()self._exec_order_data=_ExecOrderData(self._debug_level,backward_prefetch_limit,forward_prefetch_limit,)self._handles_prefetched:Dict[_HandlesKey,bool]={}# Used for guarding against mistargeted backward prefetchesself._needs_pre_backward_unshard:Dict[_HandlesKey,bool]={}# Used for guarding against mistargeted forward prefetchesself._needs_pre_forward_unshard:Dict[_HandlesKey,bool]={}# The data structures use tuples of handles to generalize over the case# where a module's forward involves multiple handles.# `_state_dict_type` controls the `state_dict()` behavior, which is# implemented using post-save and pre-load hooksself._state_dict_type=StateDictType.FULL_STATE_DICTself._state_dict_config=FullStateDictConfig()self._register_state_dict_hook(self._post_state_dict_hook)self._post_state_dict_hook_fn={StateDictType.FULL_STATE_DICT:self._full_post_state_dict_hook,StateDictType.LOCAL_STATE_DICT:self._local_post_state_dict_hook,StateDictType.SHARDED_STATE_DICT:self._sharded_post_state_dict_hook,}self._register_load_state_dict_pre_hook(self._pre_load_state_dict_hook,with_module=True)self._pre_load_state_dict_hook_fn={StateDictType.FULL_STATE_DICT:self._full_pre_load_state_dict_hook,StateDictType.LOCAL_STATE_DICT:self._local_pre_load_state_dict_hook,StateDictType.SHARDED_STATE_DICT:self._sharded_pre_load_state_dict_hook,}self.register_load_state_dict_post_hook(self._post_load_state_dict_hook)self._post_load_state_dict_hook_fn={StateDictType.FULL_STATE_DICT:self._full_post_load_state_dict_hook,StateDictType.LOCAL_STATE_DICT:self._local_post_load_state_dict_hook,StateDictType.SHARDED_STATE_DICT:self._sharded_post_load_state_dict_hook,}def_get_ignored_modules(self,root_module:nn.Module,_ignored_modules:Optional[Iterable[torch.nn.Module]],)->Set[nn.Module]:""" Checks that ``_ignored_modules`` is an iterable of ``nn.Module`` s without any FSDP instances, and returns the modules contained in their module subtrees as a :class:`set`. Nested FSDP instances are excluded, but their already-computed ignored modules are included. """if_ignored_modulesisNone:returnset()msg_prefix="`ignored_modules` should be an iterable of `torch.nn.Module`s "try:ignored_root_modules=set(_ignored_modules)exceptTypeError:raiseTypeError(msg_prefix+f"but got {type(_ignored_modules)}")formoduleinignored_root_modules:ifnotisinstance(module,torch.nn.Module):raiseTypeError(msg_prefix+f"but got an iterable with {type(module)}")ifisinstance(module,FullyShardedDataParallel):raiseValueError("`ignored_modules` should not include FSDP modules")# Include child modules and exclude nested FSDP modules themselvesignored_modules=set(childformoduleinignored_root_modulesforchildinmodule.modules()ifnotisinstance(child,(FullyShardedDataParallel,FlattenParamsWrapper)))ifroot_moduleinignored_modules:warnings.warn("Trying to ignore the top-level module passed into the FSDP ""constructor itself will result in all parameters being "f"ignored and is not well-supported: {module}")# Include nested FSDP modules' ignored modulesforsubmoduleinroot_module.modules():ifisinstance(submodule,FullyShardedDataParallel):asserthasattr(submodule,"_ignored_modules")ignored_modules.update(submodule._ignored_modules)returnignored_modulesdef_get_ignored_params(self,root_module:torch.nn.Module,ignored_modules:Set[torch.nn.Module],)->Tuple[Set[torch.nn.Parameter],Set[str]]:""" Returns the parameters of the modules in ``ignored_modules``, excluding any :class:`FlatParameter` s, and their fully prefixed names, both as :class:`set` s. """ignored_params=set(pforminignored_modulesforpinm.parameters()ifnot_is_fsdp_flattened(p))# Conservatively include all shared parameters' namesparam_to_unflat_param_names=_get_param_to_unflat_param_names(root_module,dedup_shared_params=False,)ignored_param_names=set()forparaminignored_params:unflat_param_names=param_to_unflat_param_names[param]clean_names=[]forkinunflat_param_names:# Clean any module wrapper prefixes in case of nested wrappingclean_names.append(clean_tensor_name(k))ignored_param_names.update(clean_names)returnignored_params,ignored_param_namesdef_get_buffer_names(self,root_module:nn.Module)->Set[str]:""" Returns the fully prefixed names of all buffers in the module hierarchy rooted at ``root_module`` as a class:`set`. """defmodule_fn(module:nn.Module,prefix:str,buffer_names:Set[str]):# For FSDP modules, only add the entry when considering the# contained `FlattenParamsWrapper` to avoid duplicationifnotisinstance(module,FullyShardedDataParallel):forbuffer_name,_inmodule.named_buffers(recurse=False):# Clean module wrapper prefixes in case of nested wrappingprefixed_buffer_name=clean_tensor_name(prefix+buffer_name)buffer_names.add(prefixed_buffer_name)defreturn_fn(buffer_names:Set[str],*args):returnbuffer_namesbuffer_names:Set[str]=set()return_apply_to_modules(root_module,module_fn,return_fn,buffer_names,)def_auto_wrap(self,auto_wrap_kwargs:Dict[str,Any],fsdp_kwargs:Dict[str,Any],)->None:""" Recursively auto wraps the root module given by the key "module" in ``auto_wrap_kwargs`` with the arguments in ``auto_wrap_kwargs`` and ``fsdp_kwargs``. Precondition: ``auto_wrap_policy`` contains the arguments expected by ``_recursive_wrap()``, where ``auto_wrap_policy`` is not ``None``. ``fsdp_kwargs`` contains all FSDP arguments except ``module``. """auto_wrap_policy=auto_wrap_kwargs["auto_wrap_policy"]root_module=auto_wrap_kwargs["module"]assertauto_wrap_policyisnotNone# For auto wrapping, submodules should not already be wrapped with FSDP# since double wrapping is not supportedformodule_name,moduleinroot_module.named_modules():ifisinstance(module,FullyShardedDataParallel):raiseValueError(f"Expected {module_name} to NOT be FullyShardedDataParallel ""if using an `auto_wrap_policy`")mixed_precision=fsdp_kwargs["mixed_precision"]ifmixed_precisionisnotNoneand_contains_batchnorm(root_module):_override_batchnorm_mixed_precision(root_module)auto_wrap_policy=functools.partial(_or_policy,policies=[_wrap_batchnorm_individually,auto_wrap_policy])warnings.warn("Both mixed precision and an `auto_wrap_policy` were specified ""for FSDP, where the wrapped module has batch norm submodules. ""The batch norm submodules will be wrapped as separate FSDP ""instances with mixed precision disabled since some batch norm ""kernels do not support low precision.")auto_wrap_kwargs["auto_wrap_policy"]=auto_wrap_policy_recursive_wrap(**auto_wrap_kwargs,**fsdp_kwargs)def_check_single_device_module(self,module:nn.Module,ignored_params:Set[nn.Parameter],)->None:""" Raises an error if ``module`` has original parameters on multiple devices, ignoring the parameters in ``ignored_params``. Thus, after this method, the module must be either fully on the CPU or fully on a non-CPU device. """devices=set(param.deviceforparaminself._get_orig_params(module,ignored_params))iflen(devices)>1:raiseRuntimeError(f"FSDP only supports single device modules but got params on {devices}")def_get_device_from_device_id(self,device_id:Optional[Union[int,torch.device]],)->Optional[torch.device]:""" """ifdevice_idisNone:returnNonedevice=(device_idifisinstance(device_id,torch.device)elsetorch.device(device_id))ifdevice==torch.device("cuda"):warnings.warn(f"FSDP got the argument `device_id` {device_id} on rank "f"{self.rank}, which does not have an explicit index. "f"FSDP will use the current device {torch.cuda.current_device()}. ""If this is incorrect, please explicitly call `torch.cuda.set_device()` ""before FSDP initialization or pass in the explicit device ""index as the `device_id` argument.")device=torch.device("cuda",torch.cuda.current_device())returndevicedef_materialize_module(self,module:nn.Module,param_init_fn:Optional[Callable[[nn.Module],None]],device_from_device_id:Optional[torch.device],)->None:""" Materializes the wrapped module ``module`` in place if needed: either if the module has parameters that use meta device or are torchdistX fake tensors. This method uses ``param_init_fn`` to materialize the module if the function is not ``None`` and falls back to default behavior otherwise. For meta device, this moves the module to ``device_from_device_id`` if it is not ``None`` or the current device otherwise and calls ``reset_parameters()``, and for torchdistX fake tensors, this calls ``deferred_init.materialize_module()``. """is_meta_module=any(p.is_metaforpinmodule.parameters())is_torchdistX_deferred_init=(notis_meta_moduleand_TORCHDISTX_AVAILandany(fake.is_fake(p)forpinmodule.parameters()))if(is_meta_moduleoris_torchdistX_deferred_init)andparam_init_fnisnotNone:ifnotcallable(param_init_fn):raiseValueError(f"Expected {param_init_fn} to be callable but got {type(param_init_fn)}")param_init_fn(module)elifis_meta_module:# Run default meta device initializationmaterialization_device=device_from_device_idortorch.cuda.current_device()module.to_empty(device=materialization_device)try:withtorch.no_grad():module.reset_parameters()exceptBaseExceptionase:warnings.warn("Unable to call `reset_parameters()` for module on meta "f"device with error {str(e)}. Please ensure your ""module implements a `reset_parameters()` method.")raiseeelifis_torchdistX_deferred_init:# Run default torchdistX initializationdeferred_init.materialize_module(module,check_fn=lambdak:notisinstance(k,FullyShardedDataParallel),)def_move_module_to_device(self,module:nn.Module,ignored_params:Set[nn.Parameter],device_from_device_id:Optional[torch.device],):""" Moves ``module`` depending on ``device_from_device_id`` and its current device. This includes moving ignored modules' parameters. - If ``device_from_device_id`` is not ``None``, then this moves ``module`` to the device. - If ``device_from_device_id`` is ``None``, then this does not move ``module`` but warns the user if it is on CPU. Precondition: ``_check_single_device_module()``. """cpu_device=torch.device("cpu")param=next(self._get_orig_params(module,ignored_params),None)ifparamisNone:return# no original parameters to manageifdevice_from_device_idisnotNone:ifparam.device==cpu_device:# NOTE: This includes moving ignored modules' parameters.module=module.to(device_from_device_id)# TODO: This is a temporary fix to move already- constructed# `FlatParameter`s back to CPU if needed. This is needed to# make CPU offload work with `device_id`.forsubmoduleinmodule.modules():if(isinstance(submodule,FullyShardedDataParallel)andsubmodule.cpu_offload.offload_params):withtorch.no_grad():forhandleinsubmodule._handles:handle._flat_param_to(torch.device("cpu"))elifparam.device==cpu_device:warnings.warn("Module is put on CPU and will thus have flattening and sharding"" run on CPU, which is less efficient than on GPU. We recommend passing in ""`device_id` argument which will enable FSDP to put module on GPU device,"" module must also be on GPU device to work with `sync_module_states=True` flag"" which requires GPU communication.")def_get_compute_device(self,module:nn.Module,ignored_params:Set[nn.Parameter],device_from_device_id:Optional[torch.device],)->torch.device:""" Determines and returns this FSDP instance's compute device. If the module is already on a non-CPU device, then the compute device is that non-CPU device. If the module is on CPU, then the compute device is the current device. Since this method should be called after materializing the module, any non-CPU device should not be meta device. For now, the compute device is always a CUDA GPU device with its explicit index. Precondition: ``_check_single_device_module()`` and ``_move_module_to_device()``. """# If the module is on GPU already, then that GPU device has priority# over the current deviceparam=next(self._get_orig_params(module,ignored_params),None)ifparamisnotNoneandparam.device.type=="cuda":compute_device=param.deviceelse:compute_device=torch.device("cuda",torch.cuda.current_device())if(device_from_device_idisnotNoneandcompute_device!=device_from_device_id):raiseValueError("Inconsistent compute device and `device_id` on rank "f"{self.rank}: {compute_device} vs {device_from_device_id}")returncompute_devicedef_sync_module_states(self,module:nn.Module,params:List[nn.Parameter])->None:""" Synchronizes module states (i.e. parameters ``params`` and all not-yet-synced buffers) by broadcasting from rank 0 to all ranks. Precondition: ``sync_module_states == True`` and ``self.process_group`` has been set. """ifparamsandany(param.device==torch.device("cpu")forparaminparams):raiseValueError("Module has CPU parameters, but sync_module_states=True is specified.""This only works for GPU module, please specify `device_id` argument or move"" module to GPU before init.")module_states:List[torch.Tensor]=[]# TODO (awgu): When exposing the original parameters, we need to also# use this attribute to prevent re-synchronizing parameters.forbufferinmodule.buffers():# Avoid re-synchronizing buffers in case of nested wrappingifnotgetattr(buffer,"_fsdp_synced",False):buffer._fsdp_synced=Truemodule_states.append(buffer.detach())module_states.extend(param.detach()forparaminparams)_sync_params_and_buffers(self.process_group,module_states,_PARAM_BROADCAST_BUCKET_SIZE,src=0,)def_get_orig_params(self,module:nn.Module,ignored_params:Set[nn.Parameter],)->Iterator[nn.Parameter]:""" Returns an iterator over the original parameters in ``module``, ignoring the parameters in ``ignored_params`` and any ``FlatParameter`` s (which may be present due to nested FSDP wrapping). """param_gen=module.parameters()try:whileTrue:param=next(param_gen)ifparamnotinignored_paramsandnot_is_fsdp_flattened(param):yieldparamexceptStopIteration:passdef_check_orig_params_flattened(self,ignored_params:Set[nn.Parameter])->None:""" Checks that all original parameters have been flattened and hence made invisible to ``named_parameters()``. This should be called as a sanity check after flattening the wrapped module's parameters. """forparam_name,paraminself.named_parameters():ifparamnotinignored_paramsandnot_is_fsdp_flattened(param):raiseRuntimeError(f"Found an unflattened parameter: {param_name}; "f"{param.size()}{param.__class__}")def_register_param_handle(self,handle:FlatParamHandle)->None:"""Registers the parameter handle to this FSDP instance."""ifhandlenotinself._handles:self._handles.append(handle)@torch.no_grad()def_unshard(self,handles:List[FlatParamHandle],)->None:""" Unshards the handles in ``handles``. If the handles are in :meth:`summon_full_params` and are using mixed precision, then they are forced to full precision. Postcondition: Each handle's ``FlatParameter`` 's data is the padded unsharded flattened parameter on the compute device. """ifnothandles:returnifself.limit_all_gathers:event=self._free_event_queue.dequeue_if_needed()ifevent:event.synchronize()any_ran_pre_unshard=Falsewithtorch.cuda.stream(self._streams["pre_all_gather"]):forhandleinhandles:ran_pre_unshard=handle.pre_unshard()any_ran_pre_unshard=any_ran_pre_unshardorran_pre_unshardifany_ran_pre_unshard:self._streams["all_gather"].wait_stream(self._streams["pre_all_gather"])withtorch.cuda.stream(self._streams["all_gather"]):forhandleinhandles:handle.unshard()handle.post_unshard()def_reshard(self,# unusedhandles:List[FlatParamHandle],free_unsharded_flat_params:List[bool],)->None:""" Reshards the handles in ``handles``. ``free_unsharded_flat_params`` should have the same length as ``handles``, and each element should give whether the corresponding handle should free its padded unsharded flattened parameter. """ifnothandles:returnp_assert(len(handles)==len(free_unsharded_flat_params),"Expects both lists to have equal length but got "f"{len(handles)} and {len(free_unsharded_flat_params)}")forhandle,free_unsharded_flat_paraminzip(handles,free_unsharded_flat_params,):handle.reshard(free_unsharded_flat_param)ifself.limit_all_gathersandfree_unsharded_flat_param:free_event=torch.cuda.Event()free_event.record()self._free_event_queue.enqueue(free_event)handle.post_reshard()# Since we prefetch entire handles keys at a time, conservatively mark# the entire key as no longer prefetched once we free at least onehandles_key=tuple(handles)ifany(free_unsharded_flat_params):self._handles_prefetched.pop(handles_key,None)@propertydefmodule(self)->nn.Module:""" Returns the wrapped module (like :class:`DistributedDataParallel`). """assertisinstance(self._fsdp_wrapped_module,FlattenParamsWrapper)returnself._fsdp_wrapped_module.moduledef__getattr__(self,name:str)->Any:"""Forward missing attributes to wrapped module."""try:returnsuper().__getattr__(name)# defer to nn.Module's logicexceptAttributeError:returngetattr(self._fsdp_wrapped_module,name)def__getitem__(self,key:int)->Any:"""Forward indexing calls in case the module is a nn.Sequential."""returnself._fsdp_wrapped_module.__getitem__(key)# type: ignore[operator]defcheck_is_root(self)->bool:self._lazy_init()assertself._is_rootisnotNonereturnself._is_root
[docs]@staticmethoddeffsdp_modules(module:nn.Module,root_only:bool=False,)->List["FullyShardedDataParallel"]:""" Returns all nested FSDP instances, possibly including ``module`` itself and only including FSDP root modules if ``root_only=True``. Args: module (torch.nn.Module): Root module, which may or may not be an ``FSDP`` module. root_only (bool): Whether to return only FSDP root modules. (Default: ``False``) Returns: List[FullyShardedDataParallel]: FSDP modules that are nested in the input ``module``. """return[submoduleforsubmoduleinmodule.modules()ifisinstance(submodule,FullyShardedDataParallel)and(notroot_onlyorsubmodule.check_is_root())]
[docs]defapply(self,fn:Callable[[nn.Module],None])->"FullyShardedDataParallel":r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). Compared to ``torch.nn.Module.apply``, this version additionally gathers the full parameters before applying ``fn``. It should not be called from within another ``summon_full_params`` context. Args: fn (:class:`Module` -> None): function to be applied to each submodule Returns: Module: self """uninitialized=self._is_rootisNoneself._assert_state(TrainingState_.IDLE)withself._summon_full_params(recurse=False,writeback=True):ret=super().apply(fn)# Reset lazy init that might be called by _summon_full_params, since# it could have set is_root incorrectly for non-root FSDP instances.ifuninitializedandself._is_root:formoduleinself.fsdp_modules(self):module._reset_lazy_init()returnret
def_mixed_precision_enabled_for_params(self)->bool:""" Whether user explicitly enabled mixed precision for parameters or not. """returnself.mixed_precision.param_dtypeisnotNonedef_mixed_precision_enabled_for_buffers(self)->bool:""" Whether user explicitly enabled mixed precision for buffers or not. """returnself.mixed_precision.buffer_dtypeisnotNonedef_mixed_precision_enabled_for_reduce(self)->bool:""" Whether user explicitly enabled mixed precision for gradient reduction or not. """returnself.mixed_precision.reduce_dtypeisnotNonedef_mixed_precision_keep_low_precision_grads(self)->bool:return(self.mixed_precisionisnotNoneandself.mixed_precision.keep_low_precision_grads)def_low_precision_hook_enabled(self)->bool:""" Wether a low precision hook is registered or not. """return(self._communication_hookisnotNoneandself._communication_hookinLOW_PRECISION_HOOKS)def_cast_fp_inputs_to_dtype(self,dtype:torch.dtype,*args:Any,**kwargs:Any)->Tuple[Any,Any]:""" Casts floating point tensors in ``args`` and ``kwargs`` to the precision given by ``dtype``, while respecting the existing ``requires_grad`` on the tensors. """defcast_fn(x:torch.Tensor)->torch.Tensor:ifnottorch.is_floating_point(x):returnxy=x.to(dtype)# Explicitly copy over `requires_grad` since this runs inside# `torch.no_grad()`ifx.is_leaf:y.requires_grad=x.requires_gradreturnywithtorch.no_grad():return(_apply_to_tensors(cast_fn,args),_apply_to_tensors(cast_fn,kwargs))def_cast_buffers(self,device:Optional[torch.device]=None,dtype:Optional[Dict[str,torch.dtype]]=None,memo:Optional[Set]=None,recurse:bool=True,)->None:"""Move all buffers to the given *device* and *dtype*. If *device* is not given, then it will default to ``self.compute_device``, otherwise buffer will be moved to ``device``. In the case of nested FSDP instances, we will respect the child instance's ``compute_device`` configuration. If *dtype* is given, it must be a mapping of buffer name to buffer dtype, and this argument is currently only given to restore back to original buffer types during checkpoint. If *dtype* is not given, and we are in mixed precision training, the buffer will be cast to buffer_dtype, otherwise the buffer will not be cast. Args: device (torch.device, Optional): device to cast buffers to (defaults to compute_device) dtype: (Dict[str, torch.dtype], Optional): Mapping of buffer name to their dtype to cast to. memo (Set, Optional): set of modules that have already been processed recurse (bool, Optional): Whether to call _cast_buffers recursively on nested FSDP instances (default is True). """ifmemoisNone:memo=set()formoduleinself.modules():ifmoduleisnotselfandisinstance(module,FullyShardedDataParallel)andrecurse:# Allow any child FSDP instances to handle their own buffers.module._cast_buffers(device=device,dtype=dtype,memo=memo,recurse=recurse)elifmodulenotinmemo:memo.add(module)forname,bufinmodule.named_buffers(recurse=False):ifbufisNone:continuebuf=buf.to(device=deviceorself.compute_device)ifnamenotinself._buffer_name_to_orig_dtype:self._buffer_name_to_orig_dtype[name]=buf.dtype# If given, cast buffer to the given dtype. This is used to# suppport mixed precision for buffers# (given by self.mixed_precision.buffer_dtype) and also used# to restore the buffer dtype to the original precision for# state_dict() calls.# Note that non-floating point buffers are not casted.iftorch.is_floating_point(buf):# We are restoring the original buffer type in# preparation for checkpoint.ifdtype:buf=buf.to(dtype=dtype[name])# Note that we don't pass in self.mixed_precision.buffer_dtype# recursively into _cast_buffers, as we want to respect# mp config for child FSDP instances.elifself._mixed_precision_enabled_for_buffers():buf=buf.to(self.mixed_precision.buffer_dtype)setattr(module,name,buf)def_reset_lazy_init(self)->None:""" Reset instance so :func:`_lazy_init` will run on the next forward. """self._is_root:Optional[bool]=Noneforpinself.params:ifhasattr(p,"_local_shard"):# We only need to `del` `_local_shard` because# `_init_param_attributes()` gates the logic based on its# existence (and not any of the other attributes).delp._local_shard# type: ignore[attr-defined]def_lazy_init(self)->None:""" Performs initialization lazily, typically right before the first forward pass. The laziness is needed to ensure that the parameter device/dtype and the FSDP hierarchy have finalized. This method's actual logic only runs on the root FSDP instance, which performs initialization for all non-root FSDP instances to avoid partial initialization. """ifself._is_rootisnotNone:return# no-op: already initializedifnottorch.cuda.is_available():# Allow the FSDP constructor to run even with CUDA but check this# once we start real executionraiseRuntimeError("FSDP does not support CPU only execution")# The following logic is only run on the root FSDP instance since it# will set `_is_root=False` for the non-root instancesself._is_root=Trueself._assert_state(TrainingState_.IDLE)self._init_streams()self._cast_buffers(recurse=True)forhandleinself._handles:self._init_param_attributes(handle)self._exec_order_data.init(self,self.process_group)# Initialize non-root FSDP instances and share attributes from the root# to non-root instancesinconsistent_limit_all_gathers=Falseforfsdp_moduleinself.fsdp_modules(self):iffsdp_moduleisnotself:# Relax the assert for non-root FSDP instances in case the# nested initialized module is wrapped again in FSDP later (e.g.# after training to run inference)assertfsdp_module._is_rootisNoneornotfsdp_module._is_root,("Non-root FSDP instance's `_is_root` should not have been ""set yet or should have been set to `False`")fsdp_module._is_root=Falsefsdp_module._streams=self._streamsfsdp_module._exec_order_data=self._exec_order_dataiffsdp_module.limit_all_gathers!=self.limit_all_gathers:# Prefer the root's valueinconsistent_limit_all_gathers=Truefsdp_module.limit_all_gathers=self.limit_all_gathersfsdp_module._free_event_queue=self._free_event_queuefsdp_module._handles_prefetched=self._handles_prefetchedfsdp_module._needs_pre_backward_unshard=self._needs_pre_backward_unshardforhandleinfsdp_module._handles:fsdp_module._init_param_attributes(handle)ifinconsistent_limit_all_gathers:warnings.warn("Found inconsistent `limit_all_gathers` values across FSDP "f"instances on rank {self.rank}. Using the root FSDP's value "f"of {self.limit_all_gathers} for all instances.")# TODO (awgu): Move this to the `FlatParamHandle` class later@torch.no_grad()def_init_param_attributes(self,handle:FlatParamHandle)->None:""" We manage several attributes on each Parameter instance. A few attributes are set here: ``_local_shard``: a single shard of the parameter. This is needed to recover the shard after rebuilding full parameter in forward and backward. ``_full_param_padded``: the full weight (padded to be evenly divisible by ``world_size``), used for computation in the forward and backward pass. It is initialized with the appropriate size and then has its storage freed. This will be resized in place and only materialized (via all-gather) as needed. Another attribute is set by :func:`_register_post_backward_hooks`: ``_post_backward_hook_state``: it holds the parameter's AccumulateGrad object and the registered post hook handle. """p=handle.flat_param# If _local_shard has been set in the first lazy init and# current parameter is pointed to _local_shard, no need to# set the _local_shard again.ifhasattr(p,"_local_shard"):# If CPU offloading, p._local_shard should have been placed on CPU# during its first lazy construction.ifself.cpu_offload.offload_params:assertp._local_shard.device==torch.device(# type: ignore[attr-defined]"cpu"),("Expected p._local_shard to be on CPU, "# type: ignore[attr-defined]f"but it's on {p._local_shard.device}"# type: ignore[attr-defined])return# A single shard of the parameters. Also makes p._local_shard to be on# CPU if we are CPU offloading, since p.data would be on CPU during# init.ifself.cpu_offload.offload_params:assertp.device==torch.device("cpu"),("Expected param to be on CPU when cpu_offloading is enabled. ""If CPU offloading is enabled correctly, you may be ""accidentally moving the model to CUDA after FSDP initialization.")p._local_shard=p.data# type: ignore[attr-defined]# If CPU offloading, pin the memory to enable faster CPU -> GPU device# transfer.ifself.cpu_offload.offload_params:assertp._local_shard.device==torch.device("cpu")# type: ignore[attr-defined]p._local_shard=p._local_shard.pin_memory()# type: ignore[attr-defined]# When offloading parameters, also move the grad shard to CPU during# backward pass. In this case, it's important to pre-allocate the# CPU grad shard in pinned memory so that we can do a non-blocking# transfer.p._cpu_grad=torch.zeros_like(# type: ignore[attr-defined]p,device=torch.device("cpu")).pin_memory()# If mixed_precision, maintain reduced precision param shard on# compute_device for computation in fwd/bwd. We resize storage to 0 here# and rematerialize before building the full param when needed. After# fwd/bwd, it is freed and we only hold on to the full precision shard.# As a result, this reduced precision shard is not allocated if we are# not in the forward/backward pass.if(self._mixed_precision_enabled_for_params()):p._mp_shard=torch.zeros_like(p._local_shard,device=self.compute_device,dtype=self.mixed_precision.param_dtype)_free_storage(p._mp_shard)# We also maintain a full-sized parameter of type self.compute_dtype.# We resize the storage to size 0 at init (here) and only materialize# as needed. The storage may contain padding elements so that it is# evenly divisible by world_size, although these padding elements will# be removed before the relevant computation.ifhandle.uses_sharded_strategy:# type: ignore[attr-defined]# We set p._full_param_padded's dtype to the desired parameter dtype# in the case of mixed precision. This is so that when we all_gather# into full_param_padded it can occur without issues and result in# full_param_padded having the expected param_dtype.full_param_dtype=(p.dtypeifnotself._mixed_precision_enabled_for_params()elseself.mixed_precision.param_dtype)p._full_param_padded=torch.zeros(# type: ignore[attr-defined]p.numel()*self.world_size,device=self.compute_device,dtype=full_param_dtype,)p._padded_unsharded_size=p._full_param_padded.size()# type: ignore[attr-defined]_free_storage(p._full_param_padded)# type: ignore[attr-defined]ifself._mixed_precision_enabled_for_params():p._full_prec_full_param_padded=torch.zeros(# type: ignore[attr-defined]p.numel()*self.world_size,device=self.compute_device,dtype=p.dtype,# full precision)_free_storage(p._full_prec_full_param_padded)# Track whether the `FlatParameter`'s post-backward hook has been# called for validation in `_wait_for_post_backward()`p._post_backward_called=Falsedef_init_streams(self)->None:"""Initializes CUDA streams for overlapping data transfer and computation. This should only be called on the root FSDP instance."""assertself._is_rootasserttorch.cuda.is_available()# Stream for all-gathering parameters.self._streams["all_gather"]=torch.cuda.Stream()# Stream for overlapping grad reduction with the backward pass.self._streams["post_backward"]=torch.cuda.Stream()# Stream for pre-all-gather copies (e.g. H2D or precision cast).self._streams["pre_all_gather"]=torch.cuda.Stream()def_wait_for_previous_optim_step(self)->None:""" The root :class:`FullyShardedDataParallel` instance needs to synchronize with the default stream to ensure that the previous optimizer step is done. """ifnotself._is_root:returncurrent_stream=torch.cuda.current_stream()self._streams["all_gather"].wait_stream(current_stream)# Having the pre-all-gather stream wait for the current stream even if# we do not leverage the pre-all-gather stream is tolerable since this# only runs once per iterationself._streams["pre_all_gather"].wait_stream(current_stream)def_prefetch_handles(self,current_handles_key:_HandlesKey,)->None:""" Prefetches the next handles if needed (without synchronization). An empty handles key cannot prefetch. """ifnotcurrent_handles_key:returnhandles_to_prefetch=self._get_handles_to_prefetch(current_handles_key)forhandles_keyinhandles_to_prefetch:# Prefetch the next set of handles without synchronizing to allow# the sync to happen as late as possible to maximize overlapself._unshard(handles_key)self._handles_prefetched[handles_key]=Truedef_get_handles_to_prefetch(self,current_handles_key:_HandlesKey,)->List[_HandlesKey]:""" Returns a :class:`list` of the handles keys to prefetch for the next module(s), where ``current_handles_key`` represents the current module. "Prefetching" refers to running the unshard logic early (without synchronization), and the "next" modules depend on the recorded execution order and the current training state. """training_state=self._get_training_state(current_handles_key)valid_training_states=(HandleTrainingState.BACKWARD_PRE,HandleTrainingState.BACKWARD_POST,HandleTrainingState.FORWARD,)p_assert(training_stateinvalid_training_states,f"Prefetching is only supported in {valid_training_states} but "f"currently in {training_state}")eod=self._exec_order_datatarget_handles_keys:List[_HandlesKey]=[]if((training_state==HandleTrainingState.BACKWARD_PREandself.backward_prefetch==BackwardPrefetch.BACKWARD_PRE)or(training_state==HandleTrainingState.BACKWARD_POSTandself.backward_prefetch==BackwardPrefetch.BACKWARD_POST)):target_handles_keys=[target_handles_keyfortarget_handles_keyineod.get_handles_to_backward_prefetch(current_handles_key)ifself._needs_pre_backward_unshard.get(target_handles_key,False)andnotself._handles_prefetched.get(target_handles_key,False)]elif(training_state==HandleTrainingState.FORWARDandself.forward_prefetch):target_handles_keys=[target_handles_keyfortarget_handles_keyineod.get_handles_to_forward_prefetch(current_handles_key)ifself._needs_pre_forward_unshard.get(target_handles_key,False)andnotself._handles_prefetched.get(target_handles_key,False)]returntarget_handles_keysdef_get_training_state(self,handles_key:_HandlesKey,)->HandleTrainingState:"""Returns the training state of the handles in ``handles_key``."""p_assert(len(handles_key)>0,"Expects a non-empty handles key")training_states=set(handle._training_stateforhandleinhandles_key)p_assert(len(training_states)==1,f"Expects uniform training state but got {training_states}")returnnext(iter(training_states))
[docs]@staticmethod@contextlib.contextmanagerdefstate_dict_type(module:nn.Module,state_dict_type:StateDictType,state_dict_config:Optional[StateDictConfig]=None,)->Generator:""" A context manager to set the ``state_dict_type`` of all the descendant FSDP modules of the target module. The target module does not have to be a FSDP module. If the target module is a FSDP module, its ``state_dict_type`` will also be changed. .. note:: This API should be called for only the top-level (root) module. .. note:: This API enables users to transparently use the conventional ``state_dict`` API to take model checkpoints in cases where the root FSDP module is wrapped by another ``nn.Module``. For example, the following will ensure ``state_dict`` is called on all non-FSDP instances, while dispatching into `local_state_dict` implementation for FSDP: Example:: >>> # xdoctest: +SKIP("undefined variables") >>> model = DDP(FSDP(...)) >>> with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): >>> checkpoint = model.state_dict() Args: module (torch.nn.Module): Root module. state_dict_type (StateDictType): the desired ``state_dict_type`` to set. """prev_state_dict_type=Noneprev_state_dict_config=None# Use default config a state_dict config is not set.ifstate_dict_configisNone:state_dict_config=_state_dict_type_to_config[state_dict_type]()forsubmoduleinFullyShardedDataParallel.fsdp_modules(module):ifprev_state_dict_typeisNone:prev_state_dict_type=submodule._state_dict_typeifprev_state_dict_configisNone:prev_state_dict_config=submodule._state_dict_configifprev_state_dict_type!=submodule._state_dict_type:raiseRuntimeError("All FSDP module should the same state_dict_type.")iftype(prev_state_dict_config)!=type(submodule._state_dict_config):raiseRuntimeError("All FSDP modules should have the same type of state_dict_config.")expected_state_dict_config_type=_state_dict_type_to_config[state_dict_type]ifexpected_state_dict_config_type!=type(state_dict_config):raiseRuntimeError(f"Expected state_dict_config of type {expected_state_dict_config_type} but got {type(state_dict_config)}")submodule._state_dict_type=state_dict_typesubmodule._state_dict_config=state_dict_configtry:yieldfinally:assertprev_state_dict_typeisnotNone# Avoid mypy warningassertprev_state_dict_configisnotNone# Avoid mypy warningforsubmoduleinFullyShardedDataParallel.fsdp_modules(module):submodule._state_dict_type=prev_state_dict_typesubmodule._state_dict_config=prev_state_dict_config
def_convert_to_wrapped_module_name(self,module_name:str)->str:module_name=module_name.replace(f"{FPW_MODULE}.","")module_name=module_name.replace(f"{FPW_MODULE}","")ifmodule_name:module_name=f"{module_name}."# Activation checkpoint adds a prefix that has to be# removed as well.module_name=module_name.replace(f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.","")returnmodule_name@propertydef_param_fqns(self)->Iterator[Tuple[str,str,str]]:forparam_name,module_namein(self._fsdp_wrapped_module.handle.parameter_module_names()):module_name=self._convert_to_wrapped_module_name(module_name)fqn=f"{module_name}{param_name}"yieldfqn,param_name,module_namedef_full_post_state_dict_hook(self,state_dict:Dict[str,Any],prefix:str,)->Dict[str,Any]:""" Hook that runs after model.state_dict() is called before returning result to user. For FSDP, we may have to clone the tensors in state_dict as params go back to sharded version after _summon_full_params ends, and also remove "_fsdp_wrapped_module" prefix. """_replace_by_prefix(state_dict,prefix+f"{FSDP_WRAPPED_MODULE}.",prefix)self._assert_state([TrainingState_.SUMMON_FULL_PARAMS])# Return early for trivial casesifnotstate_dictornotself._fsdp_wrapped_module.has_params:returnstate_dict# If the `FlatParameter` is registered, then this rank only needed to# participate in the all-gather but does not actually save the state# dict (e.g. when `rank0_only=True` and `self.rank != 0`)ifhasattr(self._fsdp_wrapped_module,"flat_param"):returnstate_dictoffload_to_cpu=self._state_dict_config.offload_to_cpucpu_device=torch.device("cpu")# Loop only the parameters saved in self._fsdp_wrapped_module to avoid# processing buffers.forfqn,param_name,module_nameinself._param_fqns:fqn=f"{prefix}{fqn}"clean_key=fqnclean_prefix=clean_tensor_name(prefix)# Strip prefix out of key if needed as buffer names and param names# do not have prefix considered as they are not computed in `state_dict`# call.ifclean_key.startswith(clean_prefix):clean_key=clean_key[len(clean_prefix):]# Clone non-ignored parameters before exiting the# `_summon_full_params()` contextassertfqninstate_dict,(f"FSDP assumes {fqn} is in the state_dict but the state_dict "f"only has {state_dict.keys()}. prefix={prefix}, "f"module_name={module_name} param_name={param_name} rank={self.rank}.")ifclean_keynotinself._ignored_param_namesand \
notgetattr(state_dict[fqn],"_has_been_cloned",False):try:state_dict[fqn]=state_dict[fqn].clone().detach()state_dict[fqn]._has_been_cloned=True# type: ignore[attr-defined]exceptBaseExceptionase:warnings.warn(f"Failed to clone() tensor with name {fqn}. This may mean ""that this state_dict entry could point to invalid memory ""regions after returning from state_dict() call if this ""parameter is managed by FSDP. Please check clone "f"implementation of {fqn}. Error: {str(e)}")# Offload the buffer to CPU if needed -- we do not do this in# `_summon_full_params()` since without care, that would free# the original buffer's GPU memory and require reallocating# that memory later; this only affects the state dict's buffer# variable and leaves the original buffer's GPU memory intactifoffload_to_cpu:forclean_keyinself._buffer_names:# This is a hack to support activation checkpoint.clean_key=clean_key.replace(f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.","")fqn=f"{prefix}{clean_key}"iffqnnotinstate_dict:# A buffer can be registered as non-persistent.continueifstate_dict[fqn].device!=cpu_device:state_dict[fqn]=state_dict[fqn].to(cpu_device)returnstate_dictdef_local_post_state_dict_hook(self,state_dict:Dict[str,Any],prefix:str,)->Dict[str,Any]:""" This hook create a ShardedTensor from the local flat_param and replace the state_dict[f"{prefix}{FLAT_PARAM}] with the ShardedTensor. No copy will happen. The underlying storage is the same. """_replace_by_prefix(state_dict,f"{prefix}{FSDP_WRAPPED_MODULE}.",prefix)ifnotself._fsdp_wrapped_module.has_params:returnstate_dict# state_dict[f"{prefix}{FLAT_PARAM}"] exists and has the same tensor# value as the flat_param but it is a pure Tensor because# nn.Module.state_dict() will detach the parameter. Therefore, we need# to get flat_param from the FlattenParamsWrapper to get the metadata.flat_param=getattr(self._fsdp_wrapped_module,FLAT_PARAM,None)assertflat_paramisnotNone# Construct a ShardedTensor from the flat_param.full_numel=flat_param._unpadded_unsharded_size.numel()# type: ignore[attr-defined]shard_offset=flat_param.numel()*self.rankvalid_data_size=flat_param.numel()-flat_param._shard_numel_paddedifvalid_data_size>0andflat_param._shard_numel_padded>0:flat_param=flat_param.narrow(0,0,valid_data_size)local_shards=[Shard.from_tensor_and_offsets(flat_param,[shard_offset],self.rank)]state_dict[f"{prefix}{FLAT_PARAM}"]=init_from_local_shards(local_shards,full_numel,process_group=self.process_group)# type: ignore[assignment]returnstate_dict@torch.no_grad()def_sharded_post_state_dict_hook(self,state_dict:Dict[str,Any],prefix:str,)->Dict[str,Any]:""" The hook replaces the unflattened, unsharded parameter in the state_dict with a unflattened, sharded parameter (a ShardedTensor). """_replace_by_prefix(state_dict,f"{prefix}{FSDP_WRAPPED_MODULE}.",prefix)ifnotself._fsdp_wrapped_module.has_params:returnstate_dictassertself.training_state!=TrainingState_.SUMMON_FULL_PARAMS,("Inside _sharded_post_load_state_dict_hook, the training_state must ""not be SUMMON_FULL_PARAMS.")withself._summon_full_params(recurse=False,writeback=False):forfqn,_,_inself._param_fqns:# Create a ShardedTensor for the unflattened, non-sharded parameter.param=functools.reduce(getattr,fqn.split("."),self.module)state_dict[f"{prefix}{fqn}"]=_ext_chunk_tensor(tensor=param,rank=self.rank,world_size=self.world_size,num_devices_per_node=torch.cuda.device_count(),pg=self.process_group)# type: ignore[assignment]state_dict.pop(f"{prefix}{FLAT_PARAM}")returnstate_dict@staticmethoddef_post_state_dict_hook(module:nn.Module,state_dict:Dict[str,Any],prefix:str,*args:Any,)->Dict[str,Any]:""" _post_state_dict_hook() is called after the state_dict() of this FSDP module is executed. ``self._state_dict_type`` is used to decide what postprocessing will be done. """self=cast(FullyShardedDataParallel,module)processed_state_dict=self._post_state_dict_hook_fn[self._state_dict_type](state_dict,prefix)# Restore buffers, which currently are in their full precision type,# back to their mixed precision type. This is because buffers are cast# during lazy_init() and stay at their mixed precision type before/after# forward/backward. As a result state_dict() should maintain this.if(self._is_rootandself._mixed_precision_enabled_for_buffers()):self._cast_buffers(recurse=True)returnprocessed_state_dict
[docs]defstate_dict(self,*args,**kwargs):""" This is the entry point of all three FSDP ``state_dict`` APIs: full, local, and sharded. For the full state dict (``StateDictType.FULL_STATE_DICT``), FSDP attempts to unshard the model on all ranks, which may result in an OOM error if the full model cannot fit on a single GPU. In that case, users may pass in a :class:`FullStateDictConfig` to only save the checkpoint on rank 0 and/ or to offload it to CPU memory layer by layer, enabling much larger checkpoints. If the full model cannot fit in CPU memory, then users may instead take a local state dict (``StateDictType.LOCAL_STATE_DICT``) that only saves the local shard of the model. The sharded state dict (``StateDictType.SHARDED_STATE_DICT``) saves the model parameters as ``ShardedTensor`` s. The ``state_dict`` type can be configured using the :meth:`state_dict_type` context manager. Example:: >>> # xdoctest: +SKIP("undefined variables") >>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> torch.cuda.set_device(device_id) >>> my_module = nn.Linear(...) >>> sharded_module = FSDP(my_module) >>> full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) >>> with FSDP.state_dict_type(sharded_module, StateDictType.FULL_STATE_DICT, full_state_dict_config): >>> full_dict = sharded_module.state_dict() >>> full_dict.keys() >>> odict_keys(['weight', 'bias']) >>> # using local state dict >>> with FSDP.state_dict_type(sharded_module, StateDictType.LOCAL_STATE_DICT): >>> local_dict = sharded_module.state_dict() >>> local_dict.keys() >>> odict_keys(['flat_param', 'inner.flat_param']) .. warning:: This needs to be called on all ranks, since synchronization primitives may be used. """# TODO (rohan-varma): separate these out once a state_dict pre-hook# is available.iftorch.cuda.is_available():torch.cuda.synchronize()self._lazy_init()ifself._state_dict_type==StateDictType.FULL_STATE_DICT:# Get config argsfull_state_dict_config=(self._state_dict_configifself._state_dict_configisnotNoneelseFullStateDictConfig())rank0_only=full_state_dict_config.rank0_onlyoffload_to_cpu=full_state_dict_config.offload_to_cpusummon_ctx=(self._summon_full_params(recurse=False,writeback=False,offload_to_cpu=offload_to_cpu,rank0_only=rank0_only)ifself.training_state!=TrainingState_.SUMMON_FULL_PARAMSelsecontextlib.suppress())withsummon_ctx:# Since buffers are not sharded and stay casted, restore them to their# original user module specified types for checkpoint. We take care to# recast in post_state_dict_hook for consistency with the fact that# buffers stay casted after forward/backward. We must have the# call here instead of above because _summon_full_params itself# calls _lazy_init() which would cast the buffers.if(self._is_rootandself._mixed_precision_enabled_for_buffers()):self._cast_buffers(dtype=self._buffer_name_to_orig_dtype,recurse=False)state_dict=super().state_dict(*args,**kwargs)# TODO: support offload to CPU in post state dict hook.ifnotrank0_onlyorself.rank==0:returnstate_dictelse:return{}elif(self._state_dict_type==StateDictType.LOCAL_STATE_DICTorself._state_dict_type==StateDictType.SHARDED_STATE_DICT):if(self._fsdp_wrapped_module.flat_paramisnotNoneandnotself._fsdp_wrapped_module.handle.uses_sharded_strategy):raiseRuntimeError("sharded_state_dict/local_state_dict can only be called ""when parameters are flatten and sharded.")returnsuper().state_dict(*args,**kwargs)else:raiseValueError(f"Unknown StateDictType {self._state_dict_type}.")
def_local_state_dict(self,*args:Any,**kwargs:Any)->Any:""" Returns the local state of the module. Parameters are flattened and sharded, so the resulting state_dict can only be loaded after the module has been wrapped with FSDP. """withself.state_dict_type(self,StateDictType.LOCAL_STATE_DICT):returnself.state_dict(*args,**kwargs)def_full_post_load_state_dict_hook(self,*args,**kwargs)->None:# We should exit summon_full_params context.self._assert_state([TrainingState_.SUMMON_FULL_PARAMS])assertgetattr(self,'_full_param_ctx',None)isnotNoneself._full_param_ctx.__exit__(None,None,None)self._full_param_ctx=Nonedef_sharded_state_dict(self,*args:Any,**kwargs:Any)->Any:""" Returns the sharded states of the module. Parameters are unflattened and sharded, so the resulting state_dict can be used with any parallelism (e.g., DPP, model parallelism, and single trainer) after a valid resharding. """withself.set_state_dict_type(StateDictType.SHARDED_STATE_DICT):returnself.state_dict(self,*args,**kwargs)def_full_pre_load_state_dict_hook(self,state_dict:Dict[str,Any],prefix:str,)->None:# We do not expect to be calling pre-hooks twice without post-hook# call in between.assertgetattr(self,'_full_param_ctx',None)isNone# Note that it needs writeback=True to persist.self._full_param_ctx=self._summon_full_params(recurse=False,writeback=True)self._full_param_ctx.__enter__()_replace_by_prefix(state_dict,prefix,prefix+f"{FSDP_WRAPPED_MODULE}.")def_local_post_load_state_dict_hook(self,*args,**kwargs)->None:passdef_local_pre_load_state_dict_hook(self,state_dict:Dict[str,Any],prefix:str,)->None:""" This hook finds the local flat_param for this FSDP module from the state_dict. The flat_param should be a ShardedTensor. This hook converts the ShardedTensor to a tensor. No copy happen unless padding is required. """_replace_by_prefix(state_dict,prefix,f"{prefix}{FSDP_WRAPPED_MODULE}.")fqn=f"{prefix}{FSDP_WRAPPED_MODULE}.{FLAT_PARAM}"iffqnnotinstate_dict:assertgetattr(self._fsdp_wrapped_module,FLAT_PARAM,None)isNone,("No flat parameter in state_dict but self._fsdp_wrapped_module.flat_param is not None")returnload_tensor=state_dict[fqn]assertisinstance(load_tensor,ShardedTensor),"Tensors in local_state_dict should be ShardedTensor."# Convert the ShardedTensor to a Tensor.shards=load_tensor.local_shards()assertlen(shards),"load_local_state_dict assume one shard per ShardedTensor."load_tensor=cast(torch.Tensor,shards[0].tensor)# Get the metada of the flat_param to decide whether to pad the loaded# tensor.flat_param=self._fsdp_wrapped_module.flat_paramassertflat_paramisnotNoneifflat_param._shard_numel_paddednotin(0,flat_param.numel()):assertload_tensor.numel()<flat_param.numel(),(f"Local shard size = {flat_param.numel()} and the tensor in "f"the state_dict is {load_tensor.numel()}.")load_tensor=F.pad(load_tensor,[0,flat_param._shard_numel_padded])state_dict[fqn]=load_tensordef_sharded_post_load_state_dict_hook(self,*args,**kwargs)->None:passdef_sharded_pre_load_state_dict_hook(self,state_dict:Dict[str,Any],prefix:str,)->None:""" The hook combines the unflattened, sharded parameters (ShardedTensor) to a new FlatParameter and shards the new FlatParameter to the local chunk. """_replace_by_prefix(state_dict,prefix,prefix+f"{FSDP_WRAPPED_MODULE}.")ifnotself._fsdp_wrapped_module.has_params:returnifnotself._fsdp_wrapped_module.handle.uses_sharded_strategy:raiseRuntimeError("load_sharded_state_dict can only be called when parameters ""are flatten and sharded.")nonsharded_tensors=[]# TODO: Reduce the communication by using only one _all_gather_base to# gather all the parameters in this layer. This can be achieved by# concatenated all the local shards and then append the padding.# https://github.com/pytorch/pytorch/issues/77461for(param_name,_,module_name)inself._fsdp_wrapped_module.handle.flat_param._param_infos:module_name=self._convert_to_wrapped_module_name(module_name)fqn=f"{prefix}{FSDP_WRAPPED_MODULE}.{module_name}{param_name}"param=state_dict.pop(fqn)# All-gather the param (ShardedTensor)param,shards=_ext_pre_load_state_dict_transform(param)assertlen(shards)<2,(f"Expects 0 or 1 shard per rank but got {len(shards)} shards on rank {self.rank}")param_numel=param.size().numel()dim_0_size=param.size()[0]chunk_size=(math.ceil(dim_0_size/self.world_size)*param_numel//dim_0_size)ifshards:local_tensor=cast(torch.Tensor,shards[0].tensor).flatten()ifnotlocal_tensor.is_cuda:local_tensor=local_tensor.cuda()num_padding=chunk_size-local_tensor.numel()ifnum_padding>0:local_tensor=F.pad(local_tensor,[0,num_padding])else:local_tensor=torch.zeros(chunk_size,dtype=param.dtype).cuda()tensor=torch.empty(chunk_size*self.world_size,dtype=local_tensor.dtype).cuda()dist._all_gather_base(tensor,local_tensor,group=self.process_group)tensor=tensor.narrow(0,0,param_numel).reshape(param.size())nonsharded_tensors.append(tensor)# Create a new flat_param from the loaded, non-sharded tensors.flat_param=self._fsdp_wrapped_module.flat_paramloaded_flat_param=FlatParamHandle.flatten_params(nonsharded_tensors,requires_grad=False)# Get the chunk from the loaded flat_param for the local rank.loaded_flat_param,num_to_pad=FlatParamHandle._get_shard(loaded_flat_param,self.rank,self.world_size,)loaded_flat_param.to(flat_param.device)assertflat_param.numel()==loaded_flat_param.numel(),(f"The loaded local chunk has different numel({flat_param.numel()}) "f"from the local chunk {flat_param.numel()}.")assertflat_param._shard_numel_padded==num_to_pad,(f"The loaded local chunk has different padding({num_to_pad}) "f"from the local chunk {flat_param._shard_numel_padded}.")state_dict[f"{prefix}_fsdp_wrapped_module.flat_param"]=loaded_flat_param@staticmethoddef_pre_load_state_dict_hook(module:nn.Module,state_dict:Dict[str,Any],prefix:str,*args:Any,)->None:""" ``_pre_state_dict_hook` is called before ``self._load_from_state_dict()`` is called. ``self._state_dict_type`` is used to decide what preprocessing will be done. """# Code that is common for all state_dict implsself=cast(FullyShardedDataParallel,module)iftorch.cuda.is_available():torch.cuda.synchronize()# Dispatch into state_dict specific implementation of pre-hook.self._pre_load_state_dict_hook_fn[self._state_dict_type](state_dict,prefix)@staticmethoddef_post_load_state_dict_hook(module:nn.Module,*args:Any)->None:# Code that is common for all state_dict implsself=cast(FullyShardedDataParallel,module)# Dispatch into state_dict type specific implementation of post-hook for# loading state_dict.self._post_load_state_dict_hook_fn[self._state_dict_type]()
[docs]defload_state_dict(self,state_dict:Mapping[str,Any],*args,**kwargs,)->NamedTuple:""" The entry point of all three FSDP ``load_state_dict`` APIs. By default, calling ``load_state_dict`` on an FSDP module will result in FSDP attempting to load a "full" state_dict, i.e. a state_dict consisting of full, unsharded, unflattened original module parameters. This requires FSDP to load the full parameter context on each rank which could result in GPU OOM. As a result, :func:`state_dict_type` API is available to configure between ``load_state_dict`` implementations. User can thus use ``with self.state_dict_type(self, StateDictType.LOCAL_STATE_DICT)`` context manager to load a local state dict checkpoint that will restore only local shards of the module. Currently, the only supported implementations are ``StateDictType.LOCAL_STATE_DICT`` and ``StateDictType.FULL_STATE_DICT`` (default). Please see :func:`state_dict` for documentation around creating an FSDP checkpoint. Example:: >>> # xdoctest: +SKIP("undefined variables") >>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> torch.cuda.set_device(device_id) >>> my_module = nn.Linear(...) >>> sharded_module = FSDP(my_module) >>> checkpoint = torch.load(PATH) >>> full_state_dict = checkpoint['full_state_dict'] >>> with FSDP.state_dict_type(sharded_module, StateDictType.FULL_STATE_DICT): >>> sharded_module.load_state_dict(full_state_dict) >>> full_dict.keys() >>> odict_keys(['weight', 'bias']) >>> # using local state dict >>> local_state_dict = checkpoint['local_state_dict'] >>> with FSDP.state_dict_type(sharded_module, StateDictType.LOCAL_STATE_DICT): >>> sharded_module.load_state_dict(local_state_dict) >>> local_dict.keys() >>> odict_keys(['flat_param', 'inner.flat_param']) .. warning:: This needs to be called on all ranks, since synchronization primitives may be used. """returnsuper().load_state_dict(state_dict,*args)
def_load_local_state_dict(self,state_dict:Mapping[str,Any],*args,)->NamedTuple:""" Load states from a flattened, sharded state dictionary. """withself.state_dict_type(self,StateDictType.LOCAL_STATE_DICT):returnself.load_state_dict(state_dict,*args)def_load_sharded_state_dict(self,state_dict:Union[Dict[str,torch.Tensor],"OrderedDict[str, torch.Tensor]"],strict:bool=True,)->NamedTuple:""" Load states from a unflattened, sharded state dictionary. """withself.set_state_dict_type(StateDictType.SHARDED_STATE_DICT):returnself.load_state_dict(state_dict,strict)
[docs]defforward(self,*args:Any,**kwargs:Any)->Any:""" Runs the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic. """withtorch.autograd.profiler.record_function("FullyShardedDataParallel.forward"):self._lazy_init()args,kwargs=self._fsdp_root_pre_forward(*args,**kwargs)unused=Noneunshard_fn=functools.partial(self._pre_forward_unshard,handles=self._handles)# Do not free the root's parameters in the post-forward for# `FULL_SHARD` with the intention that they are immediately used# for backward computation (though this may not be true)free_unsharded_flat_params=[notself._is_rootandhandle._config.sharding_strategy==HandleShardingStrategy.FULL_SHARDforhandleinself._handles]reshard_fn=functools.partial(self._reshard,self._handles,free_unsharded_flat_params,)self._pre_forward(self._handles,unshard_fn,unused,unused)forhandleinself._handles:p_assert(handle.flat_param.device==self.compute_device,"Expected `FlatParameter` to be on the compute device "f"{self.compute_device} but got {handle.flat_param.device}")output=self._fsdp_wrapped_module(*args,**kwargs)returnself._post_forward(self._handles,reshard_fn,unused,unused,output)
def_pre_forward(self,handles:List[FlatParamHandle],unshard_fn:Optional[Callable],module:nn.Module,input:Any,):""" Runs the pre-forward logic. This includes an opportunity to unshard currently sharded parameters such as those for the current forward and registering post-backward hooks for these current parameters. Args: handles (List[FlatParamHandle]): Handles giving the parameters used in the current forward. unshard_fn (Optional[Callable]): A callable to unshard any currently sharded parameters or ``None`` to not do any unsharding. module (nn.Module): Unused; expected by the hook signature. input (Any): Unused; expected by the hook signature. """self.training_state=TrainingState_.FORWARDself._exec_order_data.record_pre_forward(handles,self.training)forhandleinhandles:handle._training_state=HandleTrainingState.FORWARDifunshard_fnisnotNone:unshard_fn()# Register post-backward hooks to reshard the parameters and# reduce-scatter their gradients. They must be re-registered every# forward pass in case the `grad_fn` is mutated.self._register_post_backward_hooks(handles)def_pre_forward_unshard(self,handles:List[FlatParamHandle],)->None:"""Unshards parameters in the pre-forward."""ifhandles:self._unshard(handles)handles_key=tuple(handles)self._needs_pre_forward_unshard[handles_key]=Falsetorch.cuda.current_stream().wait_stream(self._streams["all_gather"])self._prefetch_handles(handles_key)def_post_forward(self,handles:List[FlatParamHandle],reshard_fn:Optional[Callable],module:nn.Module,input:Any,output:Any,)->Any:""" Runs the post-forward logic. This includes an opportunity to reshard currently unsharded parameters such as those used in the current forward and registering pre-backward hooks on the forward outputs. Args: handles (List[FlatParamHandle]): Handles giving the parameters used in the current forward. reshard_fn (Optional[Callable]): A callable to reshard any currently unsharded parameters (e.g. from the current forward) or ``None`` to not do any resharding. module (nn.Module): Unused; expected by the hook signature. input (Any): Unused; exepcted by the hook signature. output (Any): Forward pass output; pre-backward hooks are registered on the tensors that require gradients in this output. Postcondition: Each ``FlatParameter`` 's data points to the sharded flattened parameter. """self._exec_order_data.record_post_forward(handles)ifreshard_fnisnotNone:reshard_fn()# Register pre-backward hooks to unshard the flattened parameters# for the gradient computation (if needed)output=self._register_pre_backward_hooks(output,handles)self.training_state=TrainingState_.IDLEforhandleinhandles:handle._training_state=HandleTrainingState.IDLEreturnoutputdef_cast_forward_inputs(self,*args,**kwargs):"""Moves the forward inputs to the compute device and casts them to the appropriate dtype if needed."""# TODO: Do not use the side stream for tensor copies for now;# investigate the perf with/without it# TODO: For mixed precision, move the inputs to the compute device and# cast to reduced-precision in a single `to()` callargs,kwargs=_to_kwargs(args,kwargs,self.compute_device.index,False)args=args[0]kwargs=kwargs[0]ifself._mixed_precision_enabled_for_params():input_dtype=self.mixed_precision.param_dtypeargs,kwargs=self._cast_fp_inputs_to_dtype(input_dtype,*args,**kwargs,)returnargs,kwargsdef_fsdp_root_pre_forward(self,*args,**kwargs):""" Runs pre-forward logic specific to the root FSDP instance, which should run before any individual module's pre-forward. This includes synchronizing with the previous iteration and casting the forward inputs appropriately. If this is called on a non-root FSDP instance, then the forward inputs are returned directly. """p_assert(self._is_rootisnotNone,"Expects a root FSDP to have been set")ifnotself._is_root:returnargs,kwargsifself.forward_prefetch:forfsdp_moduleinself.fsdp_modules(self):handles_key=tuple(fsdp_module._handles)ifhandles_key:self._needs_pre_forward_unshard[handles_key]=Trueself._wait_for_previous_optim_step()args,kwargs=self._cast_forward_inputs(*args,**kwargs)returnargs,kwargs
[docs]@staticmethod@contextlib.contextmanagerdefsummon_full_params(module,recurse:bool=True,writeback:bool=True,rank0_only:bool=False,offload_to_cpu:bool=False,)->Generator:r""" A context manager to expose full params for FSDP instances. Can be useful *after* forward/backward for a model to get the params for additional processing or checking. It can take a non-FSDP module and will summon full params for all contained FSDP modules as well as their children, depending on the ``recurse`` argument. .. note:: This can be used on inner FSDPs. .. note:: This can *not* be used within a forward or backward pass. Nor can forward and backward be started from within this context. .. note:: Parameters will revert to their local shards after the context manager exits, storage behavior is the same as forward. .. note:: The full parameters can be modified, but only the portion corresponding to the local param shard will persist after the context manager exits (unless ``writeback=False``, in which case changes will be discarded). In the case where FSDP does not shard the parameters, currently only when ``world_size == 1``, or ``NO_SHARD`` config, the modification is persisted regardless of ``writeback``. .. note:: This method works on modules which are not FSDP themselves but may contain multiple independent FSDP units. In that case, the given arguments will apply to all contained FSDP units. .. warning:: Note that ``rank0_only=True`` in conjunction with ``writeback=True`` is not currently supported and will raise an error. This is because model parameter shapes would be different across ranks within the context, and writing to them can lead to inconsistency across ranks when the context is exited. .. warning:: Note that ``offload_to_cpu`` and ``rank0_only=False`` will result in full parameters being redundantly copied to CPU memory for GPUs that reside on the same machine, which may incur the risk of CPU OOM. It is recommended to use ``offload_to_cpu`` with ``rank0_only=True``. Args: recurse (bool, Optional): recursively summon all params for nested FSDP instances (default: True). writeback (bool, Optional): if ``False``, modifications to params are discarded after the context manager exits; disabling this can be slightly more efficient (default: True) rank0_only (bool, Optional): if ``True``, full parameters are materialized on only global rank 0. This means that within the context, only rank 0 will have full parameters and the other ranks will have sharded parameters. Note that setting ``rank0_only=True`` with ``writeback=True`` is not supported, as model parameter shapes will be different across ranks within the context, and writing to them can lead to inconsistency across ranks when the context is exited. offload_to_cpu (bool, Optional): If ``True``, full parameters are offloaded to CPU. Note that this offloading currently only occurs if the parameter is sharded (which is only not the case for world_size = 1 or ``NO_SHARD`` config). It is recommended to use ``offload_to_cpu`` with ``rank0_only=True`` to avoid redundant copies of model parameters being offloaded to the same CPU memory. """# Note that we specify root_only as FSDP roots will handle summoning# child FSDP instances based on recurse argument.root_fsdp_modules=FullyShardedDataParallel.fsdp_modules(module,root_only=True)# Summon all params for all FSDP instanceswithcontextlib.ExitStack()asstack:formoduleinroot_fsdp_modules:stack.enter_context(module._summon_full_params(recurse=recurse,writeback=writeback,rank0_only=rank0_only,offload_to_cpu=offload_to_cpu,))# Yield to the caller, with full params in all FSDP instances.yield# Exiting from the ExitStack will reshard all params.return
@contextlib.contextmanagerdef_summon_full_params(self,recurse:bool=True,writeback:bool=True,rank0_only:bool=False,offload_to_cpu:bool=False,):ifwritebackandrank0_only:raiseValueError("writeback=True and rank0_only=True is not supported, as model ""parameter shapes will be different across ranks, and writing ""to them can lead to inconsistencies across ranks when the ""context is exited.")ifoffload_to_cpuandnotrank0_only:warnings.warn("offload_to_cpu and rank0_only=False will result in ""full parameters being redundantly copied to CPU memory for ""GPUs that reside on the same machine, which may incur the risk of ""CPU OOM. It is recommended to use ``offload_to_cpu`` with ""rank0_only=True.")ifrecurse:withcontextlib.ExitStack()asstack:formoduleinself.fsdp_modules(self):stack.enter_context(module._summon_full_params(recurse=False,writeback=writeback,rank0_only=rank0_only,offload_to_cpu=offload_to_cpu,))yieldreturntorch.cuda.synchronize()self._lazy_init()self._assert_state([TrainingState_.IDLE])forhandleinself._handles:asserthandle._training_state==HandleTrainingState.IDLEself.training_state=TrainingState_.SUMMON_FULL_PARAMSforhandleinself._handles:handle._training_state=HandleTrainingState.SUMMON_FULL_PARAMSfree_unsharded_flat_params=[handle.needs_unshard()forhandleinself._handles]self._unshard(self._handles)torch.cuda.current_stream().wait_stream(self._streams["all_gather"])ifrank0_onlyandself.rank!=0:# Free the unsharded flattened parameter earlyself._reshard(self._handles,free_unsharded_flat_params)try:yieldfinally:self.training_state=TrainingState_.IDLEforhandleinself._handles:handle._training_state=HandleTrainingState.IDLEelse:# Unflatten the unsharded flattened parameterswithcontextlib.ExitStack()asstack:# Invariant: rank == 0 or !rank0_onlyforhandleinself._handles:ifoffload_to_cpuandhandle.uses_sharded_strategy:stack.enter_context(handle.to_cpu())# TODO (awgu): This FPW call assumes 1 `FlatParameter`stack.enter_context(self._fsdp_wrapped_module.unflatten_as_params())try:yieldfinally:stack.close()ifwriteback:self._write_back_to_local_shard(self._handles)self._reshard(self._handles,free_unsharded_flat_params)self.training_state=TrainingState_.IDLEforhandleinself._handles:handle._training_state=HandleTrainingState.IDLE@torch.no_grad()def_write_back_to_local_shard(self,handles:List[FlatParamHandle]):""" For each handle, writes back the this rank's shard of the unsharded flattened parameter to the sharded flattened parameter. Precondition: Each handle's ``FlatParameter`` 's data points to the padded unsharded flattened parameter. """forhandleinhandles:# For `NO_SHARD`, `_local_shard` is the unsharded flattened# parameter as wellifnothandle.uses_sharded_strategy:continueassert(handle.flat_param.ndim==1),f"Expects `flat_param` to be flattened but got {handle.flat_param.shape}"# Get the unpadded shard instead of the padded shard to persist# user changes to the padding (though FSDP does not explicitly# support this)shard,_=FlatParamHandle._get_unpadded_shard(handle.flat_param,handle.rank,handle.world_size)handle.flat_param._local_shard[:shard.numel()].copy_(shard)
[docs]defnamed_buffers(self,*args,**kwargs,)->Iterator[Tuple[str,torch.Tensor]]:""" Overrides :meth:`named_buffers()` to intercept buffer names and remove all occurrences of the FSDP-specific flattened buffer prefix when inside the :meth:`summon_full_params` context manager. """in_summon_full_params=self.training_state==TrainingState_.SUMMON_FULL_PARAMSforbuffer_name,bufferinsuper().named_buffers(*args,**kwargs):ifin_summon_full_params:# Remove any instances of the FSDP-specific prefix; there can# be multiple in the case of nested FSDP modulesbuffer_name=buffer_name.replace(FSDP_PREFIX,"")yield(buffer_name,buffer)
[docs]defnamed_parameters(self,*args,**kwargs,)->Iterator[Tuple[str,torch.nn.Parameter]]:""" Overrides :meth:`named_parameters()` to intercept parameter names and remove all occurrences of the FSDP-specific flattened parameter prefix when inside the :meth:`summon_full_params` context manager. """# Determine which logic to use based on the context at call timein_summon_full_params=self.training_state==TrainingState_.SUMMON_FULL_PARAMSforparam_name,paraminsuper().named_parameters(*args,**kwargs):ifin_summon_full_params:# Remove any instances of the FSDP-specific prefix; there can# be multiple in the case of nested FSDP modulesparam_name=param_name.replace(FSDP_PREFIX,"")yield(param_name,param)
def_register_pre_backward_hooks(self,outputs:Any,handles:List[FlatParamHandle],)->Any:""" Registers pre-backward hooks on the tensors that require gradients in the forward pass outputs ``outputs``, which were computed using the ``FlatParameter`` s of ``handles``. Returns: Forward pass outputs with pre-backward hooks registered to tensors that require gradients. """# If there is no gradient computation, then there is no need for# pre-backward logicifnottorch.is_grad_enabled():returnoutputsifself._is_root:self._post_backward_callback_queued=False# only defined on the roothandles_key=tuple(handles)ifhandles_key:# Since these handles' `FlatParameter`s participated in a forward,# we conservatively assume that they will be used in the backwardself._needs_pre_backward_unshard[handles_key]=Falseself._ran_pre_backward_hook[handles_key]=Falsedef_pre_backward_hook(_handles:List[FlatParamHandle],*unused:Any)->None:"""Prepares ``_handles`` 's ``FlatParameter`` s for gradient computation."""_handles_key=tuple(_handles)# avoid shadowing `handles_key`# Only run the pre-backward hook once per group of handles involved# in the same module forward computationif_handles_keyandself._ran_pre_backward_hook.get(_handles_key,False):returnwithtorch.autograd.profiler.record_function("FullyShardedDataParallel._pre_backward_hook"):# Queue the post-backward callback once for the root FSDP# instance to attach it to the outermost backward graph task so# that it is called after all backward calls completeifself._is_rootandnotself._post_backward_callback_queued:self._queue_wait_for_post_backward()elif_handles_key:self._assert_state([TrainingState_.IDLE])self.training_state=TrainingState_.BACKWARD_PRE# Queueing the post-backward callback is the only logic that is# not per-handle in the pre-backward hook, so we can return# early here if there are no handles.ifnot_handles_key:returnforhandlein_handles:handle._training_state=HandleTrainingState.BACKWARD_PRE# If the handles have been prefetched, this `_unshard()` simply# switches to using the unsharded parameterself._unshard(_handles)torch.cuda.current_stream().wait_stream(self._streams["all_gather"])# Set this to `False` to ensure that a mistargeted prefetch# does not actually unshard these handlesself._needs_pre_backward_unshard[_handles_key]=Falseself._prefetch_handles(_handles_key)forhandlein_handles:handle.prepare_gradient()self._ran_pre_backward_hook[_handles_key]=Truedef_register_hook(t:torch.Tensor)->torch.Tensor:ift.requires_grad:t.register_hook(functools.partial(_pre_backward_hook,handles))self._needs_pre_backward_unshard[handles_key]=Truereturntreturn_apply_to_tensors(_register_hook,outputs)def_register_post_backward_hooks(self,handles:List[FlatParamHandle],)->None:""" Registers post-backward hooks on the ``FlatParameter`` s' ``AccumulateGrad`` objects to reshard and to reduce-scatter gradients. The ``AccumulateGrad`` object represents the last function that finalizes the ``FlatParameter`` 's gradient, so it only runs after its entire gradient computation has finished. We register the post-backward hook only once in the *first* forward that a ``FlatParameter`` participates in. This relies on the ``AccumulateGrad`` object being preserved through multiple forwards. """# If there is no gradient computation, then there is no need for# post-backward logicifnottorch.is_grad_enabled():returnforhandleinhandles:flat_param=handle.flat_paramalready_registered=hasattr(flat_param,"_post_backward_hook_state")ifalready_registeredornotflat_param.requires_grad:continue# Get the `AccumulateGrad` objecttemp_flat_param=flat_param.expand_as(flat_param)p_assert(temp_flat_param.grad_fnisnotNone,"The `grad_fn` is needed to access the `AccumulateGrad` and ""register the post-backward hook")acc_grad=temp_flat_param.grad_fn.next_functions[0][0]hook_handle=acc_grad.register_hook(functools.partial(self._post_backward_hook,handle))flat_param._post_backward_hook_state=(acc_grad,hook_handle)# type: ignore[attr-defined]@torch.no_grad()def_post_backward_hook(self,handle:FlatParamHandle,*unused:Any,)->None:""" Reduce-scatters the gradient of ``handle`` 's ``FlatParameter``. Precondition: The ``FlatParameter`` 's ``.grad`` attribute contains the unsharded gradient for the local batch. Postcondition: - If using ``NO_SHARD``, then the ``.grad`` attribute is the reduced unsharded gradient. - Otherwise, the ``_saved_grad_shard`` attribute is the reduced sharded gradient (accumulating with any existing gradient). """param=handle.flat_paramparam._post_backward_called=Truewithtorch.autograd.profiler.record_function("FullyShardedDataParallel._post_backward_hook"):# First hook callback will see PRE state. If we have multiple params,# then subsequent hook callbacks will see POST state.self._assert_state([TrainingState_.BACKWARD_PRE,TrainingState_.BACKWARD_POST])self.training_state=TrainingState_.BACKWARD_POSThandle._training_state=HandleTrainingState.BACKWARD_POSTifself._use_param_exec_order_policy()andself._param_exec_order_prep_stage:# In self._fsdp_params_exec_order, the parameters are ordered based on# the execution order in the backward pass in the first iteration.self._fsdp_params_exec_order.append(param)ifparam.gradisNone:returnifparam.grad.requires_grad:raiseRuntimeError("FSDP only works with gradients that don't require gradients")free_unsharded_flat_param=self._should_free_unsharded_flat_param(handle)self._reshard([handle],[free_unsharded_flat_param])# TODO (awgu): Post-backward prefetching does not support the# multiple handles per module case (which was why we keyed by# *tuple*). The post-backward hook runs per handle, not per group# of handles. To generalize this, we may need a 2-level mapping,# where we map each individual handle to its groups of handles and# then from the groups of handles to their indices in the order.handles_key=(handle,)self._prefetch_handles(handles_key)ifnotself._sync_gradients:return# Wait for all ops in the current stream (e.g. gradient# computation) to finish before reduce-scattering the gradientself._streams["post_backward"].wait_stream(torch.cuda.current_stream())withtorch.cuda.stream(self._streams["post_backward"]):orig_grad_data=param.grad.dataif(self._mixed_precision_enabled_for_reduce()andnotself._low_precision_hook_enabled()):# Cast gradient to precision in which it should be communicated.# If a low precision hook is registered and reduce_dtype is specified# in `MixedPrecision`, communication hook will take care of# casting to lower precision and back.# TODO: Make this a communication hook when communication hooks# are implemented for FSDP. Note that this is a noop if the# reduce_dtype matches the param dtype.param.grad.data=param.grad.data.to(self.mixed_precision.reduce_dtype)ifself._exec_order_data.is_first_iter:# For all sharding strategies communication is performed through `_communication_hook`:# default comm hooks are: `reduce_scatter` for sharded strategies and# `all_reduce` for non-sharded strategies. This checks asserts that `_communication_hook`# and `_communication_hook_state`, required for communication not `None`.`p_assert(self._communication_hookisnotNone,"Communication hook should not be None")p_assert(self._communication_hook_stateisnotNone,"Communication hook state should not be None")grad=param.grad.dataifhandle.uses_sharded_strategy:# We clear `param.grad` to permit repeated gradient# computations when this FSDP module is called multiple times.# This is to avoid a race among multiple re-entrant backward# passes. For example, the second backward pass computation# precedes ahead of the first backward pass reduction, which is# possible since the reduction is in a different stream and is# async. Then, the first backward pass may be incorrectly# reducing the second backward pass's `param.grad`.# The reduced gradients are accumulated in# `param._saved_grad_shard`, and the gradient reductions can# happen in arbitrary order, though we tolerate this due to the# (approximate) commutativity of floating-point addition.param.grad=Nonegrad_flatten=torch.flatten(grad)chunks=list(grad_flatten.chunk(self.world_size))num_pad=self.world_size*chunks[0].numel()-grad.numel()input_flattened=F.pad(grad_flatten,[0,num_pad])output=torch.zeros_like(chunks[0])self._communication_hook(self._communication_hook_state,input_flattened,output)self._cast_grad_to_param_dtype(output,param)# To support gradient accumulation outside `no_sync()`, we save# the gradient data to `param._saved_grad_shard` before the# backward pass, accumulate gradients into it here, and set# `param.grad` with the accumulated value at the end of the# backward pass in preparation for the optimizer step.accumulate_grad=hasattr(param,"_saved_grad_shard")ifaccumulate_grad:p_assert(param._saved_grad_shard.shape==output.shape,# type: ignore[attr-defined]"Shape mismatch when accumulating gradients: "# type: ignore[attr-defined]f"existing grad shape={param._saved_grad_shard.shape} "f"new grad shape={output.shape}"# type: ignore[attr-defined])p_assert(param._saved_grad_shard.device==output.device,# type: ignore[attr-defined]"Device mismatch when accumulating gradients: "# type: ignore[attr-defined]f"existing grad device={param._saved_grad_shard.device} "f"new grad device={output.device}"# type: ignore[attr-defined])param._saved_grad_shard+=output# type: ignore[attr-defined]else:param._saved_grad_shard=output# type: ignore[attr-defined]grad=param._saved_grad_shard# type: ignore[attr-defined]else:ifself.sharding_strategy==ShardingStrategy.NO_SHARD:self._communication_hook(self._communication_hook_state,param.grad)# For NO_SHARD keeping grads in the reduced precision, we# can simply omit the cast as needed, we can't do this for# other sharding strategies because grad field is assigned# in _finalize_params. TODO (rvarm1) this divergence in# logic is not ideal.ifnotself._mixed_precision_keep_low_precision_grads():self._cast_grad_to_param_dtype(param.grad,param)# Regardless of sharding or not, offload the grad to CPU if we are# offloading params. This is so param and grad reside on same device# which is needed for the optimizer step.ifhandle._config.offload_params:# We specify non_blocking=True# and ensure the appropriate synchronization is done by waiting# streams in _wait_for_post_backward.param._cpu_grad.copy_(# type: ignore[attr-defined]grad.detach(),non_blocking=True)# Don't let this memory get reused until after the transfer.grad.data.record_stream(torch.cuda.current_stream())# After _post_backward_hook returns, orig_grad_data will eventually# go out of scope, at which point it could otherwise be freed for# further reuse by the main stream while the div/reduce_scatter/copy# are underway in the post_backward stream. See:# github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.pyorig_grad_data.record_stream(self._streams["post_backward"])def_cast_grad_to_param_dtype(self,grad:torch.Tensor,param:FlatParameter,):""" Casts gradient ``grad`` back to the full parameter dtype so that the optimizer step runs with that dtype. This performs an actual cast if 1. parameters were in reduced precision during the forward since then gradients would be in that reduced precision, or 2. parameters were not in reduced precision but gradients were in reduced precision for communication. However, if a low precision communication hook is registered, then this dtype cast happens in the hook instead. """self._assert_state(TrainingState_.BACKWARD_POST)if(notself._low_precision_hook_enabled()and(self._mixed_precision_enabled_for_params()orself._mixed_precision_enabled_for_reduce())):low_prec_grad_data=grad.datagrad.data=grad.data.to(dtype=param.dtype)# Do not let the low precision gradient memory get reused until# the cast to full parameter precision completeslow_prec_grad_data.record_stream(torch.cuda.current_stream())def_should_free_unsharded_flat_param(self,handle:FlatParamHandle):return((self._sync_gradientsandhandle.uses_sharded_strategy)orhandle._config.sharding_strategy==HandleShardingStrategy.FULL_SHARD)def_queue_wait_for_post_backward(self)->None:""" Queues a post-backward callback from the root FSDP instance, which should happen at the beginning of its pre-backward. """p_assert(self._is_root,"`_queue_wait_for_post_backward()` should be called on the root FSDP instance")ifself._post_backward_callback_queued:returnself._assert_state([TrainingState_.IDLE])self._post_backward_callback_queued=TrueVariable._execution_engine.queue_callback(self._wait_for_post_backward)@torch.no_grad()def_wait_for_post_backward(self)->None:"""Wait for post-backward to finish. Only called on root instance."""assertself._is_root,"_wait_for_post_backward can only be called on root."# Root's training state might be backward_pre or backward_post depending on# if root parameter's post backward hook was called. The post-backward hook# may not have been called if gradient was not computed for this param/FSDP# module.ifself._sync_gradients:torch.cuda.current_stream().wait_stream(self._streams["post_backward"])ifself.cpu_offload.offload_params:# We need to wait for the non-blocking GPU -># CPU grad transfers to finish. We need to do this for GPU -> CPU# copies because when grad is on CPU, it won't wait for any CUDA# stream to finish GPU -> CPU copies unless we explicitly block the# host-side with synchronize().torch.cuda.current_stream().synchronize()self._exec_order_data.next_iter()# A backward pass is done, clean up below.def_catch_all_reshard(fsdp_module:FullyShardedDataParallel)->None:""" Reshards full parameters that may have not been resharded in post_backward_hook. This can happen when an FSDP module's output is used in forward so its pre-backward fires unsharding the param, but post-backward does not fire since the output was not ultimately used in loss computation so FSDP parameter did not get a gradient. """# Note that we wrap resharding logic in a try-catch as a defensive# approach, as if an error is thrown, we are in the backwards pass,# and autograd would not print out much useful info about the actual# error hit.try:free_unsharded_flat_params:List[bool]=[]handles_to_reshard:List[FlatParamHandle]=[]forhandleinfsdp_module._handles:# TODO: This already-resharded check is brittle:# https://github.com/pytorch/pytorch/issues/83956already_resharded=(handle.flat_param.data_ptr()==handle.flat_param._local_shard.data_ptr())ifalready_resharded:continuefree_unsharded_flat_params.append(self._should_free_unsharded_flat_param(handle))handles_to_reshard.append(handle)self._reshard(handles_to_reshard,free_unsharded_flat_params)exceptExceptionase:p_assert(False,f"Got exception while resharding module {fsdp_module}: {str(e)}",raise_assertion_error=False)raiseedef_finalize_params(fsdp_module:FullyShardedDataParallel)->None:"""Helper used below on all fsdp modules."""forhandleinfsdp_module._handles:p=handle.flat_paramifp.requires_grad:ifhasattr(p,"_post_backward_hook_state"):p_assert(len(p._post_backward_hook_state)==2,# type: ignore[attr-defined]"p._post_backward_hook_state fields are not valid.")p._post_backward_hook_state[1].remove()# type: ignore[attr-defined]delattr(p,"_post_backward_hook_state")# Preserve the gradient accumulation state if not# synchronizing: `p.grad` remains the unsharded gradient# accumulated from prior `no_sync()` iterations, and# `p._saved_grad_shard` remains the sharded gradient from# the last synchronized iterationifnotself._sync_gradients:continue# Set `p.grad` as needed to ensure optimizer correctness# since optimizers operate on the `grad` attributeifhasattr(p,"_cpu_grad"):p_assert(p.device==torch.device("cpu"),f"Device mismatch: p={p.device} "# type: ignore[attr-defined]f"p._cpu_grad={p._cpu_grad}")p.grad=p._cpu_grad# type: ignore[attr-defined]elifhasattr(p,"_saved_grad_shard"):p_assert(p.device==p._saved_grad_shard.device,# type: ignore[attr-defined]f"Device mismatch: p={p.device} "# type: ignore[attr-defined]f"p._saved_grad_shard={p._saved_grad_shard.device}")# Check if post-backward was called for this param (FSDP unit).# TODO: This logic will have to be revisited when non-recursive wrapping# lands. If it was not called, there is no new gradient to accumulateifp._post_backward_called:p.grad=p._saved_grad_shardiffsdp_module._mixed_precision_keep_low_precision_grads():p.grad.data=p.grad.to(fsdp_module.mixed_precision.param_dtype)else:p_assert(nothandle.uses_sharded_strategyornotp._post_backward_called,"All sharded parameters that received a gradient ""should use `_saved_grad_shard`")ifhasattr(p,"_saved_grad_shard"):delattr(p,"_saved_grad_shard")p_assert(hasattr(p,'_post_backward_called'),"Expected flag _post_backward_called to be set on param.")# Reset _post_backward_called in preparation for the next iteration.p._post_backward_called=False# Update root and nested FSDP's hooks and flags.forminself.fsdp_modules(self):# includes self_finalize_params(m)_catch_all_reshard(m)m._ran_pre_backward_hook.clear()m.training_state=TrainingState_.IDLEforhandleinm._handles:handle._training_state=HandleTrainingState.IDLEm._handles_prefetched.clear()ifm._is_root:# reset this flag for cases like "one forward pass + multiple backward passes"self._post_backward_callback_queued=Falseifself._use_param_exec_order_policy()andself._param_exec_order_prep_stage:self._param_exec_order_policy_second_iter_init()def_param_exec_order_policy_second_iter_init(self)->None:self._param_exec_order_prep_stage=False# Let the parameters in self._fsdp_params_exec_order ordered based on# the execution order in the forward pass.self._fsdp_params_exec_order.reverse()forminself.modules():ifmisnotselfandisinstance(m,FullyShardedDataParallel):asserthasattr(m,"_param_exec_order_policy"),"Non-root FSDP modules should also have _param_exec_order_policy attribute"asserthasattr(m,"_param_exec_order_prep_stage"),"Non-root FSDP modules should also have _param_exec_order_prep_stage attribute"m._param_exec_order_prep_stage=False# TODO (linjianma): Construct a fsdp_wrap_map whose keys are all children modules with a FSDP wrap,# and values are its FSDP wraps. These children FSDP wraps will be detached from the root FSDP module# and will be used to schedule the parameters (rebuild_full_params and reshard).# TODO (linjianma): Remove all internal FSDP wraps from the root FSDP module.# TODO (linjianma): Based on self._fsdp_params_exec_order, get the information# needed to patch the forward() function of each key in the fsdp_wrap_map. The rules are as follows:# 1: Before each forward(), rebuild_full_params of all parameters that are currently sharded and# will be used in the forward, and reshard all parameters that are currently full and will not be# used in the next forward()# 2: After each forward(), reshard all parameters just used in the forward, and rebuild_full_params of# all parameters that will be used next.# TODO (linjianma): Patch the forward of each model in the keys# of fsdp_wrap_map based on the information above.def_assert_state(self,state:Union[TrainingState_,List[TrainingState_]])->None:"""Assert we are in the given state."""# Since assert can be turned off and this error checking# is really important, we use explicit error checking# and raise a ValueError if needed.ifisinstance(state,TrainingState_):state=[state]ifself.training_statenotinstate:msg=(f"expected to be in states {state} but current state "f"is {self.training_state}")# In case we are failing in the context of autograd hook, asserting# may not generate useful msg. So, let's print it to be sure.ifself.rank==0:print(f"Asserting FSDP instance is: {self}")print(f"ERROR: {msg}")traceback.print_stack()raiseValueError(msg)
[docs]@contextmanagerdefno_sync(self)->Generator:""" A context manager to disable gradient synchronizations across FSDP instances. Within this context, gradients will be accumulated in module variables, which will later be synchronized in the first forward-backward pass after exiting the context. This should only be used on the root FSDP instance and will recursively apply to all children FSDP instances. .. note:: This likely results in higher memory usage because FSDP will accumulate the full model gradients (instead of gradient shards) until the eventual sync. .. note:: When used with CPU offloading, the gradients will not be offloaded to CPU when inside the context manager. Instead, they will only be offloaded right after the eventual sync. """self._lazy_init()assertself._is_root,"`no_sync()` on inner FSDP instances is not supported"self._assert_state(TrainingState_.IDLE)old_flags=[]forminself.modules():ifisinstance(m,FullyShardedDataParallel):old_flags.append((m,m._sync_gradients))m._sync_gradients=Falsetry:yieldfinally:form,old_flaginold_flags:assertnotm._sync_gradients,("`_sync_gradients` was incorrectly set to ""`True` while in the `no_sync()` context manager")m._sync_gradients=old_flag
@propertydefparams_with_grad(self)->List[Parameter]:""" Recursively returns a list of all module parameters that have a gradient. """return[pforpinself.parameters()ifp.gradisnotNone]
[docs]@torch.no_grad()defclip_grad_norm_(self,max_norm:Union[float,int],norm_type:Union[float,int]=2.0)->None:""" Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place. Args: max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. Returns: Total norm of the parameters (viewed as a single vector). .. note:: This is analogous to ``torch.nn.utils.clip_grad_norm_`` but handles the partitioning and multiple devices per rank under the hood. The default torch util is not applicable here, because each rank only has a partial view of all the grads in the model, so calling it for FSDP models would lead to different scaling being applied per subset of model parameters. .. warning:: This needs to be called on all ranks, since synchronization primitives will be used. """self._lazy_init()self._wait_for_previous_optim_step()assertself._is_root,"clip_grad_norm should only be called on the root (parent) instance"self._assert_state(TrainingState_.IDLE)max_norm=float(max_norm)norm_type=float(norm_type)# Computes the max norm for this shard's gradients and sync's across workerslocal_norm=_calc_grad_norm(self.params_with_grad,norm_type).cuda()# type: ignore[arg-type]ifnorm_type==math.inf:total_norm=local_normdist.all_reduce(total_norm,op=torch.distributed.ReduceOp.MAX,group=self.process_group)else:total_norm=local_norm**norm_typedist.all_reduce(total_norm,group=self.process_group)total_norm=total_norm**(1.0/norm_type)ifself.cpu_offload:total_norm=total_norm.cpu()clip_coef=torch.tensor(max_norm,dtype=total_norm.dtype,device=total_norm.device)/(total_norm+1e-6)ifclip_coef<1:# multiply by clip_coef, aka, (max_norm/total_norm).forpinself.params_with_grad:assertp.gradisnotNonep.grad.detach().mul_(clip_coef.to(p.grad.device))
@staticmethoddef_warn_optim_input(optim_input):ifoptim_inputisnotNone:warnings.warn("The `optim_input` argument is deprecated and will be removed after PyTorch 1.13. You may remove it ""from your code without changing its functionality.")@staticmethoddef_is_using_optim_input(optim_input,optim)->bool:ifoptim_inputisNoneandoptimisNone:# Use the default behavior of `optim_input``returnTrueifoptim_inputisnotNone:# Use the `optim_input` code pathreturnTrue# Use the `optim` code pathreturnFalse
[docs]@staticmethoddeffull_optim_state_dict(model:torch.nn.Module,optim:torch.optim.Optimizer,optim_input:Optional[Union[List[Dict[str,Any]],Iterable[torch.nn.Parameter],]]=None,rank0_only:bool=True,group:Optional[dist.ProcessGroup]=None,)->Dict[str,Any]:""" Consolidates the full optimizer state on rank 0 and returns it as a :class:`dict` following the convention of :meth:`torch.optim.Optimizer.state_dict`, i.e. with keys ``"state"`` and ``"param_groups"``. The flattened parameters in ``FSDP`` modules contained in ``model`` are mapped back to their unflattened parameters. .. warning:: This needs to be called on all ranks since synchronization primitives are used. However, if ``rank0_only=True``, then the state dict is only populated on rank 0, and all other ranks return an empty :class:`dict`. .. warning:: Unlike ``torch.optim.Optimizer.state_dict()``, this method uses full parameter names as keys instead of parameter IDs. .. note:: Like in :meth:`torch.optim.Optimizer.state_dict`, the tensors contained in the optimizer state dict are not cloned, so there may be aliasing surprises. For best practices, consider saving the returned optimizer state dict immediately, e.g. using ``torch.save()``. Args: model (torch.nn.Module): Root module (which may or may not be a :class:`FullyShardedDataParallel` instance) whose parameters were passed into the optimizer ``optim``. optim (torch.optim.Optimizer): Optimizer for ``model`` 's parameters. optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): Input passed into the optimizer ``optim`` representing either a :class:`list` of parameter groups or an iterable of parameters; if ``None``, then this method assumes the input was ``model.parameters()``. This argument is deprecated, and there is no need to pass it in anymore. (Default: ``None``) rank0_only (bool): If ``True``, saves the populated :class:`dict` only on rank 0; if ``False``, saves it on all ranks. (Default: ``True``) group (dist.ProcessGroup): Model's process group or ``None`` if using the default process group. (Default: ``None``) Returns: Dict[str, Any]: A :class:`dict` containing the optimizer state for ``model`` 's original unflattened parameters and including keys "state" and "param_groups" following the convention of :meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=True``, then nonzero ranks return an empty :class:`dict`. """FullyShardedDataParallel._warn_optim_input(optim_input)using_optim_input=FullyShardedDataParallel._is_using_optim_input(optim_input,optim,)return_optim_state_dict(model=model,optim=optim,optim_input=optim_input,rank0_only=rank0_only,shard_state=False,group=group,using_optim_input=using_optim_input,)
[docs]@staticmethoddefsharded_optim_state_dict(model:torch.nn.Module,optim:torch.optim.Optimizer,optim_input:Optional[Union[List[Dict[str,Any]],Iterable[torch.nn.Parameter],]]=None,group:Optional[dist.ProcessGroup]=None,)->Dict[str,Any]:""" The API is similar to :meth:`full_optim_state_dict` but this API chunks all non-zero-dimension states to :class:`ShardedTensor` to save memory. This API should only be used when the model ``state_dict`` is derived with the context manager ``with state_dict_type(SHARDED_STATE_DICT):``. For the detailed usage, refer to :meth:`full_optim_state_dict`. .. warning:: The returned state dict contains ``ShardedTensor`` and cannot be directly used by the regular ``optim.load_state_dict``. """FullyShardedDataParallel._warn_optim_input(optim_input)using_optim_input=FullyShardedDataParallel._is_using_optim_input(optim_input,optim,)# TODO: The ultimate goal of the optimizer state APIs should be the same# as state_dict/load_state_dict -- using one API to get optimizer states# and one API to load optimizer states. ``state_dict_type`` will be used# to decide which optimizer states should be returned.# There are currently two APIs to load a full optimizer state. So the# first step of the unification is to merge the two full optimizer state# loading APIs.# Task: https://github.com/pytorch/pytorch/issues/82232return_optim_state_dict(model=model,optim=optim,optim_input=optim_input,rank0_only=False,shard_state=True,group=group,using_optim_input=using_optim_input,)
[docs]@staticmethoddefshard_full_optim_state_dict(full_optim_state_dict:Dict[str,Any],model:torch.nn.Module,optim_input:Optional[Union[List[Dict[str,Any]],Iterable[torch.nn.Parameter],]]=None,optim:Optional[torch.optim.Optimizer]=None,)->Dict[str,Any]:""" Shards the full optimizer state dict ``full_optim_state_dict`` by remapping the state to flattened parameters instead of unflattened parameters and restricting to only this rank's part of the optimizer state. The first argument should be the return value of :meth:`full_optim_state_dict`. Example:: >>> # xdoctest: +SKIP("undefined variables") >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) >>> torch.save(full_osd, PATH) >>> # Define new model with possibly different world size >>> new_model, new_optim = ... >>> full_osd = torch.load(PATH) >>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model) >>> new_optim.load_state_dict(sharded_osd) .. note:: Both :meth:`shard_full_optim_state_dict` and :meth:`scatter_full_optim_state_dict` may be used to get the sharded optimizer state dict to load. Assuming that the full optimizer state dict resides in CPU memory, the former requires each rank to have the full dict in CPU memory, where each rank individually shards the dict without any communication, while the latter requires only rank 0 to have the full dict in CPU memory, where rank 0 moves each shard to GPU memory (for NCCL) and communicates it to ranks appropriately. Hence, the former has higher aggregate CPU memory cost, while the latter has higher communication cost. Args: full_optim_state_dict (Dict[str, Any]): Optimizer state dict corresponding to the unflattened parameters and holding the full non-sharded optimizer state. model (torch.nn.Module): Root module (which may or may not be a :class:`FullyShardedDataParallel` instance) whose parameters correspond to the optimizer state in ``full_optim_state_dict``. optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): Input passed into the optimizer representing either a :class:`list` of parameter groups or an iterable of parameters; if ``None``, then this method assumes the input was ``model.parameters()``. This argument is deprecated, and there is no need to pass it in anymore. (Default: ``None``) optim (Optional[torch.optim.Optimizer]): Optimizer that will load the state dict returned by this method. This is the preferred argument to use over ``optim_input``. (Default: ``None``) Returns: Dict[str, Any]: The full optimizer state dict now remapped to flattened parameters instead of unflattened parameters and restricted to only include this rank's part of the optimizer state. """FullyShardedDataParallel._warn_optim_input(optim_input)using_optim_input=FullyShardedDataParallel._is_using_optim_input(optim_input,optim,)sharded_osd=_flatten_optim_state_dict(full_optim_state_dict,model,True,)return_rekey_sharded_optim_state_dict(sharded_osd,model,optim,optim_input,using_optim_input,)
[docs]@staticmethoddefflatten_sharded_optim_state_dict(sharded_optim_state_dict:Dict[str,Any],model:torch.nn.Module,optim_input:Optional[Union[List[Dict[str,Any]],Iterable[torch.nn.Parameter],]]=None,optim:Optional[torch.optim.Optimizer]=None,)->Dict[str,Any]:""" The API is similar to :meth:`shard_full_optim_state_dict`. The only difference is that the input ``sharded_optim_state_dict`` should be returned from :meth:`sharded_optim_state_dict`. Therefore, there will be all-gather calls on each rank to gather ``ShardedTensor`` s. Args: sharded_optim_state_dict (Dict[str, Any]): Optimizer state dict corresponding to the unflattened parameters and holding the sharded optimizer state. model (torch.nn.Module): Refer to :meth:``shard_full_optim_state_dict``. Returns: Refer to :meth:`shard_full_optim_state_dict`. """FullyShardedDataParallel._warn_optim_input(optim_input)using_optim_input=FullyShardedDataParallel._is_using_optim_input(optim_input,optim,)# TODO: The implementation is the same as ``shard_full_optim_state_dict``.# See the TODO in ``shard_full_optim_state_dict`` for the future# unification plan.flattened_osd=_flatten_optim_state_dict(sharded_optim_state_dict,model=model,shard_state=True,)return_rekey_sharded_optim_state_dict(flattened_osd,model,optim,optim_input,using_optim_input,)
[docs]@staticmethoddefscatter_full_optim_state_dict(full_optim_state_dict:Optional[Dict[str,Any]],model:torch.nn.Module,optim_input:Optional[Union[List[Dict[str,Any]],Iterable[torch.nn.Parameter],]]=None,optim:Optional[torch.optim.Optimizer]=None,group:Optional[Any]=None,)->Dict[str,Any]:""" Scatters the full optimizer state dict from rank 0 to all other ranks, returning the sharded optimizer state dict on each rank. The return value is the same as :meth:`shard_full_optim_state_dict`, and on rank 0, the first argument should be the return value of :meth:`full_optim_state_dict`. Example:: >>> # xdoctest: +SKIP("undefined variables") >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 >>> # Define new model with possibly different world size >>> new_model, new_optim, new_group = ... >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) >>> new_optim.load_state_dict(sharded_osd) .. note:: Both :meth:`shard_full_optim_state_dict` and :meth:`scatter_full_optim_state_dict` may be used to get the sharded optimizer state dict to load. Assuming that the full optimizer state dict resides in CPU memory, the former requires each rank to have the full dict in CPU memory, where each rank individually shards the dict without any communication, while the latter requires only rank 0 to have the full dict in CPU memory, where rank 0 moves each shard to GPU memory (for NCCL) and communicates it to ranks appropriately. Hence, the former has higher aggregate CPU memory cost, while the latter has higher communication cost. Args: full_optim_state_dict (Optional[Dict[str, Any]]): Optimizer state dict corresponding to the unflattened parameters and holding the full non-sharded optimizer state if on rank 0; the argument is ignored on nonzero ranks. model (torch.nn.Module): Root module (which may or may not be a :class:`FullyShardedDataParallel` instance) whose parameters correspond to the optimizer state in ``full_optim_state_dict``. optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): Input passed into the optimizer representing either a :class:`list` of parameter groups or an iterable of parameters; if ``None``, then this method assumes the input was ``model.parameters()``. This argument is deprecated, and there is no need to pass it in anymore. (Default: ``None``) optim (Optional[torch.optim.Optimizer]): Optimizer that will load the state dict returned by this method. This is the preferred argument to use over ``optim_input``. (Default: ``None``) group (dist.ProcessGroup): Model's process group or ``None`` if using the default process group. (Default: ``None``) Returns: Dict[str, Any]: The full optimizer state dict now remapped to flattened parameters instead of unflattened parameters and restricted to only include this rank's part of the optimizer state. """FullyShardedDataParallel._warn_optim_input(optim_input)using_optim_input=FullyShardedDataParallel._is_using_optim_input(optim_input,optim,)# Try to use the passed-in process group, the model's process group,# or the default process group (i.e. `None`) in that priority orderifgroupisNoneandhasattr(model,"process_group"):group=model.process_grouprank=dist.get_rank(group)world_size=dist.get_world_size(group)# Check for a valid broadcast device, preferring GPU when availableusing_nccl=dist.distributed_c10d._check_for_nccl_backend(group)broadcast_device=torch.device("cuda")iftorch.cuda.is_available() \
elsetorch.device("cpu")ifusing_ncclandnottorch.cuda.is_available():raiseRuntimeError("NCCL requires a GPU for collectives")# Flatten the optimizer state dict and construct a copy with the# positive-dimension tensors' shapes in place of the tensors themselves# since those tensors will be broadcast separately to avoid copyingifrank==0:iffull_optim_state_dictisNone:raiseValueError("Rank 0 must pass in the full optimizer state dict")flat_osd=_flatten_optim_state_dict(full_optim_state_dict,model=model,shard_state=False,)processed_osd=_process_pos_dim_tensor_state(flat_osd,world_size)# Broadcast the optim state dict without positive-dimension tensor# state and the FSDP parameter IDs from rank 0 to all ranksprocessed_osd=_broadcast_processed_optim_state_dict(processed_osdifrank==0elseNone,rank,group,)# Broadcast positive-dimension tensor state (both sharded tensors for# FSDP parameters and unsharded tensors for non-FSDP parameters)sharded_osd=_broadcast_pos_dim_tensor_states(processed_osd,flat_osdifrank==0elseNone,rank,world_size,group,broadcast_device,)# Rekey the optimizer state dict to use parameter IDs according to this# rank's `optim`sharded_osd=_rekey_sharded_optim_state_dict(sharded_osd,model,optim,optim_input,using_optim_input,)returnsharded_osd
[docs]@staticmethoddefrekey_optim_state_dict(optim_state_dict:Dict[str,Any],optim_state_key_type:OptimStateKeyType,model:torch.nn.Module,optim_input:Optional[Union[List[Dict[str,Any]],Iterable[torch.nn.Parameter],]]=None,optim:Optional[torch.optim.Optimizer]=None,)->Dict[str,Any]:""" Re-keys the optimizer state dict ``optim_state_dict`` to use the key type ``optim_state_key_type``. This can be used to achieve compatibility between optimizer state dicts from models with FSDP instances and ones without. To re-key an FSDP full optimizer state dict (i.e. from :meth:`full_optim_state_dict`) to use parameter IDs and be loadable to a non-wrapped model:: >>> # xdoctest: +SKIP("undefined variables") >>> wrapped_model, wrapped_optim = ... >>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) >>> nonwrapped_model, nonwrapped_optim = ... >>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) >>> nonwrapped_optim.load_state_dict(rekeyed_osd) To re-key a normal optimizer state dict from a non-wrapped model to be loadable to a wrapped model:: >>> # xdoctest: +SKIP("undefined variables") >>> nonwrapped_model, nonwrapped_optim = ... >>> osd = nonwrapped_optim.state_dict() >>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) >>> wrapped_model, wrapped_optim = ... >>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) >>> wrapped_optim.load_state_dict(sharded_osd) Returns: Dict[str, Any]: The optimizer state dict re-keyed using the parameter keys specified by ``optim_state_key_type``. """FullyShardedDataParallel._warn_optim_input(optim_input)using_optim_input=FullyShardedDataParallel._is_using_optim_input(optim_input,optim,)assertoptim_state_key_typein(OptimStateKeyType.PARAM_NAME,OptimStateKeyType.PARAM_ID,)osd=optim_state_dict# alias# Validate that the existing parameter keys are uniformly typeduses_param_name_mask=[type(param_key)isstrforparam_keyinosd["state"]]uses_param_id_mask=[type(param_key)isintforparam_keyinosd["state"]]if((any(uses_param_name_mask)andnotall(uses_param_name_mask))or(any(uses_param_id_mask)andnotall(uses_param_id_mask))):error_msg=f"Invalid parameter keys: {osd['state'].keys()}"raiseValueError(error_msg)# Return directly if the existing key type matches the target key typeif(optim_state_key_type==OptimStateKeyType.PARAM_NAMEandall(uses_param_name_mask))or \
(optim_state_key_type==OptimStateKeyType.PARAM_IDandall(uses_param_id_mask)):returnosd# Otherwise, actually perform the re-keyingnew_osd={}ifoptim_state_key_type==OptimStateKeyType.PARAM_NAME:# ID -> nameparam_id_to_param=(_get_param_id_to_param_from_optim_input(model,optim_input)ifusing_optim_inputelse_get_param_id_to_param(optim))param_to_param_name=_get_param_to_param_name(model)param_id_to_param_name:List[str]=[param_to_param_name[param]forparaminparam_id_to_param]new_osd["state"]={param_id_to_param_name[param_id]:param_stateforparam_id,param_stateinosd["state"].items()}new_osd["param_groups"]=copy.deepcopy(osd["param_groups"])forparam_groupinnew_osd["param_groups"]:param_group["params"]=sorted([param_id_to_param_name[param_id]forparam_idinparam_group["params"]])returnnew_osdelifoptim_state_key_type==OptimStateKeyType.PARAM_ID:# name -> IDparam_name_to_param=_get_param_name_to_param(model)param_to_param_id=(_get_param_to_param_id_from_optim_input(model,optim_input)ifusing_optim_inputelse_get_param_to_param_id(optim))# Because not all model parameters may be passed as the optimizer# input, we may need to drop some parameters from this mappingparam_name_to_param_id={param_name:param_to_param_id[param]forparam_name,paraminparam_name_to_param.items()ifparaminparam_to_param_id}new_osd["state"]={param_name_to_param_id[param_name]:param_stateforparam_name,param_stateinosd["state"].items()}new_osd["param_groups"]=copy.deepcopy(osd["param_groups"])forparam_groupinnew_osd["param_groups"]:param_group["params"]=sorted([param_name_to_param_id[param_name]forparam_nameinparam_group["params"]])returnnew_osdreturnnew_osd# should never reach here
def_get_default_comm_hook(self)->Any:r""" Returns a default communication hook based on a sharding strategy. """ifself.sharding_strategy!=ShardingStrategy.NO_SHARD:returndefault_hooks.reduce_scatter_hookelse:returndefault_hooks.allreduce_hookdef_get_default_comm_hook_state(self)->Any:r""" Returns a default communication hook state based on a sharding strategy. """returndefault_hooks.DefaultState(process_group=self.process_group)
[docs]defregister_comm_hook(self,state:object,hook:callable):""" Registers a communication hook which is an enhancement that provides a flexible hook to users where they can specify how FSDP aggregates gradients across multiple workers. This hook can be used to implement several algorithms like `GossipGrad <https://arxiv.org/abs/1803.05880>`_ and gradient compression which involve different communication strategies for parameter syncs while training with :class:`FullyShardedDataParallel`. .. warning :: FSDP communication hook should be registered before running an initial forward pass and only once. Args: state (object): Passed to the hook to maintain any state information during the training process. Examples include error feedback in gradient compression, peers to communicate with next in `GossipGrad <https://arxiv.org/abs/1803.05880>`_, etc. It is locally stored by each worker and shared by all the gradient tensors on the worker. hook (Callable): Callable, which has one of the following signatures: 1) ``hook: Callable[torch.Tensor] -> None``: This function takes in a Python tensor, which represents the full, flattened, unsharded gradient with respect to all variables corresponding to the model this FSDP unit is wrapping (that are not wrapped by other FSDP sub-units). It then performs all necessary processing and returns ``None``; 2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``: This function takes in two Python tensors, the first one represents the full, flattened, unsharded gradient with respect to all variables corresponding to the model this FSDP unit is wrapping (that are not wrapped by other FSDP sub-units). The latter represents a pre-sized tensor to store a chunk of a sharded gradient after reduction. In both cases, callable performs all necessary processing and returns ``None``. Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case. Callables with signature 2 are expected to handle gradient communication for sharded cases. """ifnotself.check_is_root():raiseAssertionError("register_comm_hook can only be called on a root instance.")forsubmoduleinself.fsdp_modules(self):assertnotsubmodule._hook_registered,"communication hook can be only registered once"submodule._hook_registered=Trueassertsubmodule._communication_hook==self._get_default_comm_hook(),\
f"communication hook should be default, but it is {submodule._communication_hook.__name__} instead"submodule._communication_hook_state=statesubmodule._communication_hook=hook
def_init_param_exec_order_wrap_policy(self,*args,**kwargs)->None:auto_wrap_policy=kwargs["auto_wrap_policy"]module=kwargs["module"]asserthasattr(auto_wrap_policy,"tracing_config")ifnot_TORCH_FX_AVAIL:assert(auto_wrap_policy.tracing_configisNone),"tracing_config should be None when torch.fx is not enabled"elifisinstance(auto_wrap_policy.tracing_config,TracingConfig):tracer=auto_wrap_policy.tracing_config.tracerexecution_info=_init_execution_info(module)forminmodule.modules():assertnotisinstance(m,FullyShardedDataParallel),"The input module of _patch_tracer should not contain FSDP modules"with_patch_tracer(tracer=tracer,root_module=module,execution_info=execution_info,):try:tracer.trace(module,auto_wrap_policy.tracing_config.concrete_args)exceptBaseExceptionase:raiseRuntimeError("tracer.trace failed inside _init_param_exec_order_wrap_policy"f" with the error: {e}.")else:assert(auto_wrap_policy.tracing_configisNone),"tracing_config should either be an instance of TracingConfig or be None"# The initial FSDP wrapping is done with auto_wrap_policy.init_policykwargs["auto_wrap_policy"]=auto_wrap_policy.init_policyself.__init__(*args,**kwargs)self._param_exec_order_policy:bool=True# self._param_exec_order_prep_stage is set to True before we get the execution orderself._param_exec_order_prep_stage:bool=True# A list that stores the flatten parameters and its name based on the parameter execution orderself._fsdp_params_exec_order:List[FlatParameter]=[]if_TORCH_FX_AVAILandisinstance(auto_wrap_policy.tracing_config,TracingConfig):# Initialize a dict that maps each module to its parent FSDP wrapmodule_to_fsdp:Dict[nn.Module,FullyShardedDataParallel]=dict()forwrapinself.fsdp_modules(self):module_to_fsdp[wrap.module]=wrap# Set self._fsdp_params_exec_order based on execution_info.module_forward_order.# TODO (linjianma): self._fsdp_params_exec_order will be set based on# the parameter execution order rather than module_forward_order,# once the non-recursive wrapping policy is fully implemented.forminexecution_info.module_forward_order:ifminmodule_to_fsdp:forflat_paraminmodule_to_fsdp[m].params:self._fsdp_params_exec_order.append(flat_param)self._param_exec_order_prep_stage=Falseforminself.modules():ifmisnotselfandisinstance(m,FullyShardedDataParallel):# Assignment by reference, so each children FSDP wrap has access to# the _fsdp_params_exec_order of the root modulem._fsdp_params_exec_order=self._fsdp_params_exec_orderm._param_exec_order_policy=self._param_exec_order_policym._param_exec_order_prep_stage=self._param_exec_order_prep_stagedef_use_param_exec_order_policy(self)->bool:return(hasattr(self,"_param_exec_order_policy")andself._param_exec_order_policy)def_is_param_exec_order_prep_stage(self)->bool:is_prep_stage=(hasattr(self,"_param_exec_order_prep_stage")andself._param_exec_order_prep_stage)ifnotis_prep_stage:forpinself.parameters():assert(nothasattr(p,"_params_exec_order_hook_handle")),"When not in execution order prep stage, all _params_exec_order_hook_handle should be removed."returnis_prep_stage
def_calc_grad_norm(parameters:List[torch.nn.Parameter],p:float)->torch.Tensor:r"""Calculate gradient norm of an iterable of parameters. Returns: Total norm of the parameters (viewed as a single vector). """parameters=[pforpinparametersifp.gradisnotNone]iflen(parameters)==0:returntorch.tensor(0.0)ifp==math.inf:local_norm=torch.tensor(max(par.grad.detach().abs().max()forparinparameters))else:# Compute the norm in full precision no matter whatlocal_norm=torch.linalg.vector_norm(torch.stack([torch.linalg.vector_norm(par.grad.detach(),p,dtype=torch.float32)forparinparameters]),p,)local_norm.to(dtype=parameters[0].dtype)returnlocal_normdef_get_param_to_unflat_param_names(model:torch.nn.Module,dedup_shared_params:bool=True,)->Dict[torch.nn.Parameter,List[str]]:""" Constructs a mapping from flattened parameter (including non-FSDP-module parameters) to its unflattened parameter names. For non-FSDP-module parameters, these mapped-to lists always contain a single element. The unflattened parameter names should match the keys of the model state dict. For shared parameters, only the first parameter name is included (following the ``torch.nn.Module.parameters()`` order). Args: model (torch.nn.Module): Root module (which may or may not be a :class:`FullyShardedDataParallel` instance). dedup_shared_params (bool): If ``True``, only includes the first list of unflattened parameter names corresponding to a parameter in the module walk order; if ``False``, then includes all of the unflattened parameter names. """defmodule_fn(module,prefix,param_to_unflat_param_names):# For FSDP modules, only add the entry when considering the contained# `FlattenParamsWrapper` to avoid duplicationifnotisinstance(module,FullyShardedDataParallel):forparam_name,paraminmodule.named_parameters(recurse=False):module_prefixed_param_names=(param._prefixed_param_namesiftype(param)isFlatParameterelse[param_name])# prefixed from `module`fully_prefixed_param_names=[clean_tensor_name(prefix+name)fornameinmodule_prefixed_param_names]# fully prefixed from the top level including `prefix`# If this parameter has already been visited, then it is a# shared parameter; then, only take the first parameter nameis_shared_param=paraminparam_to_unflat_param_namesifnotis_shared_param:param_to_unflat_param_names[param]=fully_prefixed_param_nameselifnotdedup_shared_params:param_to_unflat_param_names[param].extend(fully_prefixed_param_names)defreturn_fn(param_to_unflat_param_names):returnparam_to_unflat_param_namesparam_to_unflat_param_names:Dict[torch.nn.Parameter,List[str]]={}return_apply_to_modules(model,module_fn,return_fn,param_to_unflat_param_names,)def_get_param_to_param_name(model:torch.nn.Module,)->Dict[torch.nn.Parameter,str]:""" Constructs a mapping from parameters to their parameter names. ``model`` should not contain any :class:`FullyShardedDataParallel` instances, which means that none of the parameters should be ``FlatParameter`` s. As a result, compared to :meth:`_get_param_to_unflat_param_names`, the mapped values may be flattened from singleton :class:`list` s to the contained names themselves. Args: model (torch.nn.Module): Root module, which should not contain any :class:`FullyShardedDataParallel` instances. """param_to_param_names=_get_param_to_unflat_param_names(model)forparam_namesinparam_to_param_names.values():assertlen(param_names)>0,"`_get_param_to_unflat_param_names()` " \
"should not construct empty lists"iflen(param_names)>1:raiseRuntimeError("Each parameter should only map to one parameter name but got "f"{len(param_names)}: {param_names}")param_to_param_name={param:param_names[0]forparam,param_namesinparam_to_param_names.items()}returnparam_to_param_namedef_get_param_name_to_param(model:torch.nn.Module,)->Dict[str,torch.nn.Parameter]:"""Constructs the inverse mapping of :meth:`_get_param_to_param_name`."""param_to_param_name=_get_param_to_param_name(model)returndict(zip(param_to_param_name.values(),param_to_param_name.keys()))defclean_tensor_name(tensor_name:str)->str:"""Cleans the parameter or buffer name by removing any module wrapper prefixes."""# Call `replace()` twice separately since the name may not have bothtensor_name=tensor_name.replace(FSDP_WRAPPED_MODULE+".","")tensor_name=tensor_name.replace(FPW_MODULE+".","")# TODO: Explicitly replacing checkpoint_wrapper prefix is not ideal,# as it increases coupling between CheckpointWrapper and FSDP. This is also not# scalable for additional wrapped modules, we should come up with a general solution# for this issue.tensor_name=tensor_name.replace(_CHECKPOINT_PREFIX+".","")returntensor_name
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.