Shortcuts

Source code for torch.distributed.optim.post_localSGD_optimizer

# mypy: allow-untyped-defs
import warnings

import torch
import torch.distributed.algorithms.model_averaging.averagers as averagers


[docs]class PostLocalSGDOptimizer(torch.optim.Optimizer): r""" Wraps an arbitrary :class:`torch.optim.Optimizer` and runs `post-local SGD <https://arxiv.org/abs/1808.07217>`_, This optimizer runs local optimizer at every step. After the warm-up stage, it averages parameters periodically afer the local optimizer is applied. Args: optim: The local optimizer. averager: A model averager instance to run post-localSGD algorithm. Example:: >>> # xdoctest: +SKIP("undefined variables") >>> import torch >>> import torch.distributed as dist >>> import torch.distributed.algorithms.model_averaging.averagers as averagers >>> import torch.nn as nn >>> from torch.distributed.optim import PostLocalSGDOptimizer >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import ( >>> PostLocalSGDState, >>> post_localSGD_hook, >>> ) >>> >>> model = nn.parallel.DistributedDataParallel( >>> module, device_ids=[rank], output_device=rank >>> ) >>> >>> # Register a post-localSGD communication hook. >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100) >>> model.register_comm_hook(state, post_localSGD_hook) >>> >>> # Create a post-localSGD optimizer that wraps a local optimizer. >>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as >>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``. >>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01) >>> opt = PostLocalSGDOptimizer( >>> optim=local_optim, >>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100) >>> ) >>> >>> # In the first 100 steps, DDP runs global gradient averaging at every step. >>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default), >>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer. >>> for step in range(0, 200): >>> opt.zero_grad() >>> loss = loss_fn(output, labels) >>> loss.backward() >>> opt.step() """ def __init__(self, optim: torch.optim.Optimizer, averager: averagers.ModelAverager): self.optim = optim self.param_groups = self.optim.param_groups self.averager = averager @property def state(self): return self.optim.state def __repr__(self): return self.optim.__repr__()
[docs] def state_dict(self): r""" This is the same as :class:`torch.optim.Optimizer` :meth:`state_dict`, but adds an extra entry to record model averager's step to the checkpoint to ensure reload does not cause unnecessary warm up again. """ optim_state_dict = self.optim.state_dict() optim_state_dict["step"] = self.averager.step return optim_state_dict
[docs] def load_state_dict(self, state_dict): r""" This is the same as :class:`torch.optim.Optimizer` :meth:`load_state_dict`, but also restores model averager's step value to the one saved in the provided ``state_dict``. If there is no ``"step"`` entry in ``state_dict``, it will raise a warning and initialize the model averager's step to 0. """ self.optim.load_state_dict(state_dict) if "step" in state_dict: self.averager.step = state_dict["step"] else: warnings.warn( "Loaded state dict does not contain a step counter for an averager. " "Setting step counter to 0." ) self.averager.step = 0
[docs] def step(self): r""" Performs a single optimization step (parameter update). """ self.optim.step() self.averager.average_parameters(params=self.param_groups)
def zero_grad(self, set_to_none: bool = True): # type: ignore[override] self.optim.zero_grad(set_to_none=set_to_none) def add_param_group(self, param_group): self.optim.add_param_group(param_group)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources