• Docs >
  • FullyShardedDataParallel
Shortcuts

FullyShardedDataParallel

class torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=None, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False)[source]

A wrapper for sharding Module parameters across data parallel workers. This is inspired by Xu et al. as well as the ZeRO Stage 3 from DeepSpeed. FullyShardedDataParallel is commonly shortened to FSDP.

Example:

>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> torch.cuda.set_device(device_id)
>>> sharded_module = FSDP(my_module)
>>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
>>> x = sharded_module(x, y=3, z=torch.Tensor([1]))
>>> loss = x.sum()
>>> loss.backward()
>>> optim.step()

Warning

The optimizer must be initialized after the module has been wrapped, since FSDP will shard parameters in-place and this will break any previously initialized optimizers.

Warning

If the destination CUDA device has ID dev_id, either (1) module should already be placed on that device, (2) the device should be set using torch.cuda.set_device(dev_id), or (3) dev_id should be passed into the device_id constructor argument. This FSDP instance’s compute device will be that destination device. For (1) and (3), the FSDP initialization always occurs on GPU. For (2), the FSDP initialization happens on module ‘s current device, which may be CPU.

Warning

FSDP currently does not support gradient accumulation outside no_sync() when using CPU offloading. Trying to do so yields incorrect results since FSDP will use the newly-reduced gradient instead of accumulating with any existing gradient.

Warning

Changing the original parameter variable names after construction will lead to undefined behavior.

Warning

Passing in sync_module_states=True flag requires module to be put on GPU, or to use device_id argument to specify a CUDA device that FSDP will move module to. This is because sync_module_states=True requires GPU communication.

Warning

As of PyTorch 1.12, FSDP only offers limited support for shared parameters (for example, setting one Linear layer’s weight to another’s). In particular, modules that share parameters must be wrapped as part of the same FSDP unit. If enhanced shared parameter support is needed for your use case, please ping https://github.com/pytorch/pytorch/issues/77724

Note

Inputs into FSDP forward function will be moved to compute device (same device FSDP module is on) before running forward, so user does not have to manually move inputs from CPU -> GPU.

Parameters
  • module (nn.Module) – module to be wrapped with FSDP.

  • process_group (Optional[ProcessGroup]) – process group for sharding

  • sharding_strategy (Optional[ShardingStrategy]) – Config sharding algorithm, different sharding algorithm has trade off between memory saving and communication overhead. FULL_SHARD will be chosen if sharding_strategy is not specified.

  • cpu_offload (Optional[CPUOffload]) – CPU offloading config. Currently, only parameter and gradient CPU offload is supported. It can be enabled via passing in cpu_offload=CPUOffload(offload_params=True). Note that this currently implicitly enables gradient offloading to CPU in order for params and grads to be on same device to work with optimizer. This API is subject to change. Default is None in which case there will be no offloading.

  • auto_wrap_policy (Optional[Callable]) –

    A callable specifying a policy to recursively wrap layers with FSDP. Note that this policy currently will only apply to child modules of the passed in module. The remainder modules are always wrapped in the returned FSDP root instance. size_based_auto_wrap_policy written in torch.distributed.fsdp.wrap is an example of auto_wrap_policy callable, this policy wraps layers with the number of parameters larger than 100M. transformer_auto_wrap_policy written in torch.distributed.fsdp.wrap is an example of auto_wrap_policy callable for tranformer-like model architectures. Users can supply the customized auto_wrap_policy callable that should accept following arguments: module: nn.Module, recurse: bool, unwrapped_params: int, extra customized arguments could be added to the customized auto_wrap_policy callable as well. It is a good practice to print out the sharded model and check whether the sharded model is what the application wants and then adjust accordingly.

    Example:

    >>> def custom_auto_wrap_policy(
    >>>     module: nn.Module,
    >>>     recurse: bool,
    >>>     unwrapped_params: int,
    >>>     # These are customizable for this policy function.
    >>>     min_num_params: int = int(1e8),
    >>> ) -> bool:
    >>>     return unwrapped_params >= min_num_params
    

  • backward_prefetch (Optional[BackwardPrefetch]) – This is an experimental feature that is subject to change in the the near future. It allows users to enable two different backward_prefetch algorithms to help backward communication and computation overlapping. Pros and cons of each algorithm is explained in the class BackwardPrefetch.

  • mixed_precision (Optional[MixedPrecision]) – A MixedPrecision instance describing the mixed precision training config to be used. MixedPrecision supports configuring parameter, buffer, and gradient communication dtype. Note that only floating point data is cast to the reduced precision. This allows users potential memory saving and training speedup while trading off accuracy during model training. If None, no mixed precision is applied. Note that if mixed_precision is enabled for FSDP model that contains BatchNorm with auto_wrap_policy, FSDP will take care to disable mixed precision for BatchNorm units by wrapping them separately in their own FSDP unit with mixed_precision=None. This is done because several BatchNorm kernels do not implement reduced type support at the moment. If individually wrapping the model, users must take care to set mixed_precision=None for BatchNorm units. (Default: None)

  • ignored_modules (Optional[Iterable[torch.nn.Module]]) – Modules whose own parameters and child modules’ parameters and buffers are ignored by this instance. None of the modules directly in ignored_modules should be FullyShardedDataParallel instances, and any child modules that are already-constructed FullyShardedDataParallel instances will not be ignored if they are nested under this instance. This argument may be used to avoid sharding specific parameters when using an auto_wrap_policy or if parameters’ sharding is not managed by FSDP. (Default: None)

  • param_init_fn (Optional[Callable[[nn.Module], None]]) –

    A Callable[torch.nn.Module] -> None that specifies how modules that are currently on the meta device should be initialized onto an actual device. Note that as of v1.12, we detect modules on the meta device via is_meta check and apply a default initialization that calls reset_parameters method on the passed in nn.Module if param_init_fn is not specified, otherwise we run param_init_fn to initialize the passed in nn.Module. In particular, this means that if is_meta=True for any module parameters for modules that will be wrapped with FSDP and param_init_fn is not specified, we assume your module properly implements a reset_paramters() and will throw errors if not. Note that additionally, we offer support for modules initialized with torchdistX’s (https://github.com/pytorch/torchdistX) deferred_init API. In this case, deferred modules would be initialized by a default initialization function that calls torchdistX’s materialize_module, or the passed in param_init_fn, if it is not None. The same Callable is applied to initialize all meta modules. Note that this initialization function is applied before doing any FSDP sharding logic.

    Example:

    >>> module = MyModule(device="meta")
    >>> def my_init_fn(module):
    >>>     # responsible for initializing a module, such as with reset_parameters
    >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy)
    >>> print(next(fsdp_model.parameters()).device) # current CUDA device
    >>> # With torchdistX
    >>> module = deferred_init.deferred_init(MyModule, device="cuda")
    >>> # Will initialize via deferred_init.materialize_module().
    >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
    

  • device_id (Optional[Union[int, torch.device]]) – An int or torch.device describing the CUDA device the FSDP module should be moved to determining where initialization such as sharding takes place. If this argument is not specified and module is on CPU, we will move module to current CUDA device for faster initialization and move module back to CPU before returning. If specified, resulting FSDP instances will reside on this device. Note that if device_id is specified but module is already on a different CUDA device, an error will be thrown. (Default: None)

  • sync_module_states (bool) – If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to ensure they are the same across all ranks after initialization. This helps ensure model parameters are the same across ranks before starting training, but adds communication overhead to __init__, as at least one broadcast is triggered per individually wrapped FSDP unit. This can also help load checkpoints taken by state_dict and to be loaded by load_state_dict in a memory efficient way. See documentation for FullStateDictConfig for an example of this. (Default: False)

apply(fn)[source]

Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model (see also torch.nn.init).

Compared to torch.nn.Module.apply, this version additionally gathers the full parameters before applying fn. It should not be called from within another summon_full_params context.

Parameters

fn (Module -> None) – function to be applied to each submodule

Returns

self

Return type

Module

clip_grad_norm_(max_norm, norm_type=2.0)[source]

Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.

Parameters
  • max_norm (float or int) – max norm of the gradients

  • norm_type (float or int) – type of the used p-norm. Can be 'inf' for infinity norm.

Returns

Total norm of the parameters (viewed as a single vector).

Note

This is analogous to torch.nn.utils.clip_grad_norm_ but handles the partitioning and multiple devices per rank under the hood. The default torch util is not applicable here, because each rank only has a partial view of all the grads in the model, so calling it for FSDP models would lead to different scaling being applied per subset of model parameters.

Warning

This needs to be called on all ranks, since synchronization primitives will be used.

static fsdp_modules(module, root_only=False)[source]

Returns all nested FSDP instances, possibly including module itself and only including FSDP root modules if root_only=True.

Parameters
  • module (torch.nn.Module) – Root module, which may or may not be an FSDP module.

  • root_only (bool) – Whether to return only FSDP root modules. (Default: False)

Returns

FSDP modules that are nested in the input module.

Return type

List[FullyShardedDataParallel]

static full_optim_state_dict(model, optim, optim_input=None, rank0_only=True)[source]

Consolidates the full optimizer state on rank 0 and returns it as a dict following the convention of torch.optim.Optimizer.state_dict(), i.e. with keys "state" and "param_groups". The flattened parameters in FSDP modules contained in model are mapped back to their unflattened parameters.

Warning

This needs to be called on all ranks since synchronization primitives are used. However, if rank0_only=True, then the state dict is only populated on rank 0, and all other ranks return an empty dict.

Warning

Unlike torch.optim.Optimizer.state_dict(), this method uses full parameter names as keys instead of parameter IDs.

Warning

If you do not pass model.parameters() as the first argument to the optimizer, then you should pass that same value to this method as optim_input.

Note

Like in torch.optim.Optimizer.state_dict(), the tensors contained in the optimizer state dict are not cloned, so there may be aliasing surprises. For best practices, consider saving the returned optimizer state dict immediately, e.g. using torch.save().

Parameters
  • model (torch.nn.Module) – Root module (which may or may not be a FullyShardedDataParallel instance) whose parameters were passed into the optimizer optim.

  • optim (torch.optim.Optimizer) – Optimizer for model ‘s parameters.

  • optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer optim representing either a list of parameter groups or an iterable of parameters; if None, then this method assumes the input was model.parameters(). (Default: None)

  • rank0_only (bool) – If True, saves the populated dict only on rank 0; if False, saves it on all ranks. (Default: True)

Returns

A dict containing the optimizer state for model ‘s original unflattened parameters and including keys “state” and “param_groups” following the convention of torch.optim.Optimizer.state_dict(). If rank0_only=True, then nonzero ranks return an empty dict.

Return type

Dict[str, Any]

load_state_dict(state_dict, *args)[source]

The entry point of all three FSDP load_state_dict APIs. By default, calling load_state_dict on an FSDP module will result in FSDP attempting to load a “full” state_dict, i.e. a state_dict consisting of full, unsharded, unflattened original module parameters. This requires FSDP to load the full parameter context on each rank which could result in GPU OOM. As a result, state_dict_type() API is available to configure between load_state_dict implementations. User can thus use with self.state_dict_type(self, StateDictType.LOCAL_STATE_DICT) context manager to load a local state dict checkpoint that will restore only local shards of the module. Currently, the only supported implementations are StateDictType.LOCAL_STATE_DICT and StateDictType.FULL_STATE_DICT (default). Please see state_dict() for documentation around creating an FSDP checkpoint.

Example:

>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.distributed.fsdp import StateDictType
>>> torch.cuda.set_device(device_id)
>>> my_module = nn.Linear(...)
>>> sharded_module = FSDP(my_module)
>>> checkpoint = torch.load(PATH)
>>> full_state_dict = checkpoint['full_state_dict']
>>> with FSDP.state_dict_type(sharded_module, StateDictType.FULL_STATE_DICT):
>>>     sharded_module.load_state_dict(full_state_dict)
>>> full_dict.keys()
>>> odict_keys(['weight', 'bias'])
>>> # using local state dict
>>> local_state_dict = checkpoint['local_state_dict']
>>> with FSDP.state_dict_type(sharded_module, StateDictType.LOCAL_STATE_DICT):
>>>     sharded_module.load_state_dict(local_state_dict)
>>> local_dict.keys()
>>> odict_keys(['flat_param', 'inner.flat_param'])

Warning

This needs to be called on all ranks, since synchronization primitives may be used.

property module

make model.module accessible, just like DDP.

named_buffers(*args, **kwargs)[source]

Overrides named_buffers() to intercept buffer names and remove all occurrences of the FSDP-specific flattened buffer prefix when inside the summon_full_params() context manager.

named_parameters(*args, **kwargs)[source]

Overrides named_parameters() to intercept parameter names and remove all occurrences of the FSDP-specific flattened parameter prefix when inside the summon_full_params() context manager.

no_sync()[source]

A context manager to disable gradient synchronizations across FSDP instances. Within this context, gradients will be accumulated in module variables, which will later be synchronized in the first forward-backward pass after exiting the context. This should only be used on the root FSDP instance and will recursively apply to all children FSDP instances.

Note

This likely results in higher memory usage because FSDP will accumulate the full model gradients (instead of gradient shards) until the eventual sync.

Note

When used with CPU offloading, the gradients will not be offloaded to CPU when inside the context manager. Instead, they will only be offloaded right after the eventual sync.

property params_with_grad

Recursively returns a list of all module parameters that have a gradient.

static rekey_optim_state_dict(optim_state_dict, optim_state_key_type, model, optim_input=None)[source]

Re-keys the optimizer state dict optim_state_dict to use the key type optim_state_key_type. This can be used to achieve compatibility between optimizer state dicts from models with FSDP instances and ones without.

To re-key an FSDP full optimizer state dict (i.e. from full_optim_state_dict()) to use parameter IDs and be loadable to a non-wrapped model:

>>> wrapped_model, wrapped_optim = ...
>>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim)
>>> nonwrapped_model, nonwrapped_optim = ...
>>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model)
>>> nonwrapped_optim.load_state_dict(rekeyed_osd)

To re-key a normal optimizer state dict from a non-wrapped model to be loadable to a wrapped model:

>>> nonwrapped_model, nonwrapped_optim = ...
>>> osd = nonwrapped_optim.state_dict()
>>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model)
>>> wrapped_model, wrapped_optim = ...
>>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model)
>>> wrapped_optim.load_state_dict(sharded_osd)
Returns

The optimizer state dict re-keyed using the parameter keys specified by optim_state_key_type.

Return type

Dict[str, Any]

static scatter_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, group=None)[source]

Scatters the full optimizer state dict from rank 0 to all other ranks, returning the sharded optimizer state dict on each rank. The return value is the same as shard_full_optim_state_dict(), and on rank 0, the first argument should be the return value of full_optim_state_dict().

Example:

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> model, optim = ...
>>> full_osd = FSDP.full_optim_state_dict(model, optim)  # only non-empty on rank 0
>>> # Define new model with possibly different world size
>>> new_model, new_optim, new_group = ...
>>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group)
>>> new_optim.load_state_dict(sharded_osd)

Note

Both shard_full_optim_state_dict() and scatter_full_optim_state_dict() may be used to get the sharded optimizer state dict to load. Assuming that the full optimizer state dict resides in CPU memory, the former requires each rank to have the full dict in CPU memory, where each rank individually shards the dict without any communication, while the latter requires only rank 0 to have the full dict in CPU memory, where rank 0 moves each shard to GPU memory (for NCCL) and communicates it to ranks appropriately. Hence, the former has higher aggregate CPU memory cost, while the latter has higher communication cost.

Parameters
  • full_optim_state_dict (Optional[Dict[str, Any]]) – Optimizer state dict corresponding to the unflattened parameters and holding the full non-sharded optimizer state if on rank 0; the argument is ignored on nonzero ranks.

  • model (torch.nn.Module) – Root module (which may or may not be a FullyShardedDataParallel instance) whose parameters correspond to the optimizer state in full_optim_state_dict.

  • optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer representing either a list of parameter groups or an iterable of parameters; if None, then this method assumes the input was model.parameters(); the argument is ignored on nonzero ranks. (Default: None)

  • group (Optional[Any]) – Model’s process group or None if using the default process group. (Default: None)

Returns

The full optimizer state dict now remapped to flattened parameters instead of unflattened parameters and restricted to only include this rank’s part of the optimizer state.

Return type

Dict[str, Any]

static shard_full_optim_state_dict(full_optim_state_dict, model, optim_input=None)[source]

Shards the full optimizer state dict full_optim_state_dict by remapping the state to flattened parameters instead of unflattened parameters and restricting to only this rank’s part of the optimizer state. The first argument should be the return value of full_optim_state_dict().

Example:

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> model, optim = ...
>>> full_osd = FSDP.full_optim_state_dict(model, optim)
>>> torch.save(full_osd, PATH)
>>> # Define new model with possibly different world size
>>> new_model, new_optim = ...
>>> full_osd = torch.load(PATH)
>>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model)
>>> new_optim.load_state_dict(sharded_osd)

Warning

If you do not pass model.parameters() as the first argument to the optimizer, then you should pass that same value to this method as optim_input.

Note

Both shard_full_optim_state_dict() and scatter_full_optim_state_dict() may be used to get the sharded optimizer state dict to load. Assuming that the full optimizer state dict resides in CPU memory, the former requires each rank to have the full dict in CPU memory, where each rank individually shards the dict without any communication, while the latter requires only rank 0 to have the full dict in CPU memory, where rank 0 moves each shard to GPU memory (for NCCL) and communicates it to ranks appropriately. Hence, the former has higher aggregate CPU memory cost, while the latter has higher communication cost.

Parameters
  • full_optim_state_dict (Dict[str, Any]) – Optimizer state dict corresponding to the unflattened parameters and holding the full non-sharded optimizer state.

  • model (torch.nn.Module) – Root module (which may or may not be a FullyShardedDataParallel instance) whose parameters correspond to the optimizer state in full_optim_state_dict.

  • optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer representing either a list of parameter groups or an iterable of parameters; if None, then this method assumes the input was model.parameters(). (Default: None)

Returns

The full optimizer state dict now remapped to flattened parameters instead of unflattened parameters and restricted to only include this rank’s part of the optimizer state.

Return type

Dict[str, Any]

state_dict(*args, **kwargs)[source]

This is the entry point of all three FSDP state_dict APIs: full, local, and sharded. For the full state dict (StateDictType.FULL_STATE_DICT), FSDP attempts to unshard the model on all ranks, which may result in an OOM error if the full model cannot fit on a single GPU. In that case, users may pass in a FullStateDictConfig to only save the checkpoint on rank 0 and/ or to offload it to CPU memory layer by layer, enabling much larger checkpoints. If the full model cannot fit in CPU memory, then users may instead take a local state dict (StateDictType.LOCAL_STATE_DICT) that only saves the local shard of the model. The sharded state dict (StateDictType.SHARDED_STATE_DICT) saves the model parameters as ShardedTensor s. The state_dict type can be configured using the state_dict_type() context manager.

Example:

>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.distributed.fsdp import StateDictType
>>> torch.cuda.set_device(device_id)
>>> my_module = nn.Linear(...)
>>> sharded_module = FSDP(my_module)
>>> full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
>>> with FSDP.state_dict_type(sharded_module, StateDictType.FULL_STATE_DICT, full_state_dict_config):
>>>     full_dict = sharded_module.state_dict()
>>> full_dict.keys()
>>> odict_keys(['weight', 'bias'])
>>> # using local state dict
>>> with FSDP.state_dict_type(sharded_module, StateDictType.LOCAL_STATE_DICT):
>>>     local_dict = sharded_module.state_dict()
>>> local_dict.keys()
>>> odict_keys(['flat_param', 'inner.flat_param'])

Warning

This needs to be called on all ranks, since synchronization primitives may be used.

static state_dict_type(module, state_dict_type, state_dict_config=None)[source]

A context manager to set the state_dict_type of all the descendant FSDP modules of the target module. The target module does not have to be a FSDP module. If the target module is a FSDP module, its state_dict_type will also be changed.

Note

This API should be called for only the top-level (root) module.

Note

This API enables users to transparently use the conventional state_dict API to take model checkpoints in cases where the root FSDP module is wrapped by another nn.Module. For example, the following will ensure state_dict is called on all non-FSDP instances, while dispatching into local_state_dict implementation for FSDP:

Example:

>>> model = DDP(FSDP(...))
>>> with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
>>>     checkpoint = model.state_dict()
Parameters
  • module (torch.nn.Module) – Root module.

  • state_dict_type (StateDictType) – the desired state_dict_type to set.

static summon_full_params(module, recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False)[source]

A context manager to expose full params for FSDP instances. Can be useful after forward/backward for a model to get the params for additional processing or checking. It can take a non-FSDP module and will summon full params for all contained FSDP modules as well as their children, depending on the recurse argument.

Note

This can be used on inner FSDPs.

Note

This can not be used within a forward or backward pass. Nor can forward and backward be started from within this context.

Note

Parameters will revert to their local shards after the context manager exits, storage behavior is the same as forward.

Note

The full parameters can be modified, but only the portion corresponding to the local param shard will persist after the context manager exits (unless writeback=False, in which case changes will be discarded). In the case where FSDP does not shard the parameters, currently only when world_size == 1, or NO_SHARD config, the modification is persisted regardless of writeback.

Note

This method works on modules which are not FSDP themselves but may contain multiple independent FSDP units. In that case, the given arguments will apply to all contained FSDP units.

Warning

Note that rank0_only=True in conjunction with writeback=True is not currently supported and will raise an error. This is because model parameter shapes would be different across ranks within the context, and writing to them can lead to inconsistency across ranks when the context is exited.

Warning

Note that offload_to_cpu and rank0_only=False will result in full parameters being redundantly copied to CPU memory for GPUs that reside on the same machine, which may incur the risk of CPU OOM. It is recommended to use offload_to_cpu with rank0_only=True.

Parameters
  • recurse (bool, Optional) – recursively summon all params for nested FSDP instances (default: True).

  • writeback (bool, Optional) – if False, modifications to params are discarded after the context manager exists; disabling this can be slightly more efficient (default: True)

  • rank0_only (bool, Optional) – if True, full parameters are materialized on only global rank 0. This means that within the context, only rank 0 will have full parameters and the other ranks will have sharded parameters. Note that setting rank0_only=True with writeback=True is not supported, as model parameter shapes will be different across ranks within the context, and writing to them can lead to inconsistency across ranks when the context is exited.

  • offload_to_cpu (bool, Optional) – If True, full parameters are offloaded to CPU. Note that this offloading currently only occurs if the parameter is sharded (which is only not the case for world_size = 1 or NO_SHARD config). It is recommended to use offload_to_cpu with rank0_only=True to avoid redundant copies of model parameters being offloaded to the same CPU memory.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources