• Docs >
  • FullyShardedDataParallel


class torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, cpu_offload=None, fsdp_auto_wrap_policy=None, backward_prefetch=None)[source]

A wrapper for sharding Module parameters across data parallel workers. This is inspired by Xu et al. as well as the ZeRO Stage 3 from DeepSpeed. FullyShardedDataParallel is commonly shorten to FSDP.


>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> torch.cuda.set_device(device_id)
>>> sharded_module = FSDP(my_module)
>>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
>>> x = sharded_module(x, y=3, z=torch.Tensor([1]))
>>> loss = x.sum()
>>> loss.backward()
>>> optim.step()


The optimizer must be initialized after the module has been wrapped, since FSDP will shard parameters in-place and this will break any previously initialized optimizers.

  • module (nn.Module) – module to be wrapped with FSDP.

  • process_group (Optional[ProcessGroup]) – process group for sharding

  • cpu_offload (Optional [CPUOffload]) – CPU offloading config. Currently, only parameter and gradient CPU offload is supported. It can be enabled via passing in cpu_offload=CPUOffload(offload_params=True). Note that this currently implicitly enables gradient offloading to CPU in order for params and grads to be on same device to work with optimizer. This API is subject to change. Default is None in which case there will be no offloading.

  • fsdp_auto_wrap_policy

    (Optional [callable]): A callable specifying a policy to recursively wrap layers with FSDP. Note that this policy currently will only apply to child modules of the passed in module. The remainder modules are always wrapped in the returned FSDP root instance. default_auto_wrap_policy written in torch.distributed.fsdp.wrap is an example of fsdp_auto_wrap_policy callable, this policy wraps layers with parameter sizes larger than 100M. Users can supply the customized fsdp_auto_wrap_policy callable that should accept following arguments: module: nn.Module, recurse: bool, unwrapped_params: int, extra customized arguments could be added to the customized fsdp_auto_wrap_policy callable as well.


    >>> def custom_auto_wrap_policy(
    >>>     module: nn.Module,
    >>>     recurse: bool,
    >>>     unwrapped_params: int,
    >>>     # These are customizable for this policy function.
    >>>     min_num_params: int = int(1e8),
    >>> ) -> bool:
    >>>     return unwrapped_params >= min_num_params

  • backward_prefetch – (Optional[BackwardPrefetch]): This is an experimental feature that is subject to change in the the near future. It allows users to enable two different backward_prefetch algorithms to help backward communication and computation overlapping. Pros and cons of each algorithm is explained in the class BackwardPrefetch.

property module

make model.module accessible, just like DDP.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources