• Docs >
  • Distributed Optimizers

Distributed Optimizers

class torch.distributed.optim.ZeroRedundancyOptimizer(params, optimizer_class, group=None, parameters_as_bucket_view=False, **default)[source]

This class wraps an arbitrary optim.Optimizer and shards its states across ranks in the group as described by ZeRO. The optimizer instance in each rank is only responsible for updating 1 / world_size parameters and hence only needs to keep 1 / world_size optimizer states. After parameters are updated locally, each rank will broadcast its parameters to all other peers to keep all model replicas in the same state. ZeroRedundancyOptimizer can be used in conjunction with torch.nn.parallel.DistributedDataparallel to reduce per-rank peak memory consumption.

ZeroRedundancyOptimizer use a greedy algorithm to pack a number of parameters at each rank. Each parameter belongs to a single rank and is not divided among ranks. The partition is arbitrary and might not match the the parameter registration or usage order.


params (Iterable) – an Iterable of torch.Tensor s

Keyword Arguments
  • optimizer_class (torch.nn.Optimizer) – the class of the local optimizer.

  • group (ProcessGroup, optional) – torch.distributed ProcessGroup (default: group.WORLD initialized by torch.distributed.init_process_group()).

  • parameters_as_bucket_views (bool) – when enabled, parameters will be packed into larger buckets to speed up communication and param.data fields will point to bucket views at different offsets. When disabled, each individual parameter will be communicated separately, but params.data will stay intact.

  • **default – all trailing arguments will be forwarded to the given optimizer.


>>> import torch.nn as nn
>>> from torch.distributed.optim import ZeroRedundancyOptimizer
>>> from torch.nn.parallel import DistributedDataParallel as DDP

>>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
>>> ddp = DDP(model, device_ids=[rank])
>>> opt = ZeroRedundancyOptimizer(
>>>     ddp.parameters(),
>>>     optimizer_class=torch.optim.Adam,
>>>     lr=0.01
>>> )
>>> ddp(inputs).sum().backward()
>>> opt.step()

Add a param group to the Optimizer s param_groups.

This can be useful when fine tuning a pre-trained network, as frozen layers can be made trainable and added to the Optimizer as training progresses.


param_group (dict) – Specifies what Tensors should be optimized along with group specific optimization options.


Update the consolidated state_dict list, one per rank.


to (int) – the rank that receives the global states. (default: 0)


Restore the global parameter groups as well as the shard.


state_dict (dict) – optimizer state. Should be an object returned from a call to state_dict()


Gets this rank’s state_dict.


The state of the optimizer as a dict. It contains two entries:

  • state - a dict holding current optimization state. Its content

    differs between optimizer classes.

  • param_groups - a dict containing all parameter groups


Partitions parameters across distributed data parallel ranks.


a list of param_groups (which is a list of dict) where each element of the list contains the param_groups for a rank. Element 0 corresponds to rank 0, etc. We need all the ranks for the broadcast inside step().

static rank_local_state_dict(rank, state_dict)[source]

Returns the local_state_dict for a given rank.

  • rank (int) – rank to get local_state_dict for

  • state_dict (dict) – global state_dict


the last known global optimizer state, which consist of a list of the shards.

step(closure=None, **kwargs)[source]

Performs a single optimization step (parameter update).


closure (callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.


optional loss, depends on the underlying optimizer


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources