Shortcuts

Introduction || What is DDP || Single-Node Multi-GPU Training || Fault Tolerance || Multi-Node training || minGPT Training

Multi GPU training with DDP

Created On: Sep 27, 2022 | Last Updated: Nov 03, 2024 | Last Verified: Not Verified

Authors: Suraj Subramanian

What you will learn
  • How to migrate a single-GPU training script to multi-GPU via DDP

  • Setting up the distributed process group

  • Saving and loading models in a distributed setup

View the code used in this tutorial on GitHub

Prerequisites
  • High-level overview of how DDP works

  • A machine with multiple GPUs (this tutorial uses an AWS p3.8xlarge instance)

  • PyTorch installed with CUDA

Follow along with the video below or on youtube.

In the previous tutorial, we got a high-level overview of how DDP works; now we see how to use DDP in code. In this tutorial, we start with a single-GPU training script and migrate that to running it on 4 GPUs on a single node. Along the way, we will talk through important concepts in distributed training while implementing them in our code.

Note

If your model contains any BatchNorm layers, it needs to be converted to SyncBatchNorm to sync the running stats of BatchNorm layers across replicas.

Use the helper function torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) to convert all BatchNorm layers in the model to SyncBatchNorm.

Diff for single_gpu.py v/s multigpu.py

These are the changes you typically make to a single-GPU training script to enable DDP.

Imports

  • torch.multiprocessing is a PyTorch wrapper around Python’s native multiprocessing

  • The distributed process group contains all the processes that can communicate and synchronize with each other.

import torch
import torch.nn.functional as F
from utils import MyTrainDataset

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os

Constructing the process group

  • First, before initializing the group process, call set_device, which sets the default GPU for each process. This is important to prevent hangs or excessive memory utilization on GPU:0

  • The process group can be initialized by TCP (default) or from a shared file-system. Read more on process group initialization

  • init_process_group initializes the distributed process group.

  • Read more about choosing a DDP backend

def ddp_setup(rank: int, world_size: int):
   """
   Args:
       rank: Unique identifier of each process
      world_size: Total number of processes
   """
   os.environ["MASTER_ADDR"] = "localhost"
   os.environ["MASTER_PORT"] = "12355"
   torch.cuda.set_device(rank)
   init_process_group(backend="nccl", rank=rank, world_size=world_size)

Constructing the DDP model

self.model = DDP(model, device_ids=[gpu_id])

Distributing input data

  • DistributedSampler chunks the input data across all distributed processes.

  • The DataLoader combines a dataset and a

    sampler, and provides an iterable over the given dataset.

  • Each process will receive an input batch of 32 samples; the effective batch size is 32 * nprocs, or 128 when using 4 GPUs.

train_data = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=False,  # We don't shuffle
    sampler=DistributedSampler(train_dataset), # Use the Distributed Sampler here.
)
  • Calling the set_epoch() method on the DistributedSampler at the beginning of each epoch is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be used in each epoch.

def _run_epoch(self, epoch):
    b_sz = len(next(iter(self.train_data))[0])
    self.train_data.sampler.set_epoch(epoch)   # call this additional line at every epoch
    for source, targets in self.train_data:
      ...
      self._run_batch(source, targets)

Saving model checkpoints

  • We only need to save model checkpoints from one process. Without this condition, each process would save its copy of the identical mode. Read more on saving and loading models with DDP here

- ckp = self.model.state_dict()
+ ckp = self.model.module.state_dict()
...
...
- if epoch % self.save_every == 0:
+ if self.gpu_id == 0 and epoch % self.save_every == 0:
  self._save_checkpoint(epoch)

Warning

Collective calls are functions that run on all the distributed processes, and they are used to gather certain states or values to a specific process. Collective calls require all ranks to run the collective code. In this example, _save_checkpoint should not have any collective calls because it is only run on the rank:0 process. If you need to make any collective calls, it should be before the if self.gpu_id == 0 check.

Running the distributed training job

  • Include new arguments rank (replacing device) and world_size.

  • rank is auto-allocated by DDP when calling mp.spawn.

  • world_size is the number of processes across the training job. For GPU training, this corresponds to the number of GPUs in use, and each process works on a dedicated GPU.

- def main(device, total_epochs, save_every):
+ def main(rank, world_size, total_epochs, save_every):
+  ddp_setup(rank, world_size)
   dataset, model, optimizer = load_train_objs()
   train_data = prepare_dataloader(dataset, batch_size=32)
-  trainer = Trainer(model, train_data, optimizer, device, save_every)
+  trainer = Trainer(model, train_data, optimizer, rank, save_every)
   trainer.train(total_epochs)
+  destroy_process_group()

if __name__ == "__main__":
   import sys
   total_epochs = int(sys.argv[1])
   save_every = int(sys.argv[2])
-  device = 0      # shorthand for cuda:0
-  main(device, total_epochs, save_every)
+  world_size = torch.cuda.device_count()
+  mp.spawn(main, args=(world_size, total_epochs, save_every,), nprocs=world_size)

Here’s what the code looks like:

Further Reading

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources