Shortcuts

ignite.distributed#

Helper module to use distributed settings for multiple backends:

  • backends from native torch distributed configuration: “nccl”, “gloo”, “mpi”

  • XLA on TPUs via pytorch/xla

  • using Horovod framework as a backend

Distributed launcher and auto helpers#

We provide a context manager to simplify the code of distributed configuration setup for all above supported backends. In addition, methods like auto_model(), auto_optim() and auto_dataloader() helps to adapt in a transparent way provided model, optimizer and data loaders to existing configuration:

# main.py

import ignite.distributed as idist

def training(local_rank, config, **kwargs):

    print(idist.get_rank(), ": run with config:", config, "- backend=", idist.backend())

    train_loader = idist.auto_dataloader(dataset, batch_size=32, num_workers=12, shuffle=True, **kwargs)
    # batch size, num_workers and sampler are automatically adapted to existing configuration
    # ...
    model = resnet50()
    model = idist.auto_model(model)
    # model is DDP or DP or just itself according to existing configuration
    # ...
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    optimizer = idist.auto_optim(optimizer)
    # optimizer is itself, except XLA configuration and overrides `step()` method.
    # User can safely call `optimizer.step()` (behind `xm.optimizer_step(optimizier)` is performed)


backend = "nccl"  # torch native distributed configuration on multiple GPUs
# backend = "xla-tpu"  # XLA TPUs distributed configuration
# backend = None  # no distributed configuration
with idist.Parallel(backend=backend, **dist_configs) as parallel:
    parallel.run(training, config, a=1, b=2)

Above code may be executed with torch.distributed.launch tool or by python and specifying distributed configuration in the code. For more details, please, see Parallel, auto_model(), auto_optim() and auto_dataloader().

Complete example of CIFAR10 training can be found here.

ignite.distributed.auto#

auto_dataloader

Helper method to create a dataloader adapted for non-distributed and distributed configurations (supporting all available backends from available_backends()).

auto_model

Helper method to adapt provided model for non-distributed and distributed configurations (supporting all available backends from available_backends()).

auto_optim

Helper method to adapt optimizer for non-distributed and distributed configurations (supporting all available backends from available_backends()).

DistributedProxySampler

Distributed sampler proxy to adapt user's sampler for distributed data parallelism configuration.

class ignite.distributed.auto.DistributedProxySampler(sampler, num_replicas=None, rank=None)[source]#

Distributed sampler proxy to adapt user’s sampler for distributed data parallelism configuration.

Code is based on https://github.com/pytorch/pytorch/issues/23430#issuecomment-562350407

Note

Input sampler is assumed to have a constant size.

Parameters
  • sampler (Sampler) – Input torch data sampler.

  • num_replicas (Optional[int]) – Number of processes participating in distributed training.

  • rank (Optional[int]) – Rank of the current process within num_replicas.

ignite.distributed.auto.auto_dataloader(dataset, **kwargs)[source]#

Helper method to create a dataloader adapted for non-distributed and distributed configurations (supporting all available backends from available_backends()).

Internally, we create a dataloader with provided kwargs while applying the following updates:

  • batch size is scaled by world size: batch_size / world_size if larger or equal world size.

  • number of workers is scaled by number of local processes: num_workers / nprocs if larger or equal world size.

  • if no sampler provided by user, torch DistributedSampler is setup.

  • if a sampler is provided by user, it is wrapped by DistributedProxySampler.

  • if the default device is ‘cuda’, pin_memory is automatically set to True.

Warning

Custom batch sampler is not adapted for distributed configuration. Please, make sure that provided batch sampler is compatible with distributed configuration.

Examples:

import ignite.distribted as idist

train_loader = idist.auto_dataloader(
    train_dataset,
    batch_size=32,
    num_workers=4,
    shuffle=True,
    pin_memory="cuda" in idist.device().type,
    drop_last=True,
)
Parameters
Returns

torch DataLoader or XLA MpDeviceLoader for XLA devices

Return type

Union[DataLoader, _MpDeviceLoader]

ignite.distributed.auto.auto_model(model, sync_bn=False, **kwargs)[source]#

Helper method to adapt provided model for non-distributed and distributed configurations (supporting all available backends from available_backends()).

Internally, we perform to following:

  • send model to current device() if model’s parameters are not on the device.

  • wrap the model to torch DistributedDataParallel for native torch distributed if world size is larger than 1.

  • wrap the model to torch DataParallel if no distributed context found and more than one CUDA devices available.

  • broadcast the initial variable states from rank 0 to all other processes if Horovod distributed framework is used.

Examples:

import ignite.distribted as idist

model = idist.auto_model(model)

In addition with NVidia/Apex, it can be used in the following way:

import ignite.distribted as idist

model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model = idist.auto_model(model)
Parameters
  • model (Module) – model to adapt.

  • sync_bn (bool) – if True, applies torch convert_sync_batchnorm to the model for native torch distributed only. Default, False. Note, if using Nvidia/Apex, batchnorm conversion should be applied before calling amp.initialize.

  • kwargs (Any) – kwargs to model’s wrapping class: torch DistributedDataParallel or torch DataParallel if applicable. Please, make sure to use acceptable kwargs for given backend.

Returns

torch.nn.Module

Return type

Module

Changed in version 0.4.2:

  • Added Horovod distributed framework.

  • Added sync_bn argument.

Changed in version 0.4.3: Added kwargs to idist.auto_model.

ignite.distributed.auto.auto_optim(optimizer)[source]#

Helper method to adapt optimizer for non-distributed and distributed configurations (supporting all available backends from available_backends()).

Internally, this method is no-op for non-distributed and torch native distributed configuration.

For XLA distributed configuration, we create a new class that inherits from provided optimizer. The goal is to override the step() method with specific xm.optimizer_step implementation.

For Horovod distributed configuration, optimizer is wrapped with Horovod Distributed Optimizer and its state is broadcasted from rank 0 to all other processes.

Examples:

import ignite.distributed as idist

optimizer = idist.auto_optim(optimizer)
Parameters

optimizer (Optimizer) – input torch optimizer

Returns

Optimizer

Return type

Optimizer

Changed in version 0.4.2: Added Horovod distributed optimizer.

ignite.distributed.launcher#

Parallel

Distributed launcher context manager to simplify distributed configuration setup for multiple backends:

class ignite.distributed.launcher.Parallel(backend=None, nproc_per_node=None, nnodes=None, node_rank=None, master_addr=None, master_port=None, **spawn_kwargs)[source]#

Distributed launcher context manager to simplify distributed configuration setup for multiple backends:

Namely, it can 1) spawn nproc_per_node child processes and initialize a processing group according to provided backend (useful for standalone scripts) or 2) only initialize a processing group given the backend (useful with tools like torch.distributed.launch, horovodrun, etc).

Examples

1) Single node or Multi-node, Multi-GPU training launched with torch.distributed.launch or horovodrun tools

Single node option with 4 GPUs

python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py
# or if installed horovod
horovodrun -np=4 python main.py

Multi-node option : 2 nodes with 8 GPUs each

## node 0
python -m torch.distributed.launch --nnodes=2 --node_rank=0 --master_addr=master                 --master_port=3344 --nproc_per_node=8 --use_env main.py

# or if installed horovod
horovodrun -np 16 -H hostname1:8,hostname2:8 python main.py

## node 1
python -m torch.distributed.launch --nnodes=2 --node_rank=1 --master_addr=master                 --master_port=3344 --nproc_per_node=8 --use_env main.py

User code is the same for both options:

# main.py

import ignite.distributed as idist

def training(local_rank, config, **kwargs):
    # ...
    print(idist.get_rank(), ": run with config:", config, "- backend=", idist.backend())
    # ...

backend = "nccl"  # or "horovod" if package is installed

with idist.Parallel(backend=backend) as parallel:
    parallel.run(training, config, a=1, b=2)
  1. Single node, Multi-GPU training launched with python

python main.py
# main.py

import ignite.distributed as idist

def training(local_rank, config, **kwargs):
    # ...
    print(idist.get_rank(), ": run with config:", config, "- backend=", idist.backend())
    # ...

backend = "nccl"  # or "horovod" if package is installed

with idist.Parallel(backend=backend, nproc_per_node=4) as parallel:
    parallel.run(training, config, a=1, b=2)
  1. Single node, Multi-TPU training launched with python

python main.py
# main.py

import ignite.distributed as idist

def training(local_rank, config, **kwargs):
    # ...
    print(idist.get_rank(), ": run with config:", config, "- backend=", idist.backend())
    # ...

with idist.Parallel(backend="xla-tpu", nproc_per_node=8) as parallel:
    parallel.run(training, config, a=1, b=2)
  1. Multi-node, Multi-GPU training launched with python. For example, 2 nodes with 8 GPUs:

Using torch native distributed framework:

# node 0
python main.py --node_rank=0

# node 1
python main.py --node_rank=1
# main.py

import ignite.distributed as idist

def training(local_rank, config, **kwargs):
    # ...
    print(idist.get_rank(), ": run with config:", config, "- backend=", idist.backend())
    # ...

dist_config = {
    "nproc_per_node": 8,
    "nnodes": 2,
    "node_rank": args.node_rank,
    "master_addr": "master",
    "master_port": 15000
}

with idist.Parallel(backend="nccl", **dist_config) as parallel:
    parallel.run(training, config, a=1, b=2)
Parameters
  • backend (Optional[str]) – backend to use: nccl, gloo, xla-tpu, horovod. If None, no distributed configuration.

  • nproc_per_node (Optional[int]) – optional argument, number of processes per node to specify. If not None, run() will spawn nproc_per_node processes that run input function with its arguments.

  • nnodes (Optional[int]) – optional argument, number of nodes participating in distributed configuration. If not None, run() will spawn nproc_per_node processes that run input function with its arguments. Total world size is nproc_per_node * nnodes. This option is only supported by native torch distributed module. For other modules, please setup spawn_kwargs with backend specific arguments.

  • node_rank (Optional[int]) – optional argument, current machine index. Mandatory argument if nnodes is specified and larger than one. This option is only supported by native torch distributed module. For other modules, please setup spawn_kwargs with backend specific arguments.

  • master_addr (Optional[str]) – optional argument, master node TCP/IP address for torch native backends (nccl, gloo). Mandatory argument if nnodes is specified and larger than one.

  • master_port (Optional[int]) – optional argument, master node port for torch native backends (nccl, gloo). Mandatory argument if master_addr is specified.

  • spawn_kwargs (Any) – kwargs to idist.spawn function.

Changed in version 0.4.2: backend now accepts horovod distributed framework.

run(func, *args, **kwargs)[source]#

Execute func with provided arguments in distributed context.

Example

def training(local_rank, config, **kwargs):
    # ...
    print(idist.get_rank(), ": run with config:", config, "- backend=", idist.backend())
    # ...

with idist.Parallel(backend=backend) as parallel:
    parallel.run(training, config, a=1, b=2)
Parameters
  • func (Callable) – function to execute. First argument of the function should be local_rank - local process index.

  • args (Any) – positional arguments of func (without local_rank).

  • kwargs (Any) – keyword arguments of func.

Return type

None

ignite.distributed.utils#

This module wraps common methods to fetch information about distributed configuration, initialize/finalize process group or spawn multiple processes.

backend

Returns computation model's backend.

broadcast

Helper method to perform broadcast operation.

device

Returns current device according to current distributed configuration.

available_backends

Returns available backends.

model_name

Returns distributed configuration name (given by ignite)

get_world_size

Returns world size of current distributed configuration.

get_rank

Returns process rank within current distributed configuration.

get_local_rank

Returns local process rank within current distributed configuration.

get_nproc_per_node

Returns number of processes (or tasks) per node within current distributed configuration.

get_node_rank

Returns node rank within current distributed configuration.

get_nnodes

Returns number of nodes within current distributed configuration.

spawn

Spawns nproc_per_node processes that run fn with args/kwargs_dict and initialize distributed configuration defined by backend.

initialize

Initializes distributed configuration according to provided backend

finalize

Finalizes distributed configuration.

show_config

Helper method to display distributed configuration via logging.

set_local_rank

Method to hint the local rank in case if torch native distributed context is created by user without using initialize() or spawn().

all_reduce

Helper method to perform all reduce operation.

all_gather

Helper method to perform all gather operation.

barrier

Helper method to synchronize all processes.

hostname

Returns host name for current process within current distributed configuration.

sync

Helper method to force this module to synchronize with current distributed context.

one_rank_only

Decorator to filter handlers wrt a rank number

ignite.distributed.utils.has_native_dist_support#

True if torch.distributed is available

ignite.distributed.utils.has_xla_support#

True if torch_xla package is found

ignite.distributed.utils.all_gather(tensor)[source]#

Helper method to perform all gather operation.

Parameters

tensor (Union[Tensor, float, str]) – tensor or number or str to collect across participating processes.

Returns

torch.Tensor of shape (world_size * tensor.shape[0], tensor.shape[1], ...) if input is a tensor or torch.Tensor of shape (world_size, ) if input is a number or List of strings if input is a string

Return type

Union[Tensor, float, List[float], List[str]]

ignite.distributed.utils.all_reduce(tensor, op='SUM')[source]#

Helper method to perform all reduce operation.

Parameters
  • tensor (Union[Tensor, float]) – tensor or number to collect across participating processes.

  • op (str) – reduction operation, “SUM” by default. Possible values: “SUM”, “PRODUCT”, “MIN”, “MAX”, “AND”, “OR”. Horovod backend supports only “SUM”, “AVERAGE”, “ADASUM”, “MIN”, “MAX”, “PRODUCT”.

Returns

torch.Tensor or number

Return type

Union[Tensor, float]

ignite.distributed.utils.available_backends()[source]#

Returns available backends.

Return type

Tuple[str, …]

ignite.distributed.utils.backend()[source]#

Returns computation model’s backend.

  • None for no distributed configuration

  • “nccl” or “gloo” or “mpi” for native torch distributed configuration

  • “xla-tpu” for XLA distributed configuration

  • “horovod” for Horovod distributed framework

Returns

str or None

Return type

Optional[str]

Changed in version 0.4.2: Added Horovod distributed framework.

ignite.distributed.utils.barrier()[source]#

Helper method to synchronize all processes.

Return type

None

ignite.distributed.utils.broadcast(tensor, src=0)[source]#

Helper method to perform broadcast operation.

Parameters
  • tensor (Union[Tensor, float, str]) – tensor or number or str to broadcast to participating processes. Make sure to respect dtype of torch tensor input for all processes, otherwise execution will crash.

  • src (int) – source rank. Default, 0.

Returns

torch.Tensor or string or number

Return type

Union[Tensor, float, str]

Examples

if idist.get_rank() == 0:
    t1 = torch.rand(4, 5, 6, device=idist.device())
    s1 = "abc"
    x = 12.3456
else:
    t1 = torch.empty(4, 5, 6, device=idist.device())
    s1 = ""
    x = 0.0

# Broadcast tensor t1 from rank 0 to all processes
t1 = idist.broadcast(t1, src=0)
assert isinstance(t1, torch.Tensor)

# Broadcast string s1 from rank 0 to all processes
s1 = idist.broadcast(s1, src=0)
# >>> s1 = "abc"

# Broadcast float number x from rank 0 to all processes
x = idist.broadcast(x, src=0)
# >>> x = 12.3456

New in version 0.4.2.

ignite.distributed.utils.device()[source]#

Returns current device according to current distributed configuration.

  • torch.device(“cpu”) if no distributed configuration or torch native gloo distributed configuration

  • torch.device(“cuda:local_rank”) if torch native nccl or horovod distributed configuration

  • torch.device(“xla:index”) if XLA distributed configuration

Returns

torch.device

Return type

device

Changed in version 0.4.2: Added Horovod distributed framework.

ignite.distributed.utils.finalize()[source]#

Finalizes distributed configuration. For example, in case of native pytorch distributed configuration, it calls dist.destroy_process_group().

Return type

None

ignite.distributed.utils.get_local_rank()[source]#

Returns local process rank within current distributed configuration. Returns 0 if no distributed configuration.

Return type

int

ignite.distributed.utils.get_nnodes()[source]#

Returns number of nodes within current distributed configuration. Returns 1 if no distributed configuration.

Return type

int

ignite.distributed.utils.get_node_rank()[source]#

Returns node rank within current distributed configuration. Returns 0 if no distributed configuration.

Return type

int

ignite.distributed.utils.get_nproc_per_node()[source]#

Returns number of processes (or tasks) per node within current distributed configuration. Returns 1 if no distributed configuration.

Return type

int

ignite.distributed.utils.get_rank()[source]#

Returns process rank within current distributed configuration. Returns 0 if no distributed configuration.

Return type

int

ignite.distributed.utils.get_world_size()[source]#

Returns world size of current distributed configuration. Returns 1 if no distributed configuration.

Return type

int

ignite.distributed.utils.hostname()[source]#

Returns host name for current process within current distributed configuration.

Return type

str

ignite.distributed.utils.initialize(backend, **kwargs)[source]#

Initializes distributed configuration according to provided backend

Examples

Launch single node multi-GPU training with torch.distributed.launch utility.

# >>> python -m torch.distributed.launch --nproc_per_node=4 main.py

# main.py

import ignite.distributed as idist

def train_fn(local_rank, a, b, c):
    import torch.distributed as dist
    assert dist.is_available() and dist.is_initialized()
    assert dist.get_world_size() == 4

    device = idist.device()
    assert device == torch.device(f"cuda:{local_rank}")


idist.initialize("nccl")
local_rank = idist.get_local_rank()
train_fn(local_rank, a, b, c)
idist.finalize()
Parameters
  • backend (str) – backend: nccl, gloo, xla-tpu, horovod.

  • kwargs (Any) –

    acceptable kwargs according to provided backend:

    • ”nccl” or “gloo” : timeout(=timedelta(minutes=30)).

    • ”horovod” : comm(=None), more info: hvd_init.

Return type

None

Changed in version 0.4.2: backend now accepts horovod distributed framework.

ignite.distributed.utils.model_name()[source]#

Returns distributed configuration name (given by ignite)

  • serial for no distributed configuration

  • native-dist for native torch distributed configuration

  • xla-dist for XLA distributed configuration

  • horovod-dist for Horovod distributed framework

Changed in version 0.4.2: horovod-dist will be returned for Horovod distributed framework.

Return type

str

ignite.distributed.utils.one_rank_only(rank=0, with_barrier=False)[source]#

Decorator to filter handlers wrt a rank number

Parameters
  • rank (int) – rank number of the handler (default: 0).

  • with_barrier (bool) – synchronisation with a barrier (default: False).

Return type

Callable

engine = ...

@engine.on(...)
@one_rank_only() # means @one_rank_only(rank=0)
def some_handler(_):
    ...

@engine.on(...)
@one_rank_only(rank=1)
def some_handler(_):
    ...
ignite.distributed.utils.set_local_rank(index)[source]#

Method to hint the local rank in case if torch native distributed context is created by user without using initialize() or spawn().

Usage:

User set up torch native distributed process group

import ignite.distributed as idist

def run(local_rank, *args, **kwargs):

    idist.set_local_rank(local_rank)
    # ...
    dist.init_process_group(**dist_info)
    # ...
Parameters

index (int) – local rank or current process index

Return type

None

ignite.distributed.utils.show_config()[source]#

Helper method to display distributed configuration via logging.

Return type

None

ignite.distributed.utils.spawn(backend, fn, args, kwargs_dict=None, nproc_per_node=1, **kwargs)[source]#

Spawns nproc_per_node processes that run fn with args/kwargs_dict and initialize distributed configuration defined by backend.

Examples

  1. Launch single node multi-GPU training using torch native distributed framework

# >>> python main.py

# main.py

import ignite.distributed as idist

def train_fn(local_rank, a, b, c, d=12):
    import torch.distributed as dist
    assert dist.is_available() and dist.is_initialized()
    assert dist.get_world_size() == 4

    device = idist.device()
    assert device == torch.device(f"cuda:{local_rank}")


idist.spawn("nccl", train_fn, args=(a, b, c), kwargs_dict={"d": 23}, nproc_per_node=4)
  1. Launch multi-node multi-GPU training using torch native distributed framework

# >>> (node 0): python main.py --node_rank=0 --nnodes=8 --master_addr=master --master_port=2222
# >>> (node 1): python main.py --node_rank=1 --nnodes=8 --master_addr=master --master_port=2222
# >>> ...
# >>> (node 7): python main.py --node_rank=7 --nnodes=8 --master_addr=master --master_port=2222

# main.py

import torch
import ignite.distributed as idist

def train_fn(local_rank, nnodes, nproc_per_node):
    import torch.distributed as dist
    assert dist.is_available() and dist.is_initialized()
    assert dist.get_world_size() == nnodes * nproc_per_node

    device = idist.device()
    assert device == torch.device(f"cuda:{local_rank}")

idist.spawn(
    "nccl",
    train_fn,
    args=(nnodes, nproc_per_node),
    nproc_per_node=nproc_per_node,
    nnodes=nnodes,
    node_rank=node_rank,
    master_addr=master_addr,
    master_port=master_port
)
  1. Launch single node multi-TPU training (for example on Google Colab) using PyTorch/XLA

# >>> python main.py

# main.py

import ignite.distributed as idist

def train_fn(local_rank, a, b, c, d=12):
    import torch_xla.core.xla_model as xm
    assert xm.get_world_size() == 8

    device = idist.device()
    assert "xla" in device.type


idist.spawn("xla-tpu", train_fn, args=(a, b, c), kwargs_dict={"d": 23}, nproc_per_node=8)
Parameters
  • backend (str) – backend to use: nccl, gloo, xla-tpu, horovod

  • fn (Callable) – function to called as the entrypoint of the spawned process. This function must be defined at the top level of a module so it can be pickled and spawned. This is a requirement imposed by multiprocessing. The function is called as fn(i, *args, **kwargs_dict), where i is the process index and args is the passed through tuple of arguments.

  • args (Tuple) – arguments passed to fn.

  • kwargs_dict (Optional[Mapping]) – kwargs passed to fn.

  • nproc_per_node (int) – number of processes to spawn on a single node. Default, 1.

  • kwargs (Any) –

    acceptable kwargs according to provided backend:

    • ”nccl” or “gloo” : nnodes (default, 1), node_rank (default, 0), master_addr
      (default, “127.0.0.1”), master_port (default, 2222), timeout to dist.init_process_group function
      and kwargs for mp.start_processes function.
    • ”xla-tpu” : nnodes (default, 1), node_rank (default, 0) and kwargs to xmp.spawn function.
    • ”horovod”: hosts (default, None) and other kwargs to hvd_run function. Arguments nnodes=1
      and node_rank=0 are tolerated and ignored, otherwise an exception is raised.

Return type

None

Changed in version 0.4.2: backend now accepts horovod distributed framework.

ignite.distributed.utils.sync(temporary=False)[source]#

Helper method to force this module to synchronize with current distributed context. This method should be used when distributed context is manually created or destroyed.

Parameters

temporary (bool) – If True, distributed model synchronization is done every call of idist.get_* methods. This may have a negative performance impact.

Return type

None