• Docs >
  • How to do DistributedDataParallel(DDP)
Shortcuts

How to do DistributedDataParallel(DDP)

This document shows how to use torch.nn.parallel.DistributedDataParallel in xla, and further describes its difference against the native xla data parallel approach. You can find a minimum runnable example here.

Background / Motivation

Customers have long requested the ability to use PyTorch’s DistributedDataParallel API with xla. And here we enable it as an experimental feature.

How to use DistributedDataParallel

For those who switched from the PyTorch eager mode to XLA, here are all the changes you need to do to convert your eager DDP model into XLA model. We assume that you already know how to use XLA on a single device.

  1. Import xla specific distributed packages:

    import torch_xla
    import torch_xla.runtime as xr
    import torch_xla.distributed.xla_backend
    
  2. Init xla process group similar to other process groups such as nccl and gloo.

    dist.init_process_group("xla", rank=rank, world_size=world_size)
    
  3. Use xla specific APIs to get rank and world_size if you need to.

    new_rank = xr.global_ordinal()
    world_size = xr.world_size()
    
  4. Pass gradient_as_bucket_view=True to the DDP wrapper.

    ddp_model = DDP(model, gradient_as_bucket_view=True)
    
  5. Finally launch your model with xla specific launcher.

    torch_xla.launch(demo_fn)
    

Here we have put everything together (the example is actually taken from the DDP tutorial). The way you code it is pretty similar to the eager experience. Just with xla specific touches on a single device plus the above five changes to your script.

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from torch.nn.parallel import DistributedDataParallel as DDP

# additional imports for xla
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.xla_backend

def setup(rank, world_size):
    os.environ['PJRT_DEVICE'] = 'TPU'

    # initialize the xla process group
    dist.init_process_group("xla", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 1000000)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(1000000, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

def demo_basic(rank):
    # xla specific APIs to get rank, world_size.
    new_rank = xr.global_ordinal()
    assert new_rank == rank
    world_size = xr.world_size()

    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to XLA device
    device = xm.xla_device()
    model = ToyModel().to(device)
    # currently, graident_as_bucket_view is needed to make DDP work for xla
    ddp_model = DDP(model, gradient_as_bucket_view=True)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10).to(device))
    labels = torch.randn(20, 5).to(device)
    loss_fn(outputs, labels).backward()
    optimizer.step()
    # xla specific API to execute the graph
    xm.mark_step()

    cleanup()


def run_demo(demo_fn):
    # xla specific launcher
    torch_xla.launch(demo_fn)

if __name__ == "__main__":
    run_demo(demo_basic)

Benchmarking

Resnet50 with fake data

The following results are collected with the command on a TPU VM V3-8 environment with ToT PyTorch and PyTorch/XLA:

python test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1

And the statistical metrics are produced by using the script in this pull request. The unit for the rate is images per second.

Type Mean Median 90th % Std deviation CV
xm.optimizer_step 418.54 419.22 430.40 9.76 0.02
DDP 395.97 395.54 407.13 7.60 0.02

The performance difference between our native approach for distributed data parallel and DistributedDataParallel wrapper is: 1 - 395.97 / 418.54 = 5.39%. This result seems reasonable given the DDP wrapper introduces extra overheads on tracing the DDP runtime.

MNIST with fake data

The following results are collected with the command: python test/test_train_mp_mnist.py --fake_data on a TPU VM V3-8 environment with ToT PyTorch and PyTorch/XLA. And the statistical metrics are produced by using the script in this pull request. The unit for the rate is images per second.

Type Mean Median 90th % Std Dev CV
xm.optimizer_step 17864.19 20108.96 24351.74 5866.83 0.33
DDP 10701.39 11770.00 14313.78 3102.92 0.29

The performance difference between our native approach for distributed data parallel and DistributedDataParallel wrapper is: 1 - 14313.78 / 24351.74 = 41.22%. Here we compare 90th % instead since the dataset is small and first a few rounds are heavily impacted by data loading. This slowdown is huge but makes sense given the model is small. The additional DDP runtime tracing overhead is hard to amortize.

MNIST with real data

The following results are collected with the command n a TPU VM V3-8 environment with ToT PyTorch and PyTorch/XLA:

python test/test_train_mp_mnist.py --logdir mnist/ o.

And we can observe that the DDP wrapper converges slower than the native XLA approach even though it still achieves a high accuracy rate at 97.48% at the end. (The native approach achieves 99%.)

Disclaimer

This feature is still experimental and under active development. Use it in cautions and feel free to file any bugs to the xla github repo. For those who are interested in the native xla data parallel approach, here is the tutorial.

Here are some of the known issues that are under investigation: * gradient_as_bucket_view=True needs to be enforced. * There are some issues while being used with torch.utils.data.DataLoader. test_train_mp_mnist.py with real data crashes before exiting.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources