DDP Communication Hooks ======================= DDP communication hook is a generic interface to control how to communicate gradients across workers by overriding the vanilla allreduce in `DistributedDataParallel `_. A few built-in communication hooks are provided, and users can easily apply any of these hooks to optimize communication. Besides, the hook interface can also support user-defined communication strategies for more advanced use cases. How to Use a Communication Hook? -------------------------------- To use a communication hook, the user just needs to let the DDP model register the hook before the training loop as below. :func:`torch.nn.parallel.DistributedDataParallel.register_comm_hook` What Does a Communication Hook Operate On? ------------------------------------------ A communication hook provides a flexible way to allreduce gradients. Therefore, it mainly operates on the gradients on each replica before allreduce, which are bucketized to increase the overlap between communication and computation. Particularly, :class:`torch.distributed.GradBucket` represents a bucket of gradient tensors to be allreduced. .. autoclass:: torch.distributed.GradBucket .. autofunction:: torch.distributed.GradBucket.index .. autofunction:: torch.distributed.GradBucket.buffer .. autofunction:: torch.distributed.GradBucket.gradients .. autofunction:: torch.distributed.GradBucket.is_last .. autofunction:: torch.distributed.GradBucket.set_buffer .. autofunction:: torch.distributed.GradBucket.parameters Default Communication Hooks --------------------------- Default communication hooks are simple **stateless** hooks, so the input state in ``register_comm_hook`` is either a process group or ``None``. The input ``bucket`` is a :class:`torch.distributed.GradBucket` object. .. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.default_hooks .. autofunction:: allreduce_hook .. autofunction:: fp16_compress_hook .. autofunction:: bf16_compress_hook Additionally, a communication hook wrapper is provided to support :meth:`~fp16_compress_hook` or :meth:`~bf16_compress_hook` as a wrapper, which can be combined with other communication hooks. .. autofunction:: fp16_compress_wrapper .. autofunction:: bf16_compress_wrapper PowerSGD Communication Hook --------------------------- PowerSGD (`Vogels et al., NeurIPS 2019 `_) is a gradient compression algorithm, which can provide very high compression rates and accelerate bandwidth-bound distributed training. This algorithm needs to maintain both some hyperparameters and the internal state. Therefore, PowerSGD communication hook is a **stateful** hook, and the user needs to provide a state object defined as below. PowerSGD State ^^^^^^^^^^^^^^^^ .. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook .. autoclass:: PowerSGDState PowerSGD Hooks ^^^^^^^^^^^^^^^^ .. warning :: PowerSGD typically requires extra memory of the same size as the model's gradients to enable error feedback, which can compensate for biased compressed communication and improve accuracy. .. warning :: PowerSGD hooks may conflict with `Apex automatic mixed precision package `_. Please use PyTorch `native automatic mixed precision package `_ instead. .. autofunction:: powerSGD_hook .. autofunction:: batched_powerSGD_hook Debugging Communication Hooks ----------------------------- As the name implies, debugging communication hooks are **only** used for debugging and performance optimization purpose. .. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks .. warning :: Debugging communication hooks do not necessarily output the correct results. .. autofunction:: noop_hook Checkpointing of Communication Hooks ------------------------------------ .. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook A stateful communication hook can be saved as a part of model checkpointing to enable trainer restarts. To make a hook serializable, ``__setstate__`` and ``__getstate__`` should be defined. .. warning :: ``__getstate__`` should exclude non-serializable attributes from a returned dictionary. .. warning :: ``__setstate__`` should properly initialize non-serializable attributes, excluded from a provided ``state``. :class:`PowerSGDState` has ``__setstate__`` and ``__getstate__`` implemented and can be used as a reference. .. class:: PowerSGDState :noindex: .. automethod:: PowerSGDState.__getstate__ .. automethod:: PowerSGDState.__setstate__ Here is a simple, end-to-end example of saving and reloading PowerSGD state and hook. :: import os import sys import tempfile import torch import torch.distributed as dist import torch.nn as nn import torch.optim as optim from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD class SimpleModel(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(24,24) self.relu = nn.ReLU() self.fc2 = nn.Linear(24,12) def forward(self, x): return self.fc2(self.relu(self.fc1(x))) def setup(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' # initialize the process group dist.init_process_group("nccl", rank=rank, world_size=world_size) def cleanup(): dist.destroy_process_group() def run_demo(demo_fn, world_size): mp.spawn( demo_fn, args=(world_size,), nprocs=world_size, join=True) def demo_serialization(rank, world_size): setup(rank, world_size) CHECKPOINT = tempfile.gettempdir() + "/checkpoint.pt" model = SimpleModel().to(rank) ddp_model = DistributedDataParallel(model, device_ids=[rank]) powersgd_hook = powerSGD.powerSGD_hook powersgd_state = powerSGD.PowerSGDState(process_group=None) optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) ddp_model.register_comm_hook(powersgd_state, powersgd_hook) state = { 'state_dict': ddp_model.state_dict(), 'comm_hook': hook, 'comm_hook_state': hook_state} if rank == 0: torch.save(state, CHECKPOINT) dist.barrier() map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} checkpoint = torch.load(CHECKPOINT, map_location=map_location) ddp_model.load_state_dict(checkpoint['state_dict']) powersgd_hook = checkpoint['comm_hook'] powersgd_state = checkpoint['comm_hook_state'] ddp_model.register_comm_hook(powersgd_state, powersgd_hook) if rank == 0: os.remove(CHECKPOINT) cleanup() if __name__ == "__main__": n_gpus = torch.cuda.device_count() assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" world_size = n_gpus run_demo(demo_serialization, world_size) Acknowledgements ---------------- Many thanks to PowerSGD paper author **Thijs Vogels** for the code review on PowerSGD communication hook, as well as the `comparison experiments `_, which show that the performance of PowerSGD communication hook is on par with the implementation in the original `paper `_.