ignite.distributed#
Helper module to use distributed settings for multiple backends:
backends from native torch distributed configuration: “nccl”, “gloo”, “mpi”
XLA on TPUs via pytorch/xla
using Horovod framework as a backend
Distributed launcher and auto helpers#
We provide a context manager to simplify the code of distributed configuration setup for all above supported backends.
In addition, methods like auto_model()
, auto_optim()
and
auto_dataloader()
helps to adapt in a transparent way provided model, optimizer and data
loaders to existing configuration:
# main.py
import ignite.distributed as idist
def training(local_rank, config, **kwargs):
print(idist.get_rank(), ": run with config:", config, "- backend=", idist.backend())
train_loader = idist.auto_dataloader(dataset, batch_size=32, num_workers=12, shuffle=True, **kwargs)
# batch size, num_workers and sampler are automatically adapted to existing configuration
# ...
model = resnet50()
model = idist.auto_model(model)
# model is DDP or DP or just itself according to existing configuration
# ...
optimizer = optim.SGD(model.parameters(), lr=0.01)
optimizer = idist.auto_optim(optimizer)
# optimizer is itself, except XLA configuration and overrides `step()` method.
# User can safely call `optimizer.step()` (behind `xm.optimizer_step(optimizier)` is performed)
backend = "nccl" # torch native distributed configuration on multiple GPUs
# backend = "xla-tpu" # XLA TPUs distributed configuration
# backend = None # no distributed configuration
with idist.Parallel(backend=backend, **dist_configs) as parallel:
parallel.run(training, config, a=1, b=2)
Above code may be executed with torch.distributed.launch tool or by python and specifying distributed configuration
in the code. For more details, please, see Parallel
,
auto_model()
, auto_optim()
and
auto_dataloader()
.
Complete example of CIFAR10 training can be found here.
ignite.distributed.auto#
Helper method to create a dataloader adapted for non-distributed and distributed configurations (supporting all available backends from |
|
Helper method to adapt provided model for non-distributed and distributed configurations (supporting all available backends from |
|
Helper method to adapt optimizer for non-distributed and distributed configurations (supporting all available backends from |
|
Distributed sampler proxy to adapt user's sampler for distributed data parallelism configuration. |
- class ignite.distributed.auto.DistributedProxySampler(sampler, num_replicas=None, rank=None)[source]#
Distributed sampler proxy to adapt user’s sampler for distributed data parallelism configuration.
Code is based on https://github.com/pytorch/pytorch/issues/23430#issuecomment-562350407
Note
Input sampler is assumed to have a constant size.
- ignite.distributed.auto.auto_dataloader(dataset, **kwargs)[source]#
Helper method to create a dataloader adapted for non-distributed and distributed configurations (supporting all available backends from
available_backends()
).Internally, we create a dataloader with provided kwargs while applying the following updates:
batch size is scaled by world size:
batch_size / world_size
if larger or equal world size.number of workers is scaled by number of local processes:
num_workers / nprocs
if larger or equal world size.if no sampler provided by user, torch DistributedSampler is setup.
if a sampler is provided by user, it is wrapped by
DistributedProxySampler
.if the default device is ‘cuda’, pin_memory is automatically set to True.
Warning
Custom batch sampler is not adapted for distributed configuration. Please, make sure that provided batch sampler is compatible with distributed configuration.
Examples:
import ignite.distribted as idist train_loader = idist.auto_dataloader( train_dataset, batch_size=32, num_workers=4, shuffle=True, pin_memory="cuda" in idist.device().type, drop_last=True, )
- Parameters
dataset (Dataset) – input torch dataset
**kwargs – keyword arguments for torch DataLoader.
- Returns
torch DataLoader or XLA MpDeviceLoader for XLA devices
- Return type
Union[DataLoader, _MpDeviceLoader]
- ignite.distributed.auto.auto_model(model, sync_bn=False, **kwargs)[source]#
Helper method to adapt provided model for non-distributed and distributed configurations (supporting all available backends from
available_backends()
).Internally, we perform to following:
send model to current
device()
if model’s parameters are not on the device.wrap the model to torch DistributedDataParallel for native torch distributed if world size is larger than 1.
wrap the model to torch DataParallel if no distributed context found and more than one CUDA devices available.
broadcast the initial variable states from rank 0 to all other processes if Horovod distributed framework is used.
Examples:
import ignite.distribted as idist model = idist.auto_model(model)
In addition with NVidia/Apex, it can be used in the following way:
import ignite.distribted as idist model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) model = idist.auto_model(model)
- Parameters
model (torch.nn.Module) – model to adapt.
sync_bn (bool) – if True, applies torch convert_sync_batchnorm to the model for native torch distributed only. Default, False. Note, if using Nvidia/Apex, batchnorm conversion should be applied before calling
amp.initialize
.**kwargs – kwargs to model’s wrapping class: torch DistributedDataParallel or torch DataParallel if applicable. Please, make sure to use acceptable kwargs for given backend.
- Returns
torch.nn.Module
- Return type
Changed in version 0.4.2:
Added Horovod distributed framework.
Added
sync_bn
argument.
Changed in version 0.4.3: Added kwargs to
idist.auto_model
.
- ignite.distributed.auto.auto_optim(optimizer)[source]#
Helper method to adapt optimizer for non-distributed and distributed configurations (supporting all available backends from
available_backends()
).Internally, this method is no-op for non-distributed and torch native distributed configuration.
For XLA distributed configuration, we create a new class that inherits from provided optimizer. The goal is to override the step() method with specific xm.optimizer_step implementation.
For Horovod distributed configuration, optimizer is wrapped with Horovod Distributed Optimizer and its state is broadcasted from rank 0 to all other processes.
Examples:
import ignite.distributed as idist optimizer = idist.auto_optim(optimizer)
- Parameters
optimizer (Optimizer) – input torch optimizer
- Returns
Optimizer
- Return type
Changed in version 0.4.2: Added Horovod distributed optimizer.
ignite.distributed.launcher#
Distributed launcher context manager to simplify distributed configuration setup for multiple backends: |
- class ignite.distributed.launcher.Parallel(backend=None, nproc_per_node=None, nnodes=None, node_rank=None, master_addr=None, master_port=None, **spawn_kwargs)[source]#
Distributed launcher context manager to simplify distributed configuration setup for multiple backends:
backends from native torch distributed configuration: “nccl”, “gloo”, “mpi” (if available)
XLA on TPUs via pytorch/xla (if installed)
using Horovod distributed framework (if installed)
Namely, it can 1) spawn
nproc_per_node
child processes and initialize a processing group according to providedbackend
(useful for standalone scripts) or 2) only initialize a processing group given thebackend
(useful with tools like torch.distributed.launch, horovodrun, etc).Examples
1) Single node or Multi-node, Multi-GPU training launched with torch.distributed.launch or horovodrun tools
Single node option with 4 GPUs
python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py # or if installed horovod horovodrun -np=4 python main.py
Multi-node option : 2 nodes with 8 GPUs each
## node 0 python -m torch.distributed.launch --nnodes=2 --node_rank=0 --master_addr=master --master_port=3344 --nproc_per_node=8 --use_env main.py # or if installed horovod horovodrun -np 16 -H hostname1:8,hostname2:8 python main.py ## node 1 python -m torch.distributed.launch --nnodes=2 --node_rank=1 --master_addr=master --master_port=3344 --nproc_per_node=8 --use_env main.py
User code is the same for both options:
# main.py import ignite.distributed as idist def training(local_rank, config, **kwargs): # ... print(idist.get_rank(), ": run with config:", config, "- backend=", idist.backend()) # ... backend = "nccl" # or "horovod" if package is installed with idist.Parallel(backend=backend) as parallel: parallel.run(training, config, a=1, b=2)
Single node, Multi-GPU training launched with python
python main.py
# main.py import ignite.distributed as idist def training(local_rank, config, **kwargs): # ... print(idist.get_rank(), ": run with config:", config, "- backend=", idist.backend()) # ... backend = "nccl" # or "horovod" if package is installed with idist.Parallel(backend=backend, nproc_per_node=4) as parallel: parallel.run(training, config, a=1, b=2)
Single node, Multi-TPU training launched with python
python main.py
# main.py import ignite.distributed as idist def training(local_rank, config, **kwargs): # ... print(idist.get_rank(), ": run with config:", config, "- backend=", idist.backend()) # ... with idist.Parallel(backend="xla-tpu", nproc_per_node=8) as parallel: parallel.run(training, config, a=1, b=2)
Multi-node, Multi-GPU training launched with python. For example, 2 nodes with 8 GPUs:
Using torch native distributed framework:
# node 0 python main.py --node_rank=0 # node 1 python main.py --node_rank=1
# main.py import ignite.distributed as idist def training(local_rank, config, **kwargs): # ... print(idist.get_rank(), ": run with config:", config, "- backend=", idist.backend()) # ... dist_config = { "nproc_per_node": 8, "nnodes": 2, "node_rank": args.node_rank, "master_addr": "master", "master_port": 15000 } with idist.Parallel(backend="nccl", **dist_config) as parallel: parallel.run(training, config, a=1, b=2)
- Parameters
backend (str, optional) – backend to use: nccl, gloo, xla-tpu, horovod. If None, no distributed configuration.
nproc_per_node (int, optional) – optional argument, number of processes per node to specify. If not None,
run()
will spawnnproc_per_node
processes that run input function with its arguments.nnodes (int, optional) – optional argument, number of nodes participating in distributed configuration. If not None,
run()
will spawnnproc_per_node
processes that run input function with its arguments. Total world size is nproc_per_node * nnodes. This option is only supported by native torch distributed module. For other modules, please setupspawn_kwargs
with backend specific arguments.node_rank (int, optional) – optional argument, current machine index. Mandatory argument if
nnodes
is specified and larger than one. This option is only supported by native torch distributed module. For other modules, please setupspawn_kwargs
with backend specific arguments.master_addr (str, optional) – optional argument, master node TCP/IP address for torch native backends (nccl, gloo). Mandatory argument if
nnodes
is specified and larger than one.master_port (int, optional) – optional argument, master node port for torch native backends (nccl, gloo). Mandatory argument if
master_addr
is specified.**spawn_kwargs – kwargs to
idist.spawn
function.
Changed in version 0.4.2:
backend
now accepts horovod distributed framework.- run(func, *args, **kwargs)[source]#
Execute
func
with provided arguments in distributed context.Example
def training(local_rank, config, **kwargs): # ... print(idist.get_rank(), ": run with config:", config, "- backend=", idist.backend()) # ... with idist.Parallel(backend=backend) as parallel: parallel.run(training, config, a=1, b=2)
- Parameters
func (Callable) – function to execute. First argument of the function should be local_rank - local process index.
*args – positional arguments of
func
(without local_rank).**kwargs – keyword arguments of
func
.
- Return type
None
ignite.distributed.utils#
This module wraps common methods to fetch information about distributed configuration, initialize/finalize process group or spawn multiple processes.
Returns computation model's backend. |
|
Helper method to perform broadcast operation. |
|
Returns current device according to current distributed configuration. |
|
Returns available backends. |
|
Returns distributed configuration name (given by ignite) |
|
Returns world size of current distributed configuration. |
|
Returns process rank within current distributed configuration. |
|
Returns local process rank within current distributed configuration. |
|
Returns number of processes (or tasks) per node within current distributed configuration. |
|
Returns node rank within current distributed configuration. |
|
Returns number of nodes within current distributed configuration. |
|
Spawns |
|
Initializes distributed configuration according to provided |
|
Finalizes distributed configuration. |
|
Helper method to display distributed configuration via |
|
Method to hint the local rank in case if torch native distributed context is created by user without using |
|
Helper method to perform all reduce operation. |
|
Helper method to perform all gather operation. |
|
Helper method to synchronize all processes. |
|
Returns host name for current process within current distributed configuration. |
|
Helper method to force this module to synchronize with current distributed context. |
|
Decorator to filter handlers wrt a rank number |
- ignite.distributed.utils.has_native_dist_support#
True if torch.distributed is available
- ignite.distributed.utils.has_xla_support#
True if torch_xla package is found
- ignite.distributed.utils.all_gather(tensor)[source]#
Helper method to perform all gather operation.
- Parameters
tensor (torch.Tensor or number or str) – tensor or number or str to collect across participating processes.
- Returns
torch.Tensor of shape
(world_size * tensor.shape[0], tensor.shape[1], ...)
if input is a tensor or torch.Tensor of shape(world_size, )
if input is a number or List of strings if input is a string- Return type
- ignite.distributed.utils.all_reduce(tensor, op='SUM')[source]#
Helper method to perform all reduce operation.
- Parameters
tensor (torch.Tensor or number) – tensor or number to collect across participating processes.
op (str) – reduction operation, “SUM” by default. Possible values: “SUM”, “PRODUCT”, “MIN”, “MAX”, “AND”, “OR”. Please, several values are not supported for the backend like “horovod”.
- Returns
torch.Tensor or number
- Return type
- ignite.distributed.utils.backend()[source]#
Returns computation model’s backend.
None for no distributed configuration
“nccl” or “gloo” or “mpi” for native torch distributed configuration
“xla-tpu” for XLA distributed configuration
“horovod” for Horovod distributed framework
Changed in version 0.4.2: Added Horovod distributed framework.
- ignite.distributed.utils.barrier()[source]#
Helper method to synchronize all processes.
- Return type
None
- ignite.distributed.utils.broadcast(tensor, src=0)[source]#
Helper method to perform broadcast operation.
- Parameters
tensor (torch.Tensor or number or str) – tensor or number or str to broadcast to participating processes. Make sure to respect dtype of torch tensor input for all processes, otherwise execution will crash.
src (int) – source rank. Default, 0.
- Returns
torch.Tensor or string or number
- Return type
Examples
if idist.get_rank() == 0: t1 = torch.rand(4, 5, 6, device=idist.device()) s1 = "abc" x = 12.3456 else: t1 = torch.empty(4, 5, 6, device=idist.device()) s1 = "" x = 0.0 # Broadcast tensor t1 from rank 0 to all processes t1 = idist.broadcast(t1, src=0) assert isinstance(t1, torch.Tensor) # Broadcast string s1 from rank 0 to all processes s1 = idist.broadcast(s1, src=0) # >>> s1 = "abc" # Broadcast float number x from rank 0 to all processes x = idist.broadcast(x, src=0) # >>> x = 12.3456
New in version 0.4.2.
- ignite.distributed.utils.device()[source]#
Returns current device according to current distributed configuration.
torch.device(“cpu”) if no distributed configuration or torch native gloo distributed configuration
torch.device(“cuda:local_rank”) if torch native nccl or horovod distributed configuration
torch.device(“xla:index”) if XLA distributed configuration
- Returns
torch.device
- Return type
Changed in version 0.4.2: Added Horovod distributed framework.
- ignite.distributed.utils.finalize()[source]#
Finalizes distributed configuration. For example, in case of native pytorch distributed configuration, it calls
dist.destroy_process_group()
.- Return type
None
- ignite.distributed.utils.get_local_rank()[source]#
Returns local process rank within current distributed configuration. Returns 0 if no distributed configuration.
- Return type
- ignite.distributed.utils.get_nnodes()[source]#
Returns number of nodes within current distributed configuration. Returns 1 if no distributed configuration.
- Return type
- ignite.distributed.utils.get_node_rank()[source]#
Returns node rank within current distributed configuration. Returns 0 if no distributed configuration.
- Return type
- ignite.distributed.utils.get_nproc_per_node()[source]#
Returns number of processes (or tasks) per node within current distributed configuration. Returns 1 if no distributed configuration.
- Return type
- ignite.distributed.utils.get_rank()[source]#
Returns process rank within current distributed configuration. Returns 0 if no distributed configuration.
- Return type
- ignite.distributed.utils.get_world_size()[source]#
Returns world size of current distributed configuration. Returns 1 if no distributed configuration.
- Return type
- ignite.distributed.utils.hostname()[source]#
Returns host name for current process within current distributed configuration.
- Return type
- ignite.distributed.utils.initialize(backend, **kwargs)[source]#
Initializes distributed configuration according to provided
backend
Examples
Launch single node multi-GPU training with
torch.distributed.launch
utility.# >>> python -m torch.distributed.launch --nproc_per_node=4 main.py # main.py import ignite.distributed as idist def train_fn(local_rank, a, b, c): import torch.distributed as dist assert dist.is_available() and dist.is_initialized() assert dist.get_world_size() == 4 device = idist.device() assert device == torch.device(f"cuda:{local_rank}") idist.initialize("nccl") local_rank = idist.get_local_rank() train_fn(local_rank, a, b, c) idist.finalize()
- Parameters
- Return type
None
Changed in version 0.4.2:
backend
now accepts horovod distributed framework.
- ignite.distributed.utils.model_name()[source]#
Returns distributed configuration name (given by ignite)
serial for no distributed configuration
native-dist for native torch distributed configuration
xla-dist for XLA distributed configuration
horovod-dist for Horovod distributed framework
Changed in version 0.4.2: horovod-dist will be returned for Horovod distributed framework.
- Return type
- ignite.distributed.utils.one_rank_only(rank=0, with_barrier=False)[source]#
Decorator to filter handlers wrt a rank number
- Parameters
- Return type
engine = ... @engine.on(...) @one_rank_only() # means @one_rank_only(rank=0) def some_handler(_): ... @engine.on(...) @one_rank_only(rank=1) def some_handler(_): ...
- ignite.distributed.utils.set_local_rank(index)[source]#
Method to hint the local rank in case if torch native distributed context is created by user without using
initialize()
orspawn()
.Usage:
User set up torch native distributed process group
import ignite.distributed as idist def run(local_rank, *args, **kwargs): idist.set_local_rank(local_rank) # ... dist.init_process_group(**dist_info) # ...
- Parameters
index (int) – local rank or current process index
- Return type
None
- ignite.distributed.utils.show_config()[source]#
Helper method to display distributed configuration via
logging
.- Return type
None
- ignite.distributed.utils.spawn(backend, fn, args, kwargs_dict=None, nproc_per_node=1, **kwargs)[source]#
Spawns
nproc_per_node
processes that runfn
withargs
/kwargs_dict
and initialize distributed configuration defined bybackend
.Examples
Launch single node multi-GPU training using torch native distributed framework
# >>> python main.py # main.py import ignite.distributed as idist def train_fn(local_rank, a, b, c, d=12): import torch.distributed as dist assert dist.is_available() and dist.is_initialized() assert dist.get_world_size() == 4 device = idist.device() assert device == torch.device(f"cuda:{local_rank}") idist.spawn("nccl", train_fn, args=(a, b, c), kwargs_dict={"d": 23}, nproc_per_node=4)
Launch multi-node multi-GPU training using torch native distributed framework
# >>> (node 0): python main.py --node_rank=0 --nnodes=8 --master_addr=master --master_port=2222 # >>> (node 1): python main.py --node_rank=1 --nnodes=8 --master_addr=master --master_port=2222 # >>> ... # >>> (node 7): python main.py --node_rank=7 --nnodes=8 --master_addr=master --master_port=2222 # main.py import torch import ignite.distributed as idist def train_fn(local_rank, nnodes, nproc_per_node): import torch.distributed as dist assert dist.is_available() and dist.is_initialized() assert dist.get_world_size() == nnodes * nproc_per_node device = idist.device() assert device == torch.device(f"cuda:{local_rank}") idist.spawn( "nccl", train_fn, args=(nnodes, nproc_per_node), nproc_per_node=nproc_per_node, nnodes=nnodes, node_rank=node_rank, master_addr=master_addr, master_port=master_port )
Launch single node multi-TPU training (for example on Google Colab) using PyTorch/XLA
# >>> python main.py # main.py import ignite.distributed as idist def train_fn(local_rank, a, b, c, d=12): import torch_xla.core.xla_model as xm assert xm.get_world_size() == 8 device = idist.device() assert "xla" in device.type idist.spawn("xla-tpu", train_fn, args=(a, b, c), kwargs_dict={"d": 23}, nproc_per_node=8)
- Parameters
backend (str) – backend to use: nccl, gloo, xla-tpu, horovod
fn (function) – function to called as the entrypoint of the spawned process. This function must be defined at the top level of a module so it can be pickled and spawned. This is a requirement imposed by multiprocessing. The function is called as
fn(i, *args, **kwargs_dict)
, where i is the process index and args is the passed through tuple of arguments.args (tuple) – arguments passed to fn.
kwargs_dict (Mapping) – kwargs passed to fn.
nproc_per_node (int) – number of processes to spawn on a single node. Default, 1.
**kwargs –
acceptable kwargs according to provided backend:
- ”nccl” or “gloo” : nnodes (default, 1), node_rank (default, 0), master_addr(default, “127.0.0.1”), master_port (default, 2222), timeout to dist.init_process_group functionand kwargs for mp.start_processes function.
- ”xla-tpu” : nnodes (default, 1), node_rank (default, 0) and kwargs to xmp.spawn function.
- ”horovod”: hosts (default, None) and other kwargs to hvd_run function. Arguments nnodes=1and node_rank=0 are tolerated and ignored, otherwise an exception is raised.
- Return type
None
Changed in version 0.4.2:
backend
now accepts horovod distributed framework.
- ignite.distributed.utils.sync(temporary=False)[source]#
Helper method to force this module to synchronize with current distributed context. This method should be used when distributed context is manually created or destroyed.
- Parameters
temporary (bool) – If True, distributed model synchronization is done every call of
idist.get_*
methods. This may have a negative performance impact.- Return type
None