Source code for torch.distributed.optim.post_localSGD_optimizer

import torch
import torch.distributed.algorithms.model_averaging.averagers as averagers

[docs]class PostLocalSGDOptimizer(torch.optim.Optimizer): r""" Wraps an arbitrary :class:`torch.optim.Optimizer` and runs `post-local SGD <>`_, This optimizer runs local optimizer at every step. After the warm-up stage, it averages parameters periodically afer the local optimizer is applied. Args: optim: The local optimizer. averager: A model averager instance to run post-localSGD algorithm. Example:: >>> import torch >>> import torch.distributed as dist >>> import torch.distributed.algorithms.model_averaging.averagers as averagers >>> import torch.nn as nn >>> from torch.distributed.optim import PostLocalSGDOptimizer >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import post_localSGD_hook >>> >>> model = nn.parallel.DistributedDataParallel( >>> module, device_ids=[rank], output_device=rank >>> ) >>> >>> # Register a post-localSGD communication hook. >>> subgroup, subgroups = dist.new_subgroups() >>> state = PostLocalSGDState(subgroup=subgroup, start_localSGD_iter=100) >>> model.register_comm_hook(state, post_localSGD_hook) >>> >>> # Create a post-localSGD optimizer that wraps a local optimizer. >>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as >>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``. >>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01) >>> opt = PostLocalSGDOptimizer( >>> optim=local_optim, >>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100) >>> ) >>> >>> # In the first 100 steps, DDP runs global gradient averaging at every step. >>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default), >>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer. >>> for step in range(0, 200): >>> opt.zero_grad() >>> loss = loss_fn(output, labels) >>> loss.backward() >>> opt.step() """ def __init__( self, optim: torch.optim.Optimizer, averager: averagers.ModelAverager ): self.optim = optim self.param_groups = self.optim.param_groups self.averager = averager @property def state(self): return self.optim.state def __repr__(self): return self.optim.__repr__() def state_dict(self): return self.optim.state_dict() def load_state_dict(self, state_dict): self.optim.load_state_dict(state_dict)
[docs] def step(self): r""" Performs a single optimization step (parameter update). """ self.optim.step() for param_group in self.param_groups: for params in param_group["params"]: if params.grad is None: continue self.averager.average_parameters(iter(params))
def zero_grad(self): self.optim.zero_grad() def add_param_group(self, param_group): self.optim.add_param_group(param_group)


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources