• Tutorials >
  • Advanced Model Training with Fully Sharded Data Parallel (FSDP)

Advanced Model Training with Fully Sharded Data Parallel (FSDP)

Author: Hamid Shojanazeri, Less Wright, Rohan Varma, Yanli Zhao

This tutorial introduces more advanced features of Fully Sharded Data Parallel (FSDP) as part of the PyTorch 1.12 release. To get familiar with FSDP, please refer to the FSDP getting started tutorial.

In this tutorial, we fine-tune a HuggingFace (HF) T5 model with FSDP for text summarization as a working example.

The example uses Wikihow and for simplicity, we will showcase the training on a single node, P4dn instance with 8 A100 GPUs. We will soon have a blog post on large scale FSDP training on a multi-node cluster, please stay tuned for that on the PyTorch medium channel.

FSDP is a production ready package with focus on ease of use, performance, and long-term support. One of the main benefits of FSDP is reducing the memory footprint on each GPU. This enables training of larger models with lower total memory vs DDP, and leverages the overlap of computation and communication to train models efficiently. This reduced memory pressure can be leveraged to either train larger models or increase batch size, potentially helping overall training throughput. You can read more about PyTorch FSDP here.

FSDP Features in This Tutorial

  • Transformer Auto Wrap Policy

  • Mixed Precision

  • Initializing FSDP Model on Device

  • Sharding Strategy

  • Backward Prefetch

  • Model Checkpoint Saving via Streaming to CPU

Recap on How FSDP Works

At a high level FDSP works as follow:

In constructor

  • Shard model parameters and each rank only keeps its own shard

In forward pass

  • Run all_gather to collect all shards from all ranks to recover the full parameter for this FSDP unit Run forward computation

  • Discard non-owned parameter shards it has just collected to free memory

In backward pass

  • Run all_gather to collect all shards from all ranks to recover the full parameter in this FSDP unit Run backward computation

  • Discard non-owned parameters to free memory.

  • Run reduce_scatter to sync gradients

Fine-tuning HF T5

HF T5 pre-trained models are available in four different sizes, ranging from small with 60 Million parameters to XXL with 11 Billion parameters. In this tutorial, we demonstrate the fine-tuning of a T5 3B with FSDP for text summarization using WikiHow dataset. The main focus of this tutorial is to highlight different available features in FSDP that are helpful for training large scale model above 3B parameters. Also, we cover specific features for Transformer based models. The code for this tutorial is available in Pytorch examples.


1.1 Install PyTorch Nightlies

We will install PyTorch nightlies, as some of the features such as activation checkpointing is available in nightlies and will be added in next PyTorch release after 1.12.

pip3 install --pre torch torchvision torchaudio -f https://download.pytorch.org/whl/nightly/cu113/torch_nightly.html

1.2 Dataset Setup

Please create a data folder, download the WikiHow dataset from wikihowAll.csv and wikihowSep.cs, and place them in the data folder. We will use the wikihow dataset from summarization_dataset.

Next, we add the following code snippets to a Python script “T5_training.py”.


The full source code for this tutorial is available in PyTorch examples.

1.3 Import necessary packages:

import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, GPT2TokenizerFast
from transformers import T5Tokenizer, T5ForConditionalGeneration
import functools
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from transformers.models.t5.modeling_t5 import T5Block

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
from torch.distributed.fsdp.wrap import (
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
from summarization_dataset import *
from transformers.models.t5.modeling_t5 import T5Block
from typing import Type
import time
import tqdm
from datetime import datetime

1.4 Distributed training setup. Here we use two helper functions to initialize the processes for distributed training, and then to clean up after training completion. In this tutorial, we are going to use torch elastic, using torchrun , which will set the worker RANK and WORLD_SIZE automatically.

def setup():
    # initialize the process group

def cleanup():

2.1 Set up the HuggingFace T5 model:

def setup_model(model_name):
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    tokenizer =  T5Tokenizer.from_pretrained(model_name)
    return model, tokenizer

We also, add couple of helper functions here for date and formatting memory metrics.

def get_date_of_run():
    """create date and time for file save uniqueness
    example: 2022-05-07-08:31:12_PM'
    date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
    print(f"--> current date and time of run = {date_of_run}")
    return date_of_run

def format_metrics_to_gb(item):
    """quick function to format numbers to gigabyte and round to 4 digit precision"""
    metric_num = item / g_gigabyte
    metric_num = round(metric_num, ndigits=4)
    return metric_num

2.2 Define a train function:

def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(2).to(local_rank)

    if sampler:
    if rank==0:
        inner_pbar = tqdm.tqdm(
            range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
    for batch in train_loader:
        for key in batch.keys():
            batch[key] = batch[key].to(local_rank)
        output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
        loss = output["loss"]
        fsdp_loss[0] += loss.item()
        fsdp_loss[1] += len(batch)
        if rank==0:

    dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
    train_accuracy = fsdp_loss[0] / fsdp_loss[1]

    if rank == 0:
                f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}"
    return train_accuracy

2.3 Define a validation function:

def validation(model, rank, world_size, val_loader):
    correct = 0
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(3).to(local_rank)
    if rank == 0:
        inner_pbar = tqdm.tqdm(
            range(len(val_loader)), colour="green", desc="Validation Epoch"
    with torch.no_grad():
        for batch in val_loader:
            for key in batch.keys():
                batch[key] = batch[key].to(local_rank)
            output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"])
            fsdp_loss[0] += output["loss"].item()  # sum up batch loss
            fsdp_loss[1] += len(batch)

            if rank==0:

    dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
    val_loss = fsdp_loss[0] / fsdp_loss[1]
    if rank == 0:
        print(f"Validation Loss: {val_loss:.4f}")
    return val_loss

2.4 Define a distributed train function that wraps the model in FSDP:

def fsdp_main(args):

    model, tokenizer = setup_model("t5-base")

    local_rank = int(os.environ['LOCAL_RANK'])
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])

    dataset = load_dataset('wikihow', 'all', data_dir='data/')
    print("Size of train dataset: ", dataset['train'].shape)
    print("Size of Validation dataset: ", dataset['validation'].shape)

    #wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
    train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
    val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False)

    sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
    sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)


    train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
    test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
    cuda_kwargs = {'num_workers': 2,
                    'pin_memory': True,
                    'shuffle': False}

    train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)

    t5_auto_wrap_policy = functools.partial(
    sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP #for Zero2 and FULL_SHARD for Zero3

    #init_start_event = torch.cuda.Event(enable_timing=True)
    #init_end_event = torch.cuda.Event(enable_timing=True)


    bf16_ready = (
    and torch.cuda.is_bf16_supported()
    and LooseVersion(torch.version.cuda) >= "11.0"
    and dist.is_nccl_available()
    and nccl.version() >= (2, 10)

    if bf16_ready:
        mp_policy = bfSixteen
        mp_policy = None # defaults to fp32

    # model is on CPU before input to FSDP
    model = FSDP(model,

    optimizer = optim.AdamW(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    best_val_loss = float("inf")
    curr_val_loss = float("inf")
    file_save_name = "T5-model-"

    if rank == 0:
        time_of_run = get_date_of_run()
        dur = []
        train_acc_tracking = []
        val_acc_tracking = []
        training_start_time = time.time()

    if rank == 0 and args.track_memory:
        mem_alloc_tracker = []
        mem_reserved_tracker = []

    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
        if args.run_validation:
            curr_val_loss = validation(model, rank, world_size, val_loader)

        if rank == 0:

            print(f"--> epoch {epoch} completed...entering save and stats zone")

            dur.append(time.time() - t0)

            if args.run_validation:

            if args.track_memory:
            print(f"completed save and stats zone...")

        if args.save_model and curr_val_loss < best_val_loss:

            # save
            if rank == 0:
                print(f"--> entering save model state")

            save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
            with FSDP.state_dict_type(
                model, StateDictType.FULL_STATE_DICT, save_policy
                cpu_state = model.state_dict()
            #print(f"saving process: rank {rank}  done w state_dict")

            if rank == 0:
                print(f"--> saving model ...")
                currEpoch = (
                    "-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt"
                print(f"--> attempting to save model prefix {currEpoch}")
                save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
                print(f"--> saving as model name {save_name}")

                torch.save(cpu_state, save_name)

        if curr_val_loss < best_val_loss:

            best_val_loss = curr_val_loss
            if rank==0:
                print(f"-->>>> New Val Loss Record: {best_val_loss}")


2.5 Parse the arguments and set the main function:

if __name__ == '__main__':
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example')
    parser.add_argument('--batch-size', type=int, default=4, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=4, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=2, metavar='N',
                        help='number of epochs to train (default: 3)')
    parser.add_argument('--lr', type=float, default=.002, metavar='LR',
                        help='learning rate (default: .002)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--track_memory', action='store_false', default=True,
                        help='track the gpu memory')
    parser.add_argument('--run_validation', action='store_false', default=True,
                        help='running the validation')
    parser.add_argument('--save-model', action='store_false', default=True,
                        help='For Saving the current Model')
    args = parser.parse_args()



To run the the training using torchrun:

torchrun --nnodes 1 --nproc_per_node 4  T5_training.py

Transformer Wrapping Policy

As discussed in the previous tutorial, auto_wrap_policy is one of the FSDP features that make it easy to automatically shard a given model and put the model, optimizer and gradient shards into distinct FSDP units.

For some architectures such as Transformer encoder-decoders, some parts of the model such as embedding table is being shared with both encoder and decoder. In this case, we need to place the embedding table in the outer FSDP unit so that it could be accessed from both encoder and decoder. In addition, by registering the layer class for a transformer, the sharding plan can be made much more communication efficient. In PyTorch 1.12, FSDP added this support and now we have a wrapping policy for transfomers.

It can be created as follows, where the T5Block represents the T5 transformer layer class (holding MHSA and FFN).

t5_auto_wrap_policy = functools.partial(

model = FSDP(model,

To see the wrapped model, you can easily print the model and visually inspect the sharding and FSDP units as well.

Mixed Precision

FSDP supports flexible mixed precision training allowing for arbitrary reduced precision types (such as fp16 or bfloat16). Currently BFloat16 is only available on Ampere GPUs, so you need to confirm native support before you use it. On V100s for example, BFloat16 can still be run but due to it running non-natively, it can result in significant slowdowns.

To check if BFloat16 is natively supported, you can use the following :

bf16_ready = (
    and torch.cuda.is_bf16_supported()
    and LooseVersion(torch.version.cuda) >= "11.0"
    and dist.is_nccl_available()
    and nccl.version() >= (2, 10)

One of the advantages of mixed percision in FSDP is providing granular control over different precision levels for parameters, gradients, and buffers as follows:

fpSixteen = MixedPrecision(
    # Gradient communication precision.
    # Buffer precision.

bfSixteen = MixedPrecision(
    # Gradient communication precision.
    # Buffer precision.

fp32_policy = MixedPrecision(
    # Gradient communication precision.
    # Buffer precision.

Note that if a certain type (parameter, reduce, buffer) is not specified, they will not be casted at all.

This flexibility allows users fine grained control, such as only setting gradient communication to happen in reduced precision, and all parameters / buffer computation to be done in full precision. This is potentially useful in cases where intra-node communication is the main bottleneck and parameters / buffers must be in full precision to avoid accuracy issues. This can be done with the following policy:

grad_bf16 = MixedPrecision(reduce_dtype=torch.bfloat16)

In 2.4 we just add the relevant mixed precision policy to the FSDP wrapper:

model = FSDP(model,

In our experiments, we have observed up to 4x speed up by using BFloat16 for training and memory reduction of approximately 30% in some experiments that can be used for batch size increases.

Intializing FSDP Model on Device

In 1.12, FSDP supports a device_id argument meant to initialize input CPU module on the device given by device_id. This is useful when the entire model does not fit on a single GPU, but fits in a host’s CPU memory. When device_id is specified, FSDP will move the model to the specified device on a per-FSDP unit basis, avoiding GPU OOM issues while initializing several times faster than CPU-based initialization:


 model = FSDP(model,

Sharding Strategy

FSDP sharding strategy by default is set to fully shard the model parameters, gradients and optimizer states get sharded across all ranks. (also termed Zero3 sharding). In case you are interested to have the Zero2 sharding strategy, where only optimizer states and gradients are sharded, FSDP support this feature by passing the Sharding strategy by using “ShardingStrategy.SHARD_GRAD_OP”, instead of “ShardingStrategy.FULL_SHARD” to the FSDP initialization as follows:


 model = FSDP(model,
        sharding_strategy=ShardingStrategy.SHARD_GRAD_OP # ZERO2)

This will reduce the communication overhead in FSDP, in this case, it holds full parameters after forward and through the backwards pass.

This saves an all_gather during backwards so there is less communication at the cost of a higher memory footprint. Note that full model params are freed at the end of backwards and all_gather will happen on the next forward pass.

Backward Prefetch

The backward prefetch setting controls the timing of when the next FSDP unit’s parameters should be requested. By setting it to BACKWARD_PRE, the next FSDP’s unit params can begin to be requested and arrive sooner before the computation of the current unit starts. This overlaps the all_gather communication and gradient computation which can increase the training speed in exchange for slightly higher memory consumption. It can be utilized in the FSDP wrapper in 2.4 as follows:


 model = FSDP(model,
        backward_prefetch = BackwardPrefetch.BACKWARD_PRE)

backward_prefetch has two modes, BACKWARD_PRE and BACKWARD_POST. BACKWARD_POST means that the next FSDP unit’s params will not be requested until the current FSDP unit processing is complete, thus minimizing memory overhead. In some cases, using BACKWARD_PRE can increase model training speed up to 2-10%, with even higher speed improvements noted for larger models.

Model Checkpoint Saving, by streaming to the Rank0 CPU

To save model checkpoints using FULL_STATE_DICT saving which saves model in the same fashion as a local model, PyTorch 1.12 offers a few utilities to support the saving of larger models.

First, a FullStateDictConfig can be specified, allowing the state_dict to be populated on rank 0 only and offloaded to the CPU.

When using this configuration, FSDP will allgather model parameters, offloading them to the CPU one by one, only on rank 0. When the state_dict is finally saved, it will only be populated on rank 0 and contain CPU tensors. This avoids potential OOM for models that are larger than a single GPU memory and allows users to checkpoint models whose size is roughly the available CPU RAM on the user’s machine.

This feature can be run as follows:

save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
            model, StateDictType.FULL_STATE_DICT, save_policy
            cpu_state = model.state_dict()
if rank == 0:
 save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
 torch.save(cpu_state, save_name)


In this tutorial, we have introduced many new features for FSDP available in Pytorch 1.12 and used HF T5 as the running example. Using the proper wrapping policy especially for transformer models, along with mixed precision and backward prefetch should speed up your training runs. Also, features such as initializing the model on device, and checkpoint saving via streaming to CPU should help to avoid OOM error in dealing with large models.

We are actively working to add new features to FSDP for the next release. If you have feedback, feature requests, questions or are encountering issues using FSDP, please feel free to contact us by opening an issue in the PyTorch Github repository.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources