• Docs >
  • Distributed Optimizers

Distributed Optimizers

class torch.distributed.optim.ZeroRedundancyOptimizer(params, optimizer_class, process_group=None, parameters_as_bucket_view=False, **defaults)[source]

This class wraps an arbitrary optim.Optimizer and shards its states across ranks in the group as described by ZeRO. The local optimizer instance in each rank is only responsible for updating approximately 1 / world_size parameters and hence only needs to keep 1 / world_size optimizer states. After parameters are updated locally, each rank will broadcast its parameters to all other peers to keep all model replicas in the same state. ZeroRedundancyOptimizer can be used in conjunction with torch.nn.parallel.DistributedDataParallel to reduce per-rank peak memory consumption.

ZeroRedundancyOptimizer uses a sorted-greedy algorithm to pack a number of parameters at each rank. Each parameter belongs to a single rank and is not divided among ranks. The partition is arbitrary and might not match the the parameter registration or usage order.


params (Iterable) – an Iterable of torch.Tensor s giving all parameters, which will be sharded across ranks.

Keyword Arguments
  • optimizer_class (torch.nn.Optimizer) – the class of the local optimizer.

  • process_group (ProcessGroup, optional) – torch.distributed ProcessGroup (default: dist.group.WORLD initialized by torch.distributed.init_process_group()).

  • parameters_as_bucket_view (bool) – when enabled, parameters are packed into larger buckets to speed up communication, and param.data fields point to bucket views at different offsets; when disabled, each individual parameter is communicated separately, but each params.data stays intact.

  • **defaults – any trailing arguments, which are forwarded to the local optimizer.


>>> import torch.nn as nn
>>> from torch.distributed.optim import ZeroRedundancyOptimizer
>>> from torch.nn.parallel import DistributedDataParallel as DDP

>>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
>>> ddp = DDP(model, device_ids=[rank])
>>> opt = ZeroRedundancyOptimizer(
>>>     ddp.parameters(),
>>>     optimizer_class=torch.optim.Adam,
>>>     lr=0.01
>>> )
>>> ddp(inputs).sum().backward()
>>> opt.step()

Add a parameter group to the Optimizer s param_groups.

This can be useful when fine tuning a pre-trained network, as frozen layers can be made trainable and added to the Optimizer as training progresses.


param_group (dict) – specifies the parameters to be optimized and group-specific optimization options.


Consolidate a list of state_dict s (one per rank) on the target rank.


to (int) – the rank that receives the optimizer states (default: 0).


Load the state pertaining to the given rank from the input state_dict, updating the local optimizer as needed.


state_dict (dict) – optimizer state; should be an object returned from a call to state_dict().


Returns the last global optimizer state known to this rank.

step(closure=None, **kwargs)[source]

Performs a single optimization step (parameter update).


closure (callable) – a closure that re-evaluates the model and returns the loss; optional for most optimizers.


Optional loss depending on the underlying local optimizer.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources