• Docs >
  • Distributed Optimizers
Shortcuts

Distributed Optimizers

class torch.distributed.optim.ZeroRedundancyOptimizer(params, optimizer_class, process_group=None, parameters_as_bucket_view=False, overlap_with_ddp=False, **defaults)[source]

This class wraps an arbitrary optim.Optimizer and shards its states across ranks in the group as described by ZeRO. The local optimizer instance in each rank is only responsible for updating approximately 1 / world_size parameters and hence only needs to keep 1 / world_size optimizer states. After parameters are updated locally, each rank will broadcast its parameters to all other peers to keep all model replicas in the same state. ZeroRedundancyOptimizer can be used in conjunction with torch.nn.parallel.DistributedDataParallel to reduce per-rank peak memory consumption.

ZeroRedundancyOptimizer uses a sorted-greedy algorithm to pack a number of parameters at each rank. Each parameter belongs to a single rank and is not divided among ranks. The partition is arbitrary and might not match the the parameter registration or usage order.

Parameters

params (Iterable) – an Iterable of torch.Tensor s giving all parameters, which will be sharded across ranks.

Keyword Arguments
  • optimizer_class (torch.nn.Optimizer) – the class of the local optimizer.

  • process_group (ProcessGroup, optional) – torch.distributed ProcessGroup (default: dist.group.WORLD initialized by torch.distributed.init_process_group()).

  • parameters_as_bucket_view (bool, optional) – if True, parameters are packed into buckets to speed up communication, and param.data fields point to bucket views at different offsets; if False, each individual parameter is communicated separately, and each params.data stays intact (default: False).

  • overlap_with_ddp (bool, optional) – if True, step() is overlapped with DistributedDataParallel ‘s gradient synchronization; this requires (1) either a functional optimizer for the optimizer_class argument or one with a functional equivalent and (2) registering a DDP communication hook constructed from one of the functions in ddp_zero_hook.py; parameters are packed into buckets matching those in DistributedDataParallel, meaning that the parameters_as_bucket_view argument is ignored. If False, step() runs disjointly after the backward pass (per normal). (default: False)

  • **defaults – any trailing arguments, which are forwarded to the local optimizer.

Example:

>>> import torch.nn as nn
>>> from torch.distributed.optim import ZeroRedundancyOptimizer
>>> from torch.nn.parallel import DistributedDataParallel as DDP

>>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
>>> ddp = DDP(model, device_ids=[rank])
>>> opt = ZeroRedundancyOptimizer(
>>>     ddp.parameters(),
>>>     optimizer_class=torch.optim.Adam,
>>>     lr=0.01
>>> )
>>> ddp(inputs).sum().backward()
>>> opt.step()

Warning

Currently, ZeroRedundancyOptimizer requires that all of the passed-in parameters are the same dense type.

Warning

If you pass overlap_with_ddp=True, be wary of the following: Given the way that overlapping DistributedDataParallel with ZeroRedundancyOptimizer is currently implemented, the first two or three training iterations do not perform parameter updates in the optimizer step, depending on if static_graph=False or static_graph=True, respectively. This is because it needs information about the gradient bucketing strategy used by DistributedDataParallel, which is not finalized until the second forward pass if static_graph=False or until the third forward pass if static_graph=True. To adjust for this, one option is to prepend dummy inputs.

Warning

ZeroRedundancyOptimizer is experimental and subject to change.

add_param_group(param_group)[source]

Add a parameter group to the Optimizer ‘s param_groups.

This can be useful when fine tuning a pre-trained network, as frozen layers can be made trainable and added to the Optimizer as training progresses.

Parameters

param_group (dict) – specifies the parameters to be optimized and group-specific optimization options.

Warning

This method handles updating the shards on all partitions but needs to be called on all ranks. Calling this on a subset of the ranks will cause the training to hang because communication primitives are called depending on the managed parameters and expect all the ranks to participate on the same set of parameters.

consolidate_state_dict(to=0)[source]

Consolidate a list of state_dict s (one per rank) on the target rank.

Parameters

to (int) – the rank that receives the optimizer states (default: 0).

Raises

RuntimeError – if overlap_with_ddp=True and this method is called before this ZeroRedundancyOptimizer instance has been fully initialized, which happens once DistributedDataParallel gradient buckets have been rebuilt.

Warning

This needs to be called on all ranks.

join_hook(**kwargs)[source]

Returns the ZeRO join hook, which enables training on uneven inputs by shadowing the collective communications in the optimizer step.

Gradients must be properly set before this hook is called.

Parameters

kwargs (dict) – a dict containing any keyword arguments to modify the behavior of the join hook at run time; all Joinable instances sharing the same join context manager are forwarded the same value for kwargs.

This hook does not support any keyword arguments; i.e. kwargs is unused.

load_state_dict(state_dict)[source]

Load the state pertaining to the given rank from the input state_dict, updating the local optimizer as needed.

Parameters

state_dict (dict) – optimizer state; should be an object returned from a call to state_dict().

Raises

RuntimeError – if overlap_with_ddp=True and this method is called before this ZeroRedundancyOptimizer instance has been fully initialized, which happens once DistributedDataParallel gradient buckets have been rebuilt.

state_dict()[source]

Returns the last global optimizer state known to this rank.

Raises

RuntimeError – if overlap_with_ddp=True and this method is called before this ZeroRedundancyOptimizer instance has been fully initialized, which happens once DistributedDataParallel gradient buckets have been rebuilt; or if this method is called without a preceding call to consolidate_state_dict().

step(closure=None, **kwargs)[source]

Performs a single optimizer step and syncs parameters across all ranks.

Parameters

closure (callable) – a closure that re-evaluates the model and returns the loss; optional for most optimizers.

Returns

Optional loss depending on the underlying local optimizer.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources