ignite.distributed#
Helper module to use distributed settings for multiple backends:
backends from native torch distributed configuration: “nccl”, “gloo”, “mpi”
XLA on TPUs via pytorch/xla
using Horovod framework as a backend
Distributed launcher and auto helpers#
We provide a context manager to simplify the code of distributed configuration setup for all above supported backends.
In addition, methods like auto_model()
, auto_optim()
and
auto_dataloader()
helps to adapt in a transparent way provided model, optimizer and data
loaders to existing configuration:
# main.py
import ignite.distributed as idist
def training(local_rank, config, **kwargs):
print(idist.get_rank(), ": run with config:", config, "- backend=", idist.backend())
train_loader = idist.auto_dataloader(dataset, batch_size=32, num_workers=12, shuffle=True, **kwargs)
# batch size, num_workers and sampler are automatically adapted to existing configuration
# ...
model = resnet50()
model = idist.auto_model(model)
# model is DDP or DP or just itself according to existing configuration
# ...
optimizer = optim.SGD(model.parameters(), lr=0.01)
optimizer = idist.auto_optim(optimizer)
# optimizer is itself, except XLA configuration and overrides `step()` method.
# User can safely call `optimizer.step()` (behind `xm.optimizer_step(optimizier)` is performed)
backend = "nccl" # torch native distributed configuration on multiple GPUs
# backend = "xla-tpu" # XLA TPUs distributed configuration
# backend = None # no distributed configuration
#
# dist_configs = {'nproc_per_node': 4} # Use specified distributed configuration if launch as python main.py
# dist_configs["start_method"] = "fork" # Add start_method as "fork" if using Jupyter Notebook
with idist.Parallel(backend=backend, **dist_configs) as parallel:
parallel.run(training, config, a=1, b=2)
Above code may be executed with torch.distributed.launch tool or by python and specifying distributed configuration
in the code. For more details, please, see Parallel
,
auto_model()
, auto_optim()
and
auto_dataloader()
.
Complete example of CIFAR10 training can be found here.
ignite.distributed.auto#
Distributed sampler proxy to adapt user's sampler for distributed data parallelism configuration. |
|
Helper method to create a dataloader adapted for non-distributed and distributed configurations (supporting all available backends from |
|
Helper method to adapt provided model for non-distributed and distributed configurations (supporting all available backends from |
|
Helper method to adapt optimizer for non-distributed and distributed configurations (supporting all available backends from |
Note
In distributed configuration, methods auto_model()
, auto_optim()
and auto_dataloader()
will have effect only when distributed group is initialized.
ignite.distributed.launcher#
Distributed launcher context manager to simplify distributed configuration setup for multiple backends: |
ignite.distributed.utils#
This module wraps common methods to fetch information about distributed configuration, initialize/finalize process group or spawn multiple processes.
Returns computation model's backend. |
|
Helper method to perform broadcast operation. |
|
Returns current device according to current distributed configuration. |
|
Returns available backends. |
|
Returns distributed configuration name (given by ignite) |
|
Returns world size of current distributed configuration. |
|
Returns process rank within current distributed configuration. |
|
Returns local process rank within current distributed configuration. |
|
Returns number of processes (or tasks) per node within current distributed configuration. |
|
Returns node rank within current distributed configuration. |
|
Returns number of nodes within current distributed configuration. |
|
Spawns |
|
Initializes distributed configuration according to provided |
|
Finalizes distributed configuration. |
|
Helper method to display distributed configuration via |
|
Method to hint the local rank in case if torch native distributed context is created by user without using |
|
Helper method to perform all reduce operation. |
|
Helper method to perform all gather operation. |
|
Helper method to synchronize all processes. |
|
Returns host name for current process within current distributed configuration. |
|
Helper method to force this module to synchronize with current distributed context. |
|
Decorator to filter handlers wrt a rank number |
- ignite.distributed.utils.has_native_dist_support#
True if torch.distributed is available
- ignite.distributed.utils.has_xla_support#
True if torch_xla package is found
- ignite.distributed.utils.all_gather(tensor)[source]#
Helper method to perform all gather operation.
- Parameters
tensor (Union[Tensor, float, str]) – tensor or number or str to collect across participating processes.
- Returns
torch.Tensor of shape
(world_size * tensor.shape[0], tensor.shape[1], ...)
if input is a tensor or torch.Tensor of shape(world_size, )
if input is a number or List of strings if input is a string- Return type
- ignite.distributed.utils.all_reduce(tensor, op='SUM')[source]#
Helper method to perform all reduce operation.
- Parameters
- Returns
torch.Tensor or number
- Return type
- ignite.distributed.utils.backend()[source]#
Returns computation model’s backend.
None for no distributed configuration
“nccl” or “gloo” or “mpi” for native torch distributed configuration
“xla-tpu” for XLA distributed configuration
“horovod” for Horovod distributed framework
Changed in version 0.4.2: Added Horovod distributed framework.
- ignite.distributed.utils.barrier()[source]#
Helper method to synchronize all processes.
- Return type
None
- ignite.distributed.utils.broadcast(tensor, src=0, safe_mode=False)[source]#
Helper method to perform broadcast operation.
- Parameters
tensor (Optional[Union[Tensor, float, str]]) – tensor or number or str to broadcast to participating processes. Make sure to respect data type of torch tensor input for all processes, otherwise execution will crash. Can use None for non-source data with
safe_mode=True
.src (int) – source rank. Default, 0.
safe_mode (bool) – if True, non source input data can be
None
or anything (will be discarded), otherwise data type of the inputtensor
should be respected for all processes. Please, keep in mind, this mode is working only for dense tensors as source input if a tensor is provided. There are additional collective ops are performed before doing the broadcast and, thus, can be slower than without using this mode. Default, False.
- Returns
torch.Tensor or string or number
- Return type
Examples
y = None if idist.get_rank() == 0: t1 = torch.rand(4, 5, 6, device=idist.device()) s1 = "abc" x = 12.3456 y = torch.rand(1, 2, 3, device=idist.device()) else: t1 = torch.empty(4, 5, 6, device=idist.device()) s1 = "" x = 0.0 # Broadcast tensor t1 from rank 0 to all processes t1 = idist.broadcast(t1, src=0) assert isinstance(t1, torch.Tensor) # Broadcast string s1 from rank 0 to all processes s1 = idist.broadcast(s1, src=0) # >>> s1 = "abc" # Broadcast float number x from rank 0 to all processes x = idist.broadcast(x, src=0) # >>> x = 12.3456 # Broadcast any of those types from rank 0, # but other ranks do not define the placeholder y = idist.broadcast(y, src=0, safe_mode=True) assert isinstance(y, torch.Tensor)
New in version 0.4.2.
Changed in version 0.4.5: added
safe_mode
- ignite.distributed.utils.device()[source]#
Returns current device according to current distributed configuration.
torch.device(“cpu”) if no distributed configuration or torch native gloo distributed configuration
torch.device(“cuda:local_rank”) if torch native nccl or horovod distributed configuration
torch.device(“xla:index”) if XLA distributed configuration
- Returns
torch.device
- Return type
Changed in version 0.4.2: Added Horovod distributed framework.
- ignite.distributed.utils.finalize()[source]#
Finalizes distributed configuration. For example, in case of native pytorch distributed configuration, it calls
dist.destroy_process_group()
.- Return type
None
- ignite.distributed.utils.get_local_rank()[source]#
Returns local process rank within current distributed configuration. Returns 0 if no distributed configuration.
- Return type
- ignite.distributed.utils.get_nnodes()[source]#
Returns number of nodes within current distributed configuration. Returns 1 if no distributed configuration.
- Return type
- ignite.distributed.utils.get_node_rank()[source]#
Returns node rank within current distributed configuration. Returns 0 if no distributed configuration.
- Return type
- ignite.distributed.utils.get_nproc_per_node()[source]#
Returns number of processes (or tasks) per node within current distributed configuration. Returns 1 if no distributed configuration.
- Return type
- ignite.distributed.utils.get_rank()[source]#
Returns process rank within current distributed configuration. Returns 0 if no distributed configuration.
- Return type
- ignite.distributed.utils.get_world_size()[source]#
Returns world size of current distributed configuration. Returns 1 if no distributed configuration.
- Return type
- ignite.distributed.utils.hostname()[source]#
Returns host name for current process within current distributed configuration.
- Return type
- ignite.distributed.utils.initialize(backend, **kwargs)[source]#
Initializes distributed configuration according to provided
backend
- Parameters
backend (str) – backend: nccl, gloo, xla-tpu, horovod.
kwargs (Any) –
acceptable kwargs according to provided backend:
- ”nccl” or “gloo” :
timeout(=timedelta(minutes=30))
,init_method(=None)
,rank(=None)
,world_size(=None)
.By default,init_method
will be “env://”. See more info about parameters: torch_init. - ”horovod” : comm(=None), more info: hvd_init.
- Return type
None
Examples
Launch single node multi-GPU training with
torchrun
utility.# >>> torchrun --nproc_per_node=4 main.py # main.py import ignite.distributed as idist def train_fn(local_rank, a, b, c): import torch.distributed as dist assert dist.is_available() and dist.is_initialized() assert dist.get_world_size() == 4 device = idist.device() assert device == torch.device(f"cuda:{local_rank}") backend = "nccl" # or "gloo" or "horovod" or "xla-tpu" idist.initialize(backend) # or for torch native distributed on Windows: # idist.initialize("nccl", init_method="file://tmp/shared") local_rank = idist.get_local_rank() train_fn(local_rank, a, b, c) idist.finalize()
Changed in version 0.4.2:
backend
now accepts horovod distributed framework.Changed in version 0.4.5:
kwargs
now acceptsinit_method
,rank
,world_size
for PyTorch native distributed backend.
- ignite.distributed.utils.model_name()[source]#
Returns distributed configuration name (given by ignite)
serial for no distributed configuration
native-dist for native torch distributed configuration
xla-dist for XLA distributed configuration
horovod-dist for Horovod distributed framework
Changed in version 0.4.2: horovod-dist will be returned for Horovod distributed framework.
- Return type
- ignite.distributed.utils.one_rank_only(rank=0, with_barrier=False)[source]#
Decorator to filter handlers wrt a rank number
- Parameters
- Return type
Examples
engine = ... @engine.on(...) @one_rank_only() # means @one_rank_only(rank=0) def some_handler(_): ... @engine.on(...) @one_rank_only(rank=1) def some_handler(_): ...
- ignite.distributed.utils.set_local_rank(index)[source]#
Method to hint the local rank in case if torch native distributed context is created by user without using
initialize()
orspawn()
.- Parameters
index (int) – local rank or current process index
- Return type
None
Examples
User set up torch native distributed process group
import ignite.distributed as idist def run(local_rank, *args, **kwargs): idist.set_local_rank(local_rank) # ... dist.init_process_group(**dist_info) # ...
- ignite.distributed.utils.show_config()[source]#
Helper method to display distributed configuration via
logging
.- Return type
None
- ignite.distributed.utils.spawn(backend, fn, args, kwargs_dict=None, nproc_per_node=1, **kwargs)[source]#
Spawns
nproc_per_node
processes that runfn
withargs
/kwargs_dict
and initialize distributed configuration defined bybackend
.- Parameters
backend (str) – backend to use: nccl, gloo, xla-tpu, horovod
fn (Callable) – function to called as the entrypoint of the spawned process. This function must be defined at the top level of a module so it can be pickled and spawned. This is a requirement imposed by multiprocessing. The function is called as
fn(i, *args, **kwargs_dict)
, where i is the process index and args is the passed through tuple of arguments.args (Tuple) – arguments passed to fn.
nproc_per_node (int) – number of processes to spawn on a single node. Default, 1.
kwargs (Any) –
acceptable kwargs according to provided backend:
- ”nccl” or “gloo” :
nnodes
(default, 1),node_rank
(default, 0),master_addr
(default, “127.0.0.1”),master_port
(default, 2222),init_method
(default, “env://”),timeout to dist.init_process_group functionand kwargs for mp.start_processes function. - and
node_rank=0
are tolerated and ignored, otherwise an exception is raised.
- Return type
None
Examples
Launch single node multi-GPU training using torch native distributed framework
# >>> python main.py # main.py import ignite.distributed as idist def train_fn(local_rank, a, b, c, d=12): import torch.distributed as dist assert dist.is_available() and dist.is_initialized() assert dist.get_world_size() == 4 device = idist.device() assert device == torch.device(f"cuda:{local_rank}") idist.spawn("nccl", train_fn, args=(a, b, c), kwargs_dict={"d": 23}, nproc_per_node=4)
Launch multi-node multi-GPU training using torch native distributed framework
# >>> (node 0): python main.py --node_rank=0 --nnodes=8 --master_addr=master --master_port=2222 # >>> (node 1): python main.py --node_rank=1 --nnodes=8 --master_addr=master --master_port=2222 # >>> ... # >>> (node 7): python main.py --node_rank=7 --nnodes=8 --master_addr=master --master_port=2222 # main.py import torch import ignite.distributed as idist def train_fn(local_rank, nnodes, nproc_per_node): import torch.distributed as dist assert dist.is_available() and dist.is_initialized() assert dist.get_world_size() == nnodes * nproc_per_node device = idist.device() assert device == torch.device(f"cuda:{local_rank}") idist.spawn( "nccl", train_fn, args=(nnodes, nproc_per_node), nproc_per_node=nproc_per_node, nnodes=nnodes, node_rank=node_rank, master_addr=master_addr, master_port=master_port )
Launch single node multi-TPU training (for example on Google Colab) using PyTorch/XLA
# >>> python main.py # main.py import ignite.distributed as idist def train_fn(local_rank, a, b, c, d=12): import torch_xla.core.xla_model as xm assert xm.get_world_size() == 8 device = idist.device() assert "xla" in device.type idist.spawn("xla-tpu", train_fn, args=(a, b, c), kwargs_dict={"d": 23}, nproc_per_node=8)
Changed in version 0.4.2:
backend
now accepts horovod distributed framework.
- ignite.distributed.utils.sync(temporary=False)[source]#
Helper method to force this module to synchronize with current distributed context. This method should be used when distributed context is manually created or destroyed.
- Parameters
temporary (bool) – If True, distributed model synchronization is done every call of
idist.get_*
methods. This may have a negative performance impact.- Return type
None