Shortcuts

ignite.distributed#

Helper module to use distributed settings for multiple backends:

  • backends from native torch distributed configuration: “nccl”, “gloo”, “mpi”

  • XLA on TPUs via pytorch/xla

Distributed launcher and auto helpers#

We provide a context manager to simplify the code of distributed configuration setup for all above supported backends. In addition, methods like auto_model(), auto_optim() and auto_dataloader() helps to adapt in a transparent way provided model, optimizer and data loaders to existing configuration:

# main.py

import ignite.distributed as idist

def training(local_rank, config, **kwargs):

    print(idist.get_rank(), ": run with config:", config, "- backend=", idist.backend())

    train_loader = idist.auto_dataloader(dataset, batch_size=32, num_workers=12, shuffle=True, **kwargs)
    # batch size, num_workers and sampler are automatically adapted to existing configuration
    # ...
    model = resnet50()
    model = idist.auto_model(model)
    # model is DDP or DP or just itself according to existing configuration
    # ...
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    optimizer = idist.auto_optim(optimizer)
    # optimizer is itself, except XLA configuration and overrides `step()` method.
    # User can safely call `optimizer.step()` (behind `xm.optimizer_step(optimizier)` is performed)


backend = "nccl"  # torch native distributed configuration on multiple GPUs
# backend = "xla-tpu"  # XLA TPUs distributed configuration
# backend = None  # no distributed configuration
with idist.Parallel(backend=backend, **dist_configs) as parallel:
    parallel.run(training, config, a=1, b=2)

Above code may be executed with torch.distributed.launch tool or by python and specifying distributed configuration in the code. For more details, please, see Parallel, auto_model(), auto_optim() and auto_dataloader().

Complete example of CIFAR10 training can be found here.

ignite.distributed.auto#

class ignite.distributed.auto.DistributedProxySampler(sampler, num_replicas=None, rank=None)[source]#

Distributed sampler proxy to adapt user’s sampler for distributed data parallelism configuration.

Code is based on https://github.com/pytorch/pytorch/issues/23430#issuecomment-562350407

Note

Input sampler is assumed to have a constant size.

Parameters
  • sampler (Sampler) – Input torch data sampler.

  • num_replicas (int, optional) – Number of processes participating in distributed training.

  • rank (int, optional) – Rank of the current process within num_replicas.

ignite.distributed.auto.auto_dataloader(dataset, **kwargs)[source]#

Helper method to create a dataloader adapted for non-distributed and distributed configurations (supporting all available backends from available_backends()).

Internally, we create a dataloader with provided kwargs while applying the following updates:

  • batch size is scaled by world size: batch_size / world_size if larger or equal world size.

  • number of workers is scaled by number of local processes: num_workers / nprocs if larger or equal world size.

  • if no sampler provided by user, torch DistributedSampler is setup.

  • if a sampler is provided by user, it is wrapped by DistributedProxySampler.

  • if the default device is ‘cuda’, pin_memory is automatically set to True.

Warning

Custom batch sampler is not adapted for distributed configuration. Please, make sure that provided batch sampler is compatible with distributed configuration.

Examples:

import ignite.distribted as idist

train_loader = idist.auto_dataloader(
    train_dataset,
    batch_size=32,
    num_workers=4,
    shuffle=True,
    pin_memory="cuda" in idist.device().type,
    drop_last=True,
)
Parameters
  • dataset (Dataset) – input torch dataset

  • **kwargs – keyword arguments for torch DataLoader.

Returns

torch DataLoader or XLA MpDeviceLoader for XLA devices

ignite.distributed.auto.auto_model(model)[source]#

Helper method to adapt provided model for non-distributed and distributed configurations (supporting all available backends from available_backends()).

Internally, we perform to following:

  • send model to current device() if model’s parameters are not on the device.

  • wrap the model to torch DistributedDataParallel for native torch distributed if world size is larger than 1.

  • wrap the model to torch DataParallel if no distributed context found and more than one CUDA devices available.

Examples:

import ignite.distribted as idist

model = idist.auto_model(model)

In addition with NVidia/Apex, it can be used in the following way:

import ignite.distribted as idist

model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model = idist.auto_model(model)
Parameters

model (torch.nn.Module) – model to adapt.

Returns

torch.nn.Module

Return type

Module

ignite.distributed.auto.auto_optim(optimizer)[source]#

Helper method to adapt optimizer for non-distributed and distributed configurations (supporting all available backends from available_backends()).

Internally, this method is no-op for non-distributed and torch native distributed configuration. For XLA distributed configuration, we create a new class that inherits from provided optimizer. The goal is to override the step() method with specific xm.optimizer_step implementation.

Examples:

import ignite.distribted as idist

optimizer = idist.auto_optim(optimizer)
Parameters

optimizer (Optimizer) – input torch optimizer

Returns

Optimizer

Return type

Optimizer

ignite.distributed.launcher#

class ignite.distributed.launcher.Parallel(backend=None, nproc_per_node=None, nnodes=None, node_rank=None, master_addr=None, master_port=None, **spawn_kwargs)[source]#

Distributed launcher context manager to simplify distributed configuration setup for multiple backends:

  • backends from native torch distributed configuration: “nccl”, “gloo”, “mpi”

  • XLA on TPUs via pytorch/xla

Namely, it can 1) spawn nproc_per_node child processes and initialize a processing group according to provided backend (useful for standalone scripts) or 2) only initialize a processing group given the backend (useful with tools like torch.distributed.launch).

Examples

  1. Single node or Multi-node, Multi-GPU training launched with torch.distributed.launch tool

Single node option :

python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py

Multi-node option :

# node 0
python -m torch.distributed.launch --nnodes=2 --node_rank=0 --master_addr=master                 --master_port=3344 --nproc_per_node=8 --use_env main.py

# node 1
python -m torch.distributed.launch --nnodes=2 --node_rank=1 --master_addr=master                 --master_port=3344 --nproc_per_node=8 --use_env main.py

User code is the same for both options:

# main.py

import ignite.distributed as idist

def training(local_rank, config, **kwargs):
    # ...
    print(idist.get_rank(), ": run with config:", config, "- backend=", idist.backend())
    # ...

with idist.Parallel(backend="nccl") as parallel:
    parallel.run(training, config, a=1, b=2)
  1. Single node, Multi-GPU training launched with python

python main.py
# main.py

import ignite.distributed as idist

def training(local_rank, config, **kwargs):
    # ...
    print(idist.get_rank(), ": run with config:", config, "- backend=", idist.backend())
    # ...

with idist.Parallel(backend="nccl", nproc_per_node=4) as parallel:
    parallel.run(training, config, a=1, b=2)
  1. Single node, Multi-TPU training launched with python

python main.py
# main.py

import ignite.distributed as idist

def training(local_rank, config, **kwargs):
    # ...
    print(idist.get_rank(), ": run with config:", config, "- backend=", idist.backend())
    # ...

with idist.Parallel(backend="xla-tpu", nproc_per_node=8) as parallel:
    parallel.run(training, config, a=1, b=2)
  1. Multi-node, Multi-GPU training launched with python. For example, 2 nodes with 8 GPUs:

# node 0
python main.py --node_rank=0

# node 1
python main.py --node_rank=1
# main.py

import ignite.distributed as idist

def training(local_rank, config, **kwargs):
    # ...
    print(idist.get_rank(), ": run with config:", config, "- backend=", idist.backend())
    # ...

dist_config = {
    "nproc_per_node": 8,
    "nnodes": 2,
    "node_rank": args.node_rank,
    "master_addr": "master",
    "master_port": 15000
}

with idist.Parallel(backend="nccl", **dist_config) as parallel:
    parallel.run(training, config, a=1, b=2)
Parameters
  • backend (str, optional) – backend to use: nccl, gloo, xla-tpu. If None, no distributed configuration.

  • nproc_per_node (int, optional) – optional argument, number of processes per node to specify. If not None, run() will spawn nproc_per_node processes that run input function with its arguments.

  • nnodes (int, optional) – optional argument, number of nodes participating in distributed configuration. If not None, run() will spawn nproc_per_node processes that run input function with its arguments. Total world size is nproc_per_node * nnodes.

  • node_rank (int, optional) – optional argument, current machine index. Mandatory argument if nnodes is specified and larger than one.

  • master_addr (str, optional) – optional argument, master node TCP/IP address for torch native backends (nccl, gloo). Mandatory argument if nnodes is specified and larger than one.

  • master_port (int, optional) – optional argument, master node port for torch native backends (nccl, gloo). Mandatory argument if master_addr is specified.

  • **spawn_kwargs – kwargs to idist.spawn function.

run(func, *args, **kwargs)[source]#

Execute func with provided arguments in distributed context.

Example

def training(local_rank, config, **kwargs):
    # ...
    print(idist.get_rank(), ": run with config:", config, "- backend=", idist.backend())
    # ...
Parameters
  • func (Callable) – function to execute. First argument of the function should be local_rank - local process index.

  • *args – positional arguments of func (without local_rank).

  • **kwargs – keyword arguments of func.

ignite.distributed.utils#

This module wraps common methods to fetch information about distributed configuration, initialize/finalize process group or spawn multiple processes.

ignite.distributed.utils.has_native_dist_support#

True if torch.distributed is available

ignite.distributed.utils.has_xla_support#

True if torch_xla package is found

ignite.distributed.utils.all_gather(tensor)[source]#

Helper method to perform all gather operation.

Parameters

tensor (torch.Tensor or number or str) – tensor or number or str to collect across participating processes.

Returns

torch.Tensor of shape (world_size * tensor.shape[0], tensor.shape[1], …) or List of strings

Return type

Union[Tensor, Number, List[str]]

ignite.distributed.utils.all_reduce(tensor, op='SUM')[source]#

Helper method to perform all reduce operation.

Parameters
  • tensor (torch.Tensor or number) – tensor or number to collect across participating processes.

  • op (str) – reduction operation, “SUM” by default. Possible values: “SUM”, “PRODUCT”, “MIN”, “MAX”, “AND”, “OR”.

Returns

torch.Tensor or number

Return type

Union[Tensor, Number]

ignite.distributed.utils.available_backends()[source]#

Returns available backends.

Return type

Tuple[str]

ignite.distributed.utils.backend()[source]#

Returns computation model’s backend.

  • None for no distributed configuration

  • “nccl” or “gloo” or “mpi” for native torch distributed configuration

  • “xla-tpu” for XLA distributed configuration

Returns

str or None

Return type

Optional[str]

ignite.distributed.utils.barrier()[source]#

Helper method to synchronize all processes.

ignite.distributed.utils.device()[source]#

Returns current device according to current distributed configuration.

  • torch.device(“cpu”) if no distributed configuration or native gloo distributed configuration

  • torch.device(“cuda:local_rank”) if native nccl distributed configuration

  • torch.device(“xla:index”) if XLA distributed configuration

Returns

torch.device

Return type

device

ignite.distributed.utils.finalize()[source]#

Finalizes distributed configuration. For example, in case of native pytorch distributed configuration, it calls dist.destroy_process_group().

ignite.distributed.utils.get_local_rank()[source]#

Returns local process rank within current distributed configuration. Returns 0 if no distributed configuration.

Return type

int

ignite.distributed.utils.get_nnodes()[source]#

Returns number of nodes within current distributed configuration. Returns 1 if no distributed configuration.

Return type

int

ignite.distributed.utils.get_node_rank()[source]#

Returns node rank within current distributed configuration. Returns 0 if no distributed configuration.

Return type

int

ignite.distributed.utils.get_nproc_per_node()[source]#

Returns number of processes (or tasks) per node within current distributed configuration. Returns 1 if no distributed configuration.

Return type

int

ignite.distributed.utils.get_rank()[source]#

Returns process rank within current distributed configuration. Returns 0 if no distributed configuration.

Return type

int

ignite.distributed.utils.get_world_size()[source]#

Returns world size of current distributed configuration. Returns 1 if no distributed configuration.

Return type

int

ignite.distributed.utils.hostname()[source]#

Returns host name for current process within current distributed configuration.

Return type

str

ignite.distributed.utils.initialize(backend, **kwargs)[source]#

Initializes distributed configuration according to provided backend

Examples

Launch single node multi-GPU training with torch.distributed.launch utility.

# >>> python -m torch.distributed.launch --nproc_per_node=4 main.py

# main.py

import ignite.distributed as idist

def train_fn(local_rank, a, b, c):
    import torch.distributed as dist
    assert dist.is_available() and dist.is_initialized()
    assert dist.get_world_size() == 4

    device = idist.device()
    assert device == torch.device("cuda:{}".format(local_rank))


idist.initialize("nccl")
local_rank = idist.get_local_rank()
train_fn(local_rank, a, b, c)
idist.finalize()
Parameters
  • backend (str, optional) – backend: nccl, gloo, xla-tpu.

  • **kwargs

    acceptable kwargs according to provided backend:

    • ”nccl” or “gloo” : timeout(=timedelta(minutes=30))

ignite.distributed.utils.model_name()[source]#

Returns distributed configuration name (given by ignite)

  • serial for no distributed configuration

  • native-dist for native torch distributed configuration

  • xla-dist for XLA distributed configuration

Return type

str

ignite.distributed.utils.one_rank_only(rank=0, with_barrier=False)[source]#

Decorator to filter handlers wrt a rank number

Parameters
  • rank (int) – rank number of the handler (default: 0).

  • with_barrier (bool) – synchronisation with a barrier (default: False).

engine = ...

@engine.on(...)
@one_rank_only() # means @one_rank_only(rank=0)
def some_handler(_):
    ...

@engine.on(...)
@one_rank_only(rank=1)
def some_handler(_):
    ...
ignite.distributed.utils.set_local_rank(index)[source]#

Method to hint the local rank in case if torch native distributed context is created by user without using initialize() or spawn().

Usage:

User set up torch native distributed process group

import ignite.distributed as idist

def run(local_rank, *args, **kwargs):

    idist.set_local_rank(local_rank)
    # ...
    dist.init_process_group(**dist_info)
    # ...
Parameters

index (int) – local rank or current process index

ignite.distributed.utils.show_config()[source]#

Helper method to display distributed configuration via logging.

ignite.distributed.utils.spawn(backend, fn, args, kwargs_dict=None, nproc_per_node=1, **kwargs)[source]#

Spawns nproc_per_node processes that run fn with args/kwargs_dict and initialize distributed configuration defined by backend.

Examples

  1. Launch single node multi-GPU training

# >>> python main.py

# main.py

import ignite.distributed as idist

def train_fn(local_rank, a, b, c, d=12):
    import torch.distributed as dist
    assert dist.is_available() and dist.is_initialized()
    assert dist.get_world_size() == 4

    device = idist.device()
    assert device == torch.device("cuda:{}".format(local_rank))


idist.spawn("nccl", train_fn, args=(a, b, c), kwargs_dict={"d": 23}, nproc_per_node=4)
  1. Launch multi-node multi-GPU training

# >>> (node 0): python main.py --node_rank=0 --nnodes=8 --master_addr=master --master_port=2222
# >>> (node 1): python main.py --node_rank=1 --nnodes=8 --master_addr=master --master_port=2222
# >>> ...
# >>> (node 7): python main.py --node_rank=7 --nnodes=8 --master_addr=master --master_port=2222

# main.py

import torch
import ignite.distributed as idist

def train_fn(local_rank, nnodes, nproc_per_node):
    import torch.distributed as dist
    assert dist.is_available() and dist.is_initialized()
    assert dist.get_world_size() == nnodes * nproc_per_node

    device = idist.device()
    assert device == torch.device("cuda:{}".format(local_rank))

idist.spawn(
    "nccl",
    train_fn,
    args=(nnodes, nproc_per_node),
    nproc_per_node=nproc_per_node,
    nnodes=nnodes,
    node_rank=node_rank,
    master_addr=master_addr,
    master_port=master_port
)
  1. Launch single node multi-TPU training (for example on Google Colab)

# >>> python main.py

# main.py

import ignite.distributed as idist

def train_fn(local_rank, a, b, c, d=12):
    import torch_xla.core.xla_model as xm
    assert xm.get_world_size() == 8

    device = idist.device()
    assert "xla" in device.type


idist.spawn("xla-tpu", train_fn, args=(a, b, c), kwargs_dict={"d": 23}, nproc_per_node=8)
Parameters
  • backend (str) – backend to use: nccl, gloo, xla-tpu

  • fn (function) – function to called as the entrypoint of the spawned process. This function must be defined at the top level of a module so it can be pickled and spawned. This is a requirement imposed by multiprocessing. The function is called as fn(i, *args, **kwargs_dict), where i is the process index and args is the passed through tuple of arguments.

  • args (tuple) – arguments passed to fn.

  • kwargs_dict (Mapping) – kwargs passed to fn.

  • nproc_per_node (int) – number of processes to spawn on a single node. Default, 1.

  • **kwargs

    acceptable kwargs according to provided backend:

    • ”nccl” or “gloo” : nnodes (default, 1), node_rank (default, 0), master_addr
      (default, “127.0.0.1”), master_port (default, 2222), timeout to dist.init_process_group function
      and kwargs for mp.spawn function.
    • ”xla-tpu” : nnodes (default, 1), node_rank (default, 0) and kwargs to xmp.spawn function.

ignite.distributed.utils.sync(temporary=False)[source]#

Helper method to force this module to synchronize with current distributed context. This method should be used when distributed context is manually created or destroyed.

Parameters

temporary (bool) – If True, distributed model synchronization is done every call of idist.get_* methods. This may have performance negative impact.