# Fully Sharded Data Parallel in PyTorch XLA

Fully Sharded Data Parallel (FSDP) in PyTorch XLA is a utility for
sharding Module parameters across data-parallel workers.

Example usage:

``` python3
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP

model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()
```

It is also possible to shard individual layers separately and have an
outer wrapper handle any leftover parameters.

Notes: The `XlaFullyShardedDataParallel` class supports both the ZeRO-2
optimizer (sharding gradients and optimizer states) and the ZeRO-3
optimizer (sharding parameters, gradients, and optimizer states) in
<https://arxiv.org/abs/1910.02054>. The ZeRO-3 optimizer should be
implemented via nested FSDP with `reshard_after_forward=True`. See
`test/test_train_mp_mnist_fsdp_with_ckpt.py` and
`test/test_train_mp_imagenet_fsdp.py` for an example. \* For large
models that cannot fit into a single TPU memory or the host CPU memory,
one should interleave submodule construction with inner FSDP wrapping.
See
[FSDPViTModel](https://github.com/ronghanghu/vit_10b_fsdp_example/blob/master/run_vit_training.py)
for an example. a simple wrapper `checkpoint_module` is provided (based
on `torch_xla.utils.checkpoint.checkpoint` from
<https://github.com/pytorch/xla/pull/3524>) to perform [gradient
checkpointing](https://spell.ml/blog/gradient-checkpointing-pytorch-YGypLBAAACEAefHs)
over a given `nn.Module` instance. See
`test/test_train_mp_mnist_fsdp_with_ckpt.py` and
`test/test_train_mp_imagenet_fsdp.py` for an example. Auto-wrapping
submodules: instead of manually nested FSDP wrapping, one can also
specify an `auto_wrap_policy` argument to automatically wrap the
submodules with inner FSDP. `size_based_auto_wrap_policy` in
`torch_xla.distributed.fsdp.wrap` is an example of `auto_wrap_policy`
callable, this policy wraps layers with the number of parameters larger
than 100M. `transformer_auto_wrap_policy` in
`torch_xla.distributed.fsdp.wrap` is an example of `auto_wrap_policy`
callable for transformer-like model architectures.

For example, to automatically wrap all `torch.nn.Conv2d` submodules with
inner FSDP, one can use:

``` python3
from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})
```

Additionally, one can also specify an `auto_wrapper_callable` argument
to use a custom callable wrapper for the submodules (the default wrapper
is just the `XlaFullyShardedDataParallel` class itself). For example,
one can use the following to apply gradient checkpointing (i.e.
activation checkpointing/rematerialization) to each auto-wrapped
submodule.

``` python3
from torch_xla.distributed.fsdp import checkpoint_module
auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel(
    checkpoint_module(m), *args, **kwargs)
```

-   When stepping the optimizer, directly call `optimizer.step` and do
    not call `xm.optimizer_step`. The latter reduces the gradient across
    ranks, which is not needed for FSDP (where the parameters are
    already sharded).
-   When saving model and optimizer checkpoints during training, each
    training process needs to save its own checkpoint of the (sharded)
    model and optimizer state dicts (use `master_only=False` and set
    different paths for each rank in `xm.save`). When resuming, it needs
    to load the checkpoint for the corresponding rank.
-   Please also save `model.get_shard_metadata()` along with
    `model.state_dict()` as follows and use
    `consolidate_sharded_model_checkpoints` to stitch the sharded model
    checkpoints together into a full model state dict. See
    `test/test_train_mp_mnist_fsdp_with_ckpt.py` for an example.

``` python3
ckpt = {
    'model': model.state_dict(),
    'shard_metadata': model.get_shard_metadata(),
    'optimizer': optimizer.state_dict(),
}
ckpt_path = f'/tmp/rank-{xr.global_ordinal()}-of-{xr.world_size()}.pth'
xm.save(ckpt, ckpt_path, master_only=False)
```

-   The checkpoint consolidation script can also be launched from the
    command line as follows.

``` bash
# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
  --ckpt_prefix /path/to/your_sharded_checkpoint_files \
  --ckpt_suffix "_rank-*-of-*.pth"
```

The implementation of this class is largely inspired by and mostly
follows the structure of `fairscale.nn.FullyShardedDataParallel` in
<https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html>. One of
the biggest differences from `fairscale.nn.FullyShardedDataParallel` is
that in XLA we don't have explicit parameter storage, so here we resort
to a different approach to free full parameters for ZeRO-3.

## Example training scripts on MNIST and ImageNet

-   Minimum example :
    [examples/fsdp/train_resnet_fsdp_auto_wrap.py](https://github.com/pytorch/xla/blob/master/examples/fsdp/train_resnet_fsdp_auto_wrap.py)
-   MNIST:
    [test/test_train_mp_mnist_fsdp_with_ckpt.py](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist_fsdp_with_ckpt.py)
    (it also tests checkpoint consolidation)
-   ImageNet:
    [test/test_train_mp_imagenet_fsdp.py](https://github.com/pytorch/xla/blob/master/test/test_train_mp_imagenet_fsdp.py)

### Installation

FSDP is available on PyTorch/XLA 1.12 release and newer nightly. Please
refer to <https://github.com/pytorch/xla#-available-images-and-wheels>
for installation guide.

### Clone PyTorch/XLA repo

``` bash
git clone --recursive https://github.com/pytorch/pytorch
cd pytorch/
git clone --recursive https://github.com/pytorch/xla.git
cd ~/
```

### Train MNIST on v3-8 TPU

It gets around 98.9 accuracy for 2 epochs:

``` bash
python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
  --batch_size 16 --drop_last --num_epochs 2 \
  --use_nested_fsdp --use_gradient_checkpointing
```

This script automatically tests checkpoint consolidation at the end. You
can also manually consolidate the sharded checkpoints via

``` bash
# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
  --ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
  --ckpt_suffix "_rank-*-of-*.pth"
```

### Train ImageNet with ResNet-50 on v3-8 TPU

It gets around 75.9 accuracy for 100 epochs; download
[ImageNet-1k](https://github.com/pytorch/examples/tree/master/imagenet#requirements)
to `/datasets/imagenet-1k`:

``` bash
python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
  --datadir /datasets/imagenet-1k --drop_last \
  --model resnet50 --test_set_batch_size 64 --eval_interval 10 \
  --lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \
  --use_nested_fsdp
```

You can also add `--use_gradient_checkpointing` (which needs to be used
along with `--use_nested_fsdp` or `--auto_wrap_policy`) to apply
gradient checkpointing on the residual blocks.

## Example training scripts on TPU pod (with 10 billion parameters)

To train large models that cannot fit into a single TPU, one should
apply auto-wrap or manually wrap the submodules with inner FSDP when
building the entire model to implement the ZeRO-3 algorithm.

Please see <https://github.com/ronghanghu/vit_10b_fsdp_example> for an
example of sharded training of a Vision Transformer (ViT) model using
this XLA FSDP PR.