torch.distributed.fsdp.fully_shard
==================================

PyTorch FSDP2 (``fully_shard``)
-------------------------------

PyTorch FSDP2 provides a fully sharded data parallelism (FSDP) implementation
targeting performant eager-mode while using per-parameter sharding for improved
usability.

- If you are new to FSDP, we recommend that you start with FSDP2 due to improved
  usability.
- If you are currently using FSDP1, consider evaluating the following
  differences to see if you should switch to FSDP2:

Compared to PyTorch FSDP1 (``FullyShardedDataParallel``):

- FSDP2 uses ``DTensor``-based dim-0 per-parameter sharding for a simpler
  sharding representation compared to FSDP1's flat-parameter sharding, while
  preserving similar throughput performance. More specifically, FSDP2 chunks
  each parameter on dim-0 across the data parallel workers (using
  ``torch.chunk(dim=0)``), whereas FSDP1 flattens, concatenates, and chunks a
  group of tensors together, making reasoning about what data is present on
  each worker and resharding to different parallelisms complex. Per-parameter
  sharding provides a more intuitive user experience, relaxes constraints
  around frozen parameters, and allows for communication-free (sharded) state
  dicts, which otherwise require all-gathers in FSDP1.
- FSDP2 implements a different memory management approach to handle the
  multi-stream usages that avoids ``torch.Tensor.record_stream``. This ensures
  deterministic and expected memory usage and does not require blocking the CPU
  like in FSDP1's ``limit_all_gathers=True``.
- FSDP2 exposes APIs for manual control over prefetching and collective
  scheduling, allowing power users more customization. See the methods on
  ``FSDPModule`` below for details.
- FSDP2 simplifies some of the API surface: e.g. FSDP2 does not directly
  support full state dicts. Instead, users can reshard the sharded state dicts
  containing ``DTensor`` s to full state dicts themselves using ``DTensor``
  APIs like ``DTensor.full_tensor()`` or by using higher-level APIs like
  `PyTorch Distributed Checkpoint <https://pytorch.org/docs/stable/distributed.checkpoint.html>`_ 's
  distributed state dict APIs. Also, some other args have been removed; see
  `here <https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md>`_ for
  details.

If you are onboarding FSDP for the first time or if any of the above appeals to
your use case, we recommend that you consider using FSDP2.

See `this RFC <https://github.com/pytorch/pytorch/issues/114299>`_ for details
on system design and implementation.

.. note::
  ``torch.distributed.fsdp.fully_shard`` is currently in prototype state and
  under development. The core API will likely not change, but we may make some
  API changes if necessary.

.. currentmodule:: torch.distributed.fsdp

The frontend API is ``fully_shard`` that can be called on a ``module``:

.. autofunction:: fully_shard

Calling ``fully_shard(module)`` dynamically constructs a new class that
subclasses ``type(module)`` and an FSDP class ``FSDPModule``. For example, if
we call ``fully_shard(linear)`` on a module ``linear: nn.Linear``, then FSDP
constructs a new class ``FSDPLinear`` and changes ``linear`` 's type to this.
Otherwise, ``fully_shard`` does not change the module structure and parameter
fully-qualified names. The class ``FSDPModule`` allows providing some
FSDP-specific methods on the module.

.. autoclass:: FSDPModule
    :members:
    :member-order: bysource

.. autoclass:: UnshardHandle
    :members:

.. autofunction:: register_fsdp_forward_method

.. autoclass:: MixedPrecisionPolicy
    :members:

.. autoclass:: OffloadPolicy
    :members:

.. autoclass:: CPUOffloadPolicy
    :members: