Shortcuts

Source code for torch.nn.parallel.distributed

# mypy: allow-untyped-defs
import copy
import functools
import inspect
import itertools
import logging
import os
import sys
import warnings
import weakref
from collections import defaultdict, deque
from contextlib import contextmanager
from dataclasses import dataclass, fields, is_dataclass
from enum import auto, Enum
from typing import Any, Callable, List, Optional, Tuple, Type, TYPE_CHECKING

import torch
import torch.distributed as dist
from torch.autograd import Function, Variable
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
from torch.utils._pytree import tree_flatten, tree_unflatten

RPC_AVAILABLE = False
if dist.is_available():
    from torch.distributed.distributed_c10d import (
        _get_default_group,
        _rank_not_in_group,
        ReduceOp,
    )
    from torch.distributed.utils import (
        _alloc_storage,
        _cast_forward_inputs,
        _free_storage,
        _sync_module_states,
        _to_kwargs,
        _verify_param_shape_across_processes,
    )
if torch.distributed.rpc.is_available():
    RPC_AVAILABLE = True
    from torch.distributed.rpc import RRef

from torch._utils import _get_device_index

from ..modules import Module
from .scatter_gather import gather, scatter_kwargs  # noqa: F401

if TYPE_CHECKING:
    from torch.utils.hooks import RemovableHandle

__all__ = ["DistributedDataParallel"]

logger = logging.getLogger(__name__)


@dataclass
class _MixedPrecision:
    """
    This configures DDP-native mixed precision training.

    Attributes:
        param_dtype (torch.dtype): This specifies the dtype for model
            parameters, inputs (when ``cast_forward_inputs`` is set to
            ``True``), and therefore the dtype for computation.
            However, outside the forward and backward passes, parameters are in
            full precision. Model checkpointing always happens in full
            precision.
        reduce_dtype (torch.dtype): This specifies the dtype for gradient
            reduction, which is permitted to differ from ``param_dtype``.
        buffer_dtype (torch.dtype): This specifies the dtype for buffers.

    .. note:: This API is experimental and subject to change.

    .. note:: Only floating point tensors are cast to their specified dtypes.

    .. note:: ``state_dict`` checkpoints parameters and buffers in full
        precision.

    .. note:: Each low precision dtype must be specified explicitly. For
        example, ``_MixedPrecision(reduce_dtype=torch.float16)`` only specifies
        the reduction dtype to be low precision, and DDP will not cast
        parameters or buffers.

    .. note:: If a ``reduce_dtype`` is not specified, then gradient reduction
        happens in ``param_dtype`` if specified or the original parameter dtype
        otherwise. For example, ``_MixedPrecision(param_dtype=torch.float16)``
        would result in communication occurring in fp16.
    """

    param_dtype: Optional[torch.dtype] = None
    reduce_dtype: Optional[torch.dtype] = None
    buffer_dtype: Optional[torch.dtype] = None
    # TODO (rohan-varma): keep_low_precision_grads: bool = False
    # TODO (rohan-varma): APIs to allow users to run batchnorm and layernorm
    # in full precision. For DDP, this can be implemented by not performing the
    # parameter cast for BN and LN units.


def _cast_buffers(mixed_precision_config, root_module):
    """Casts buffers to the given ``buffer_dtype``."""
    for buf in root_module.buffers():
        if hasattr(buf, "_ddp_ignored") and buf._ddp_ignored:
            continue

        buf.data = buf.to(dtype=mixed_precision_config.buffer_dtype)


def _setup_mixed_precision_params(mixed_precision_config, root_module):
    """Create and free storage for the mixed precision parameters."""
    for param in root_module.parameters():
        # Do not setup mixed precision for DDP ignored parameters.
        if hasattr(param, "_ddp_ignored") and param._ddp_ignored:
            continue

        if not hasattr(param, "_mp_param"):
            param._mp_param = torch.zeros_like(
                param,
                device=param.device,
                dtype=mixed_precision_config.param_dtype,
                requires_grad=param.requires_grad,
            )
            _free_storage(param._mp_param)
            # _fp_param will point to the full precision param so it can be switched
            # back to at the end of forward / backward.
            param._fp_param = param.data


def _tree_flatten_with_rref(output):
    output_is_rref = RPC_AVAILABLE and isinstance(output, RRef)
    if output_is_rref:
        output_tensor_list, treespec = tree_flatten(output.local_value())
    else:
        output_tensor_list, treespec = tree_flatten(output)
    # Need to return flattened tensors, spec to re-pack them, as well
    # as if the return type was actually an RRef to reconstruct.
    return output_tensor_list, treespec, output_is_rref


def _tree_unflatten_with_rref(output, treespec, output_is_rref):
    output = tree_unflatten(output, treespec)
    if output_is_rref:
        output = RRef(output)
    return output


def _find_tensors(obj):
    r"""Recursively find all tensors contained in the specified object."""
    if RPC_AVAILABLE and isinstance(obj, RRef):
        # If the current node is the owner of the RRef, unwrap it and try to
        # find Tensors.
        # TODO: Expand to remote RRefs.
        if obj.is_owner():
            return _find_tensors(obj.local_value())
    if isinstance(obj, torch.Tensor):
        return [obj]
    if isinstance(obj, (list, tuple)):
        return itertools.chain.from_iterable(map(_find_tensors, obj))
    if isinstance(obj, dict):
        return itertools.chain.from_iterable(map(_find_tensors, obj.values()))
    if is_dataclass(obj):
        return itertools.chain.from_iterable(
            map(_find_tensors, (getattr(obj, f.name) for f in fields(obj)))
        )

    return []


def _dump_DDP_relevant_env_vars():
    relevant_env_vars = [
        "RANK",
        "LOCAL_RANK",
        "WORLD_SIZE",
        "MASTER_PORT",
        "MASTER_ADDR",
        "CUDA_VISIBLE_DEVICES",
        "GLOO_SOCKET_IFNAME",
        "GLOO_DEVICE_TRANSPORT",
        "NCCL_SOCKET_IFNAME",
        "TORCH_NCCL_BLOCKING_WAIT",
        "NCCL_DEBUG",
        "NCCL_DEBUG_SUBSYS",
        "NCCL_IB_DISABLE",
        # More NCCL env vars:
        "NCCL_P2P_DISABLE",
        "NCCL_P2P_LEVEL",
        "NCCL_SHM_DISABLE",
        "NCCL_SOCKET_NTHREADS",
        "NCCL_NSOCKS_PERTHREAD",
        "NCCL_BUFFSIZE",
        "NCCL_NTHREADS",
        "NCCL_RINGS",
        "NCCL_MAX_NCHANNELS",
        "NCCL_MIN_NCHANNELS",
        "NCCL_CHECKS_DISABLE",
        "NCCL_CHECK_POINTERS",
        "NCCL_LAUNCH_MODE",
        "NCCL_IB_HCA",
        "NCCL_IB_TIMEOUT",
        "NCCL_IB_RETRY_CNT",
        "NCCL_IB_GID_INDEX",
        "NCCL_IB_SL",
        "NCCL_IB_TC",
        "NCCL_IB_AR_THRESHOLD",
        "NCCL_IB_CUDA_SUPPORT",
        "NCCL_NET_GDR_LEVEL",
        "NCCL_NET_GDR_READ",
        "NCCL_SINGLE_RING_THRESHOLD",
        "NCCL_LL_THRESHOLD",
        "NCCL_TREE_THRESHOLD",
        "NCCL_ALGO",
        "NCCL_PROTO",
        "NCCL_IGNORE_CPU_AFFINITY",
        "NCCL_DEBUG_FILE",
        "NCCL_COLLNET_ENABLE",
        "NCCL_TOPO_FILE",
        "NCCL_TOPO_DUMP_FILE",
        "TORCH_NCCL_ASYNC_ERROR_HANDLING",
    ]
    formatted_output = ""
    for var in relevant_env_vars:
        value = os.environ[var] if var in os.environ else "N/A"
        formatted_output += f"env:{var}={value}\n"
    print(formatted_output)


class _BufferCommHookLocation(Enum):
    PRE_FORWARD = auto()
    POST_FORWARD = auto()


@dataclass
class _BufferCommHook:
    buffer_comm_hook: Callable
    buffer_comm_hook_state: Any
    buffer_comm_hook_location: _BufferCommHookLocation


# Add a DDPSink to run various functions when backwards starts, such as
# queueing call back of out-most backward/graph task,
# this helps call back is fired after all gradients' calculation
# is completed.
class _DDPSink(Function):
    @staticmethod
    def forward(ctx, ddp_weakref, *inputs):
        # set_materialize_grads(False) will ensure that None gradients stay as
        # None and are not filled with zeros.
        ctx.set_materialize_grads(False)
        ctx.ddp_weakref = ddp_weakref
        ret = inputs
        if ddp_weakref()._ddp_sink_clone:
            ret = tuple(
                inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs
            )
        return ret

    @staticmethod
    def backward(ctx, *grad_outputs):
        # Enqueue delay allreduce for static graph training on the first
        # iteration.
        ddp_weakref = ctx.ddp_weakref()
        reducer = ddp_weakref.reducer
        static_graph = ddp_weakref.static_graph
        delay_ar_enqueued = (
            static_graph and ddp_weakref._static_graph_delay_allreduce_enqueued
        )
        if static_graph and not delay_ar_enqueued:
            Variable._execution_engine.queue_callback(  # type: ignore[call-arg,misc]
                reducer._delay_all_reduce
            )
            ddp_weakref._static_graph_delay_allreduce_enqueued = True

        return (None, *grad_outputs)


class _DDPJoinHook(JoinHook):
    def __init__(self, ddp, divide_by_initial_world_size):
        """Set config variables for internal usage."""
        assert isinstance(ddp, DistributedDataParallel), (
            "DDP join hook requires passing in a DistributedDataParallel "
            "instance as the state"
        )
        assert ddp.logger is not None
        ddp.logger._set_uneven_input_join()
        self.ddp = ddp
        self.ddp._divide_by_initial_world_size = divide_by_initial_world_size
        super().__init__()

    def main_hook(self):
        """Shadow the DDP collective communication operations in the forward and backward passes."""
        ddp = self.ddp
        # Buckets are rebuilt only once during a training period
        ddp.reducer._rebuild_buckets()

        # Schedule a broadcast if we are syncing module buffers in the
        # forward pass
        # TODO: make DDP uneven inputs context manager support buffer
        # comm hook (https://github.com/pytorch/pytorch/issues/65436)
        ddp._check_and_sync_module_buffers()

        # Check if need to sync in the backward pass
        should_sync_backwards = ddp._check_global_requires_backward_grad_sync(
            is_joined_rank=True
        )
        # Forward parameter sync is disabled in the next iteration if we
        # are skipping gradient sync this iteration, so set
        # `require_forward_param_sync` accordingly
        ddp.require_forward_param_sync = should_sync_backwards
        if not should_sync_backwards:
            return

        # Schedule one allreduce per gradient bucket to match the backward
        # pass allreduce
        ddp._match_all_reduce_for_bwd_pass()

        # Check if we need to allreduce locally unused parameters
        if ddp.find_unused_parameters:
            ddp._match_unused_params_allreduce()

        # Rebuilt parameters are pushed only once during a training period
        ddp.reducer._push_all_rebuilt_params()

    def post_hook(self, is_last_joiner: bool):
        """Sync the final model to ensure that the model is the same across all processes."""
        self.ddp._sync_final_model(is_last_joiner)


[docs]class DistributedDataParallel(Module, Joinable): r"""Implement distributed data parallelism based on ``torch.distributed`` at module level. This container provides data parallelism by synchronizing gradients across each model replica. The devices to synchronize across are specified by the input ``process_group``, which is the entire world by default. Note that ``DistributedDataParallel`` does not chunk or otherwise shard the input across participating GPUs; the user is responsible for defining how to do so, for example through the use of a :class:`DistributedSampler`. See also: :ref:`distributed-basics` and :ref:`cuda-nn-ddp-instead`. The same constraints on input as in :class:`torch.nn.DataParallel` apply. Creation of this class requires that ``torch.distributed`` to be already initialized, by calling :func:`torch.distributed.init_process_group`. ``DistributedDataParallel`` is proven to be significantly faster than :class:`torch.nn.DataParallel` for single-node multi-GPU data parallel training. To use ``DistributedDataParallel`` on a host with N GPUs, you should spawn up ``N`` processes, ensuring that each process exclusively works on a single GPU from 0 to N-1. This can be done by either setting ``CUDA_VISIBLE_DEVICES`` for every process or by calling: >>> # xdoctest: +SKIP("undefined variables") >>> torch.cuda.set_device(i) where i is from 0 to N-1. In each process, you should refer the following to construct this module: >>> # xdoctest: +SKIP("undefined variables") >>> torch.distributed.init_process_group( >>> backend='nccl', world_size=N, init_method='...' >>> ) >>> model = DistributedDataParallel(model, device_ids=[i], output_device=i) In order to spawn up multiple processes per node, you can use either ``torch.distributed.launch`` or ``torch.multiprocessing.spawn``. .. note:: Please refer to `PyTorch Distributed Overview <https://pytorch.org/tutorials/beginner/dist_overview.html>`__ for a brief introduction to all features related to distributed training. .. note:: ``DistributedDataParallel`` can be used in conjunction with :class:`torch.distributed.optim.ZeroRedundancyOptimizer` to reduce per-rank optimizer states memory footprint. Please refer to `ZeroRedundancyOptimizer recipe <https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html>`__ for more details. .. note:: ``nccl`` backend is currently the fastest and highly recommended backend when using GPUs. This applies to both single-node and multi-node distributed training. .. note:: This module also supports mixed-precision distributed training. This means that your model can have different types of parameters such as mixed types of ``fp16`` and ``fp32``, the gradient reduction on these mixed types of parameters will just work fine. .. note:: If you use ``torch.save`` on one process to checkpoint the module, and ``torch.load`` on some other processes to recover it, make sure that ``map_location`` is configured properly for every process. Without ``map_location``, ``torch.load`` would recover the module to devices where the module was saved from. .. note:: When a model is trained on ``M`` nodes with ``batch=N``, the gradient will be ``M`` times smaller when compared to the same model trained on a single node with ``batch=M*N`` if the loss is summed (NOT averaged as usual) across instances in a batch (because the gradients between different nodes are averaged). You should take this into consideration when you want to obtain a mathematically equivalent training process compared to the local training counterpart. But in most cases, you can just treat a DistributedDataParallel wrapped model, a DataParallel wrapped model and an ordinary model on a single GPU as the same (E.g. using the same learning rate for equivalent batch size). .. note:: Parameters are never broadcast between processes. The module performs an all-reduce step on gradients and assumes that they will be modified by the optimizer in all processes in the same way. Buffers (e.g. BatchNorm stats) are broadcast from the module in process of rank 0, to all other replicas in the system in every iteration. .. note:: If you are using DistributedDataParallel in conjunction with the :ref:`distributed-rpc-framework`, you should always use :meth:`torch.distributed.autograd.backward` to compute gradients and :class:`torch.distributed.optim.DistributedOptimizer` for optimizing parameters. Example:: >>> # xdoctest: +SKIP("undefined variables") >>> import torch.distributed.autograd as dist_autograd >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> import torch >>> from torch import optim >>> from torch.distributed.optim import DistributedOptimizer >>> import torch.distributed.rpc as rpc >>> from torch.distributed.rpc import RRef >>> >>> t1 = torch.rand((3, 3), requires_grad=True) >>> t2 = torch.rand((3, 3), requires_grad=True) >>> rref = rpc.remote("worker1", torch.add, args=(t1, t2)) >>> ddp_model = DDP(my_model) >>> >>> # Setup optimizer >>> optimizer_params = [rref] >>> for param in ddp_model.parameters(): >>> optimizer_params.append(RRef(param)) >>> >>> dist_optim = DistributedOptimizer( >>> optim.SGD, >>> optimizer_params, >>> lr=0.05, >>> ) >>> >>> with dist_autograd.context() as context_id: >>> pred = ddp_model(rref.to_here()) >>> loss = loss_func(pred, target) >>> dist_autograd.backward(context_id, [loss]) >>> dist_optim.step(context_id) .. note:: DistributedDataParallel currently offers limited support for gradient checkpointing with :meth:`torch.utils.checkpoint`. If the checkpoint is done with use_reentrant=False (recommended), DDP will work as expected without any limitations. If, however, the checkpoint is done with use_reentrant=True (the default), DDP will work as expected when there are no unused parameters in the model and each layer is checkpointed at most once (make sure you are not passing `find_unused_parameters=True` to DDP). We currently do not support the case where a layer is checkpointed multiple times, or when there unused parameters in the checkpointed model. .. note:: To let a non-DDP model load a state dict from a DDP model, :meth:`~torch.nn.modules.utils.consume_prefix_in_state_dict_if_present` needs to be applied to strip the prefix "module." in the DDP state dict before loading. .. warning:: Constructor, forward method, and differentiation of the output (or a function of the output of this module) are distributed synchronization points. Take that into account in case different processes might be executing different code. .. warning:: This module assumes all parameters are registered in the model by the time it is created. No parameters should be added nor removed later. Same applies to buffers. .. warning:: This module assumes all parameters are registered in the model of each distributed processes are in the same order. The module itself will conduct gradient ``allreduce`` following the reverse order of the registered parameters of the model. In other words, it is users' responsibility to ensure that each distributed process has the exact same model and thus the exact same parameter registration order. .. warning:: This module allows parameters with non-rowmajor-contiguous strides. For example, your model may contain some parameters whose :class:`torch.memory_format` is ``torch.contiguous_format`` and others whose format is ``torch.channels_last``. However, corresponding parameters in different processes must have the same strides. .. warning:: This module doesn't work with :func:`torch.autograd.grad` (i.e. it will only work if gradients are to be accumulated in ``.grad`` attributes of parameters). .. warning:: If you plan on using this module with a ``nccl`` backend or a ``gloo`` backend (that uses Infiniband), together with a DataLoader that uses multiple workers, please change the multiprocessing start method to ``forkserver`` (Python 3 only) or ``spawn``. Unfortunately Gloo (that uses Infiniband) and NCCL2 are not fork safe, and you will likely experience deadlocks if you don't change this setting. .. warning:: You should never try to change your model's parameters after wrapping up your model with ``DistributedDataParallel``. Because, when wrapping up your model with ``DistributedDataParallel``, the constructor of ``DistributedDataParallel`` will register the additional gradient reduction functions on all the parameters of the model itself at the time of construction. If you change the model's parameters afterwards, gradient reduction functions no longer match the correct set of parameters. .. warning:: Using ``DistributedDataParallel`` in conjunction with the :ref:`distributed-rpc-framework` is experimental and subject to change. Args: module (Module): module to be parallelized device_ids (list of int or torch.device): CUDA devices. 1) For single-device modules, ``device_ids`` can contain exactly one device id, which represents the only CUDA device where the input module corresponding to this process resides. Alternatively, ``device_ids`` can also be ``None``. 2) For multi-device modules and CPU modules, ``device_ids`` must be ``None``. When ``device_ids`` is ``None`` for both cases, both the input data for the forward pass and the actual module must be placed on the correct device. (default: ``None``) output_device (int or torch.device): Device location of output for single-device CUDA modules. For multi-device modules and CPU modules, it must be ``None``, and the module itself dictates the output location. (default: ``device_ids[0]`` for single-device modules) broadcast_buffers (bool): Flag that enables syncing (broadcasting) buffers of the module at beginning of the ``forward`` function. (default: ``True``) process_group: The process group to be used for distributed data all-reduction. If ``None``, the default process group, which is created by :func:`torch.distributed.init_process_group`, will be used. (default: ``None``) bucket_cap_mb: ``DistributedDataParallel`` will bucket parameters into multiple buckets so that gradient reduction of each bucket can potentially overlap with backward computation. :attr:`bucket_cap_mb` controls the bucket size in MebiBytes (MiB). If ``None``, a default size of 25 MiB will be used. (default: ``None``) find_unused_parameters (bool): Traverse the autograd graph from all tensors contained in the return value of the wrapped module's ``forward`` function. Parameters that don't receive gradients as part of this graph are preemptively marked as being ready to be reduced. In addition, parameters that may have been used in the wrapped module's ``forward`` function but were not part of loss computation and thus would also not receive gradients are preemptively marked as ready to be reduced. (default: ``False``) check_reduction: This argument is deprecated. gradient_as_bucket_view (bool): When set to ``True``, gradients will be views pointing to different offsets of ``allreduce`` communication buckets. This can reduce peak memory usage, where the saved memory size will be equal to the total gradients size. Moreover, it avoids the overhead of copying between gradients and ``allreduce`` communication buckets. When gradients are views, ``detach_()`` cannot be called on the gradients. If hitting such errors, please fix it by referring to the :meth:`~torch.optim.Optimizer.zero_grad` function in ``torch/optim/optimizer.py`` as a solution. Note that gradients will be views after first iteration, so the peak memory saving should be checked after first iteration. static_graph (bool): When set to ``True``, DDP knows the trained graph is static. Static graph means 1) The set of used and unused parameters will not change during the whole training loop; in this case, it does not matter whether users set ``find_unused_parameters = True`` or not. 2) How the graph is trained will not change during the whole training loop (meaning there is no control flow depending on iterations). When static_graph is set to be ``True``, DDP will support cases that can not be supported in the past: 1) Reentrant backwards. 2) Activation checkpointing multiple times. 3) Activation checkpointing when model has unused parameters. 4) There are model parameters that are outside of forward function. 5) Potentially improve performance when there are unused parameters, as DDP will not search graph in each iteration to detect unused parameters when static_graph is set to be ``True``. To check whether you can set static_graph to be ``True``, one way is to check ddp logging data at the end of your previous model training, if ``ddp_logging_data.get("can_set_static_graph") == True``, mostly you can set ``static_graph = True`` as well. Example:: >>> # xdoctest: +SKIP("undefined variables") >>> model_DDP = torch.nn.parallel.DistributedDataParallel(model) >>> # Training loop >>> ... >>> ddp_logging_data = model_DDP._get_ddp_logging_data() >>> static_graph = ddp_logging_data.get("can_set_static_graph") delay_all_reduce_named_params (list of tuple of str and torch.nn.Parameter): a list of named parameters whose all reduce will be delayed when the gradient of the parameter specified in ``param_to_hook_all_reduce`` is ready. Other arguments of DDP do not apply to named params specified in this argument as these named params will be ignored by DDP reducer. param_to_hook_all_reduce (torch.nn.Parameter): a parameter to hook delayed all reduce of parameters specified in ``delay_all_reduce_named_params``. Attributes: module (Module): the module to be parallelized. Example:: >>> # xdoctest: +SKIP("undefined variables") >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...') >>> net = torch.nn.parallel.DistributedDataParallel(model) """ # used to track whether the given thread is inside ddp forward for torchdynamo purposes _active_ddp_module: Optional["DistributedDataParallel"] = None def __init__( self, module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=None, find_unused_parameters=False, check_reduction=False, gradient_as_bucket_view=False, static_graph=False, delay_all_reduce_named_params=None, param_to_hook_all_reduce=None, mixed_precision: Optional[_MixedPrecision] = None, device_mesh=None, ): super().__init__() Joinable.__init__(self) self.logger = None if bool(delay_all_reduce_named_params is not None) != bool( param_to_hook_all_reduce is not None ): self._log_and_throw( ValueError, "delay_all_reduce_named_params and param_to_hook_all_reduce " "need to be set at the same time.", ) if process_group and device_mesh is not None: raise RuntimeError( "Cannot specify both process_group and device_mesh arguments." ) elif process_group is None and device_mesh is None: self.process_group = _get_default_group() elif device_mesh is None: self.process_group = process_group else: if device_mesh.ndim != 1: raise RuntimeError( f"Only 1D device mesh is supported, but got {device_mesh}." ) self.device_mesh = device_mesh self.process_group = device_mesh.get_group(mesh_dim=0) from torch.distributed.device_mesh import _mesh_resources if _mesh_resources.get_parent_mesh(device_mesh) is not None: # TODO: This is a temporary work around to enable DDP + TP. # We should do the logic in DDP so that the 2D implementation is # sound and the state_dict works out of the box. # This has to be done before check UninitializedParameter. from torch.distributed.tensor.parallel.ddp import ( _pre_dp_module_transform, ) _pre_dp_module_transform(module) self._delay_all_reduce_params = [] if hasattr(module, "_ddp_params_and_buffers_to_ignore"): self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore) else: self.parameters_to_ignore = set() if delay_all_reduce_named_params is not None: for name, param in delay_all_reduce_named_params: self.parameters_to_ignore.add(name) self._delay_all_reduce_params.append(param) self._module_parameters = [ p for n, p in module.named_parameters() if n not in self.parameters_to_ignore ] if not any(p.requires_grad for p in self._module_parameters): if len(self._delay_all_reduce_params): logger.info("Delay the AllReduce of all parameters.") else: self._log_and_throw( RuntimeError, "DistributedDataParallel is not needed when a module " "doesn't have any parameter that requires a gradient.", ) if device_ids is not None and len(device_ids) > 1: self._log_and_throw( ValueError, "device_ids can only be None or contain a single element.", ) self.is_multi_device_module = ( len({p.device for p in self._module_parameters}) > 1 ) distinct_device_types = { p.device.type for p in self._module_parameters if p.device is not None } if len(distinct_device_types) != 1: self._log_and_throw( ValueError, "DistributedDataParallel's input module must be on " f"the same type of devices, but input module parameters locate in {distinct_device_types}.", ) self.device_type = next(iter(distinct_device_types)) if ( device_ids is None or len(device_ids) == 0 # For backward compatibility. or self.device_type == "cpu" or self.is_multi_device_module ): if device_ids or output_device: self._log_and_throw( ValueError, "DistributedDataParallel device_ids and output_device arguments " "only work with single-device/multiple-device GPU modules or CPU modules, " f"but got device_ids {device_ids}, output_device {output_device}, " f"and module parameters {({p.device for p in self._module_parameters})}.", ) self.device_ids = None self.output_device = None else: self.device_ids = [_get_device_index(x, True) for x in device_ids] if output_device is None: output_device = device_ids[0] self.output_device = _get_device_index(output_device, True) self.static_graph = False self.dim = dim self.module = module self.device = next(iter(self._module_parameters)).device self.broadcast_buffers = broadcast_buffers self.find_unused_parameters = find_unused_parameters self.require_backward_grad_sync = True self.require_forward_param_sync = True self.gradient_as_bucket_view = gradient_as_bucket_view self.mixed_precision = mixed_precision if self.mixed_precision is not None: logger.warning("Received mixed precision config %s", self.mixed_precision) if check_reduction: # This argument is no longer used since the reducer # will ensure reduction completes even if some parameters # do not receive gradients. warnings.warn( "The `check_reduction` argument in `DistributedDataParallel` " "module is deprecated. Please avoid using it.", FutureWarning, stacklevel=2, ) # Check that a module does not have Uninitialized parameters for param in self._module_parameters: if isinstance(param, torch.nn.parameter.UninitializedParameter): self._log_and_throw( RuntimeError, "Modules with uninitialized parameters can't be used with `DistributedDataParallel`. " "Run a dummy forward pass to correctly initialize the modules", ) # used for intra-node param sync and inter-node sync as well self.broadcast_bucket_size = int(250 * 1024 * 1024) # reduction bucket size if bucket_cap_mb is None: # default case (bucket cap is 25 MiB) bucket_cap_mb = 25 self.bucket_bytes_cap_default = True else: self.bucket_bytes_cap_default = False self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024) # Whether to perform input tensor CPU to GPU copies on a side-stream self.use_side_stream_for_tensor_copies = ( os.environ.get("PYTORCH_DDP_USE_SIDE_STREAM", "1") == "1" ) # Initialize gradient buffers and register all reduce hook self._delay_grad_buffer = None self._delay_grad_views: List[torch.Tensor] = [] self._delay_all_reduce_all_params = False if len(self._delay_all_reduce_params) != 0: self._register_delay_all_reduce_hook( bucket_cap_mb=bucket_cap_mb, param_to_hook_all_reduce=param_to_hook_all_reduce, device_ids=device_ids, ) if self._delay_all_reduce_all_params: return # Build parameters for reducer. parameters, expect_sparse_gradient = self._build_params_for_reducer() # Verify model equivalence. _verify_param_shape_across_processes(self.process_group, parameters) # Sync params and buffers. Ensures all DDP models start off at the same value. _sync_module_states( module=self.module, process_group=self.process_group, broadcast_bucket_size=self.broadcast_bucket_size, src=0, params_and_buffers_to_ignore=self.parameters_to_ignore, broadcast_buffers=self.broadcast_buffers, ) # In debug mode, build a mapping of parameter index -> parameter. param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters) # Builds reducer. self._ddp_init_helper( parameters, expect_sparse_gradient, param_to_name_mapping, static_graph, ) self._comm_hooks: List[Tuple[Callable, object]] = [] if self.mixed_precision is not None: _setup_mixed_precision_params(self.mixed_precision, self.module) _cast_buffers(self.mixed_precision, self.module) # Stream used for async low precision copies. self._mp_stream = torch.cuda.Stream() self._submodule_to_event = defaultdict(deque) # type: ignore[var-annotated] # Add forward pre-hook to root module to kick off copies to lower # precision. self.module.register_forward_pre_hook( self._root_copy_hook, prepend=False, with_kwargs=True ) # Add forward pre hook to all submodules to wait for copy events # before running computation. for module in self.module.modules(): module.register_forward_pre_hook( self._module_wait_for_copy_hook, prepend=False, with_kwargs=True, ) # Set up callbacks in backward to upcast and use full precision # params. TODO (rohan-varma): Make this compose with general # comm hooks and apply_optimizer_in_backward. Importing inline to # avoid circular import issue. from torch.distributed.algorithms.ddp_comm_hooks.mixed_precision_hooks import ( _AllreduceUpcastHookState, _reducer_allreduce_and_upcast_hook, ) upcast_hook_state = _AllreduceUpcastHookState( ddp_weakref=weakref.ref(self), upcast_stream=torch.cuda.Stream(), ) self.register_comm_hook( upcast_hook_state, _reducer_allreduce_and_upcast_hook, ) # Inform reducer of reduced precision param dtype for correctness # of type checks between gradient and bucket. self.reducer._set_mixed_precision_param_dtype( # type: ignore[attr-defined] self.mixed_precision.param_dtype ) self._has_rebuilt_buckets = False if static_graph: self._set_static_graph() self._lazy_init_ran = False # Register the AccumulateGrad post hooks if optimize_ddp is # True. The hooks will be deregistered if compiled_autograd is not # enabled. self._accum_grad_hooks: List[RemovableHandle] = [] optimize_ddp = torch._dynamo.config._get_optimize_ddp_mode() self._use_python_reducer = optimize_ddp in ( "python_reducer", "python_reducer_without_compiled_forward", ) if self._use_python_reducer: torch._inductor.config._fuse_ddp_communication = True torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb # Directly adding this to the trace rule will disturb the users # who are using DDPOptimizer. torch._dynamo.trace_rules.LEGACY_MOD_INLINELIST.add( "torch.nn.parallel.distributed" ) torch._dynamo.trace_rules.get_legacy_mod_inlinelist.cache_clear() self._force_to_disable_cpp_reducer = ( optimize_ddp == "python_reducer_without_compiled_forward" ) if self._use_python_reducer: self._register_accum_grad_hook() # Whether or not DDPSink performs a clone. self._ddp_sink_clone = True def _register_accum_grad_hook(self): import torch.distributed._functional_collectives as fcol def compiled_accum_grad_hook( param, *, param_index: int, ): if not self.require_backward_grad_sync: return if param.grad is None: return if self._comm_hooks: for hook, state in self._comm_hooks: hook(state, (param.grad, param)) else: gradient = param.grad / self.process_group.size() gradient = fcol.all_reduce(gradient, "sum", self.process_group) param.grad.copy_(gradient) for index, param in enumerate(self._module_parameters): if not param.requires_grad: continue self._accum_grad_hooks.append( param.register_post_accumulate_grad_hook( functools.partial( compiled_accum_grad_hook, param_index=index, ) ) ) def _delayed_all_reduce_hook(self, grad): world_size = dist.get_world_size(self.process_group) self._delay_grad_buffer.div_(world_size) # type: ignore[union-attr] _ = dist.all_reduce( self._delay_grad_buffer, group=self.process_group, async_op=True ) return grad def _register_delay_all_reduce_hook( self, bucket_cap_mb, param_to_hook_all_reduce, device_ids, ): # 1. Create gradient buffer device = torch.device("cpu") if device_ids is None else device_ids[0] self._delay_grad_buffer = torch.zeros( sum(p.numel() for p in self._delay_all_reduce_params), device=device, ) # 2. Broadcast the parameters detached_params = [p.detach() for p in self._delay_all_reduce_params] dist._broadcast_coalesced(self.process_group, detached_params, bucket_cap_mb, 0) # 3. Hook all reduce to the specified parameter param_to_hook_all_reduce.register_hook(self._delayed_all_reduce_hook) # 4. Build tensor views for gradients offset = 0 for param in self._delay_all_reduce_params: grad_view = self._delay_grad_buffer[offset : (offset + param.numel())].view( param.shape ) self._delay_grad_views.append(grad_view) offset = offset + param.numel() # 5. Check whether the all reduce of all params requiring grad is delayed. for module_name, module in self.module.named_modules(): for param_name, param in module.named_parameters(recurse=False): if param.requires_grad: full_name = f"{module_name}.{param_name}" if full_name not in self.parameters_to_ignore: # There is at least a param whose all reduce will not be delayed. # In this case, we should not set self._delay_all_reduce_all_params # to True. return self._delay_all_reduce_all_params = True def _setup_in_backward_optimizers(self): # Check if user has used apply_optim_in_backward to overlap optimizer # step + DDP backward. Current constraints: # 1. Only allreduce is supported at the moment, no custom communication. # 2. For DDP-managed parameters that have their optimizer run in # backward, their gradients are set to ``None``. If your use case # requires DDP parameters grad not to be set to ``None`` after their # in-backward optimizer runs, please ping # https://github.com/pytorch/pytorch/issues/90052. # NOTE: we use self._module_parameters instead of .parameters() since # the former excludes ignored (non-DDP managed) parameters. if any(hasattr(p, "_in_backward_optimizers") for p in self._module_parameters): torch._C._log_api_usage_once("ddp.optimizer_in_backward") # Remove hooks that apply_optim_in_backward had registered because # DDP customizes how optimizer is overlapped with backward due to # the allreduce. param_to_handle_map = ( dist.optim.apply_optimizer_in_backward.param_to_optim_hook_handle_map ) for p in self._module_parameters: for handle in param_to_handle_map.get(p, []): handle.remove() # Need a weakref to DDP instance to run all_reduce (from reducer) # and get managed DDP parameters. ddp_weakref = weakref.ref(self) # Note: importing in function, otherwise this will cause a circular # import. from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import ( _apply_optim_in_backward_hook, ) self.register_comm_hook( ddp_weakref, _apply_optim_in_backward_hook( gradient_is_bucket_view=self.gradient_as_bucket_view ), ) self.reducer._set_optimizer_in_backward() # type: ignore[attr-defined] def _fire_reducer_autograd_hook(self, idx, *unused): """ Fire the reducer's autograd hook to allreduce params in a Reducer bucket. Note that this is only used during mixed precision training as the Reducer's hooks installed during construction time would not be called as we're working in the low precision parameter setting. """ self.reducer._autograd_hook(idx) # type: ignore[attr-defined] def _root_copy_hook(self, *args: Any, **kwargs: Any) -> None: """ For DDP mixed precision, put low precision copies on separate stream and create events to wait for them. When training with DDP mixed precision, this root pre-forward hook kicks off low precision copies on a separate stream and creates respective events to wait for them. """ # Clear out previous iteration submodule to event. This is because we # may have populated some events for modules that didn't end up being # used. self._submodule_to_event = defaultdict(deque) # type: ignore[var-annotated] with torch.cuda.stream(self._mp_stream): for submodule in self.module.modules(): for param in submodule.parameters(recurse=False): # Do not cast DDP ignored parameters. if hasattr(param, "_ddp_ignored") and param._ddp_ignored: continue _alloc_storage(param._mp_param, param.size()) # copy() implicitly casts to low precision with torch.no_grad(): param._mp_param.copy_(param.data) # TODO: when zero_grad(set_to_none=False) or in grad # accumulation case, accumulated grads can be in fp32 # which can cause errors when running DDP backwards due # to mismatched incoming and accumulated gradient types. # So we manually cast the accumulated grad down for now, # in the future we may shift to FSDP style gradient # accumulation management where the accumulated gradient # is saved and .grad field is set to None, bypassing # this issue. if param.grad is not None: param.grad.data = param.grad.to( self.mixed_precision.param_dtype # type: ignore[union-attr] ) param.data = param._mp_param copy_event = torch.cuda.Event() copy_event.record() self._submodule_to_event[submodule].append(copy_event) def _module_wait_for_copy_hook( self, module, *args: Any, **kwargs: Any, ) -> None: """Before carrying out computation, wait on the appropriate event to ensure low precision copies have finished.""" try: event = self._submodule_to_event[module].popleft() except IndexError: # copy event has already been waited on return event.wait(stream=torch.cuda.current_stream()) for p in module.parameters(recurse=False): # Don't register hooks if param does not require grad if not p.requires_grad or (hasattr(p, "_ddp_ignored") and p._ddp_ignored): continue # We need to register autograd hook here instead of DDP's ctor # since we're working with the low precision param. Register them # via obtaining the gradient accumulator. tmp = p.expand_as(p) grad_acc = tmp.grad_fn.next_functions[0][0] hook = grad_acc.register_hook( functools.partial(self._fire_reducer_autograd_hook, p._idx) ) p._ddp_mp_hook_state = (grad_acc, hook) def _log_and_throw(self, err_type, err_msg): if self.logger is not None: self.logger.set_error_and_log(f"{str(err_type)}: {err_msg}") raise err_type(err_msg) def _ddp_init_helper( self, parameters, expect_sparse_gradient, param_to_name_mapping, static_graph, ): """ DDP init helper function to manage parameters, grad hooks, logging, and SyncBatchNorm. Initialization helper function that does the following: (1) bucketing the parameters for reductions (2) resetting the bucketing states (3) registering the grad hooks (4) Logging construction-time DDP logging data (5) passing a handle of DDP to SyncBatchNorm Layer """ # Notice, the parameters order is not in the order in which they are used, # especially in models with control flow. # # Alongside parameters are not presented in the real execution order, # if a certain model happens to also # 1) have other collectives comm ops in its backward graph. # 2) have unused parameter in subset ranks of the whole world. # bucketing could insert ALL-REDUCE comm op too early on the rank with unused parameter, # matching up with other collectives comm ops on other ranks unexpectedly. # # In order to handle this corner case, when the parameters are not in the real execution order, # we don't do bucketing, thus only one ALL-REDUCE is inserted after all the gradients # of the whole graph are computed. # # Notice, here we only disable bucketing for the first iteration. # After the first iteration, it's OK to rebuild buckets, # because "bucket rebuild" bucketizes parameters based on its real execution order in backward graph. # Can remove this branching once #73732 is landed. if static_graph is True or self.find_unused_parameters is False: bucket_size_limits = [sys.maxsize] else: if self.bucket_bytes_cap_default: bucket_size_limits = [ dist._DEFAULT_FIRST_BUCKET_BYTES, self.bucket_bytes_cap, ] else: bucket_size_limits = [self.bucket_bytes_cap] ( bucket_indices, per_bucket_size_limits, ) = dist._compute_bucket_assignment_by_size( parameters, bucket_size_limits, expect_sparse_gradient, ) # Remember index for parameters if we are in mixed precision, as we # need to pass in index to Reducer's autograd hook via python. if self.mixed_precision is not None: for i, p in enumerate(parameters): p._idx = i # Note: reverse list of buckets because we want to approximate the # order in which their gradients are produced, and assume they # are used in the forward pass in the order they are defined. self.reducer = dist.Reducer( parameters, list(reversed(bucket_indices)), list(reversed(per_bucket_size_limits)), self.process_group, expect_sparse_gradient, # The bucket size limit is specified in the constructor. # Additionally, we allow for a single small bucket for parameters # that are defined first, such that their gradients don't spill into # a much larger bucket, adding unnecessary latency after gradient # computation finishes. Experiments showed 1MB is a reasonable value. self.bucket_bytes_cap, self.find_unused_parameters, self.gradient_as_bucket_view, param_to_name_mapping, # User can set dist._DEFAULT_FIRST_BUCKET_BYTES to tune DDP first # bucket. dist._DEFAULT_FIRST_BUCKET_BYTES if self.bucket_bytes_cap_default else self.bucket_bytes_cap, ) self.logger = dist.Logger(self.reducer) # Set as a weak reference to avoid reference cycle between # logger and reducer. self.reducer.set_logger(self.logger) has_sync_bn = False for submodule in self.module.modules(): if isinstance(submodule, torch.nn.SyncBatchNorm): has_sync_bn = True break # Set logging data that can be got during construction time. self.logger.set_construction_data_and_log( self.module.__class__.__name__, [] if self.device_ids is None else self.device_ids, -1 if self.output_device is None else self.output_device, self.broadcast_buffers, has_sync_bn, static_graph, ) # passing a handle to torch.nn.SyncBatchNorm layer self._passing_sync_batchnorm_handle(self.module) def __getstate__(self): self._check_default_group() attrs = copy.copy(self.__dict__) del attrs["process_group"] del attrs["reducer"] del attrs["logger"] return attrs def __setstate__(self, state): # If serializable, then the process group should be the default one self.process_group = _get_default_group() super().__setstate__(state) self.__dict__.setdefault("require_forward_param_sync", True) self.__dict__.setdefault("require_backward_grad_sync", True) parameters, expect_sparse_gradient = self._build_params_for_reducer() # In debug mode, build a mapping of parameter index -> parameter. param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters) # Builds reducer. self._ddp_init_helper( parameters, expect_sparse_gradient, param_to_name_mapping, self.static_graph, ) if self.static_graph: self.reducer._set_static_graph() assert self.logger is not None self.logger._set_static_graph() def _build_params_for_reducer(self): # Build tuple of (module, parameter) for all parameters that require grads. modules_and_parameters = [ (module, parameter) for module_name, module in self.module.named_modules() for parameter in [ param # Note that we access module.named_parameters instead of # parameters(module). parameters(module) is only needed in the # single-process multi device case, where it accesses replicated # parameters through _former_parameters. for param_name, param in module.named_parameters(recurse=False) if param.requires_grad and f"{module_name}.{param_name}" not in self.parameters_to_ignore ] ] # Deduplicate any parameters that might be shared across child modules. memo = set() modules_and_parameters = [ # "p not in memo" is the deduplication check. # "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed. (m, p) for m, p in modules_and_parameters if p not in memo and not memo.add(p) # type: ignore[func-returns-value] ] # Build list of parameters. parameters = [parameter for _, parameter in modules_and_parameters] # Checks if a module will produce a sparse gradient. def produces_sparse_gradient(module): if isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)): return module.sparse return False # Build list of booleans indicating whether or not to expect sparse # gradients for the corresponding parameters. expect_sparse_gradient = [ produces_sparse_gradient(module) for module, _ in modules_and_parameters ] self._assign_modules_buffers() return parameters, expect_sparse_gradient def _assign_modules_buffers(self): """ Assign self.module.named_buffers to self.modules_buffers. Assigns module buffers to self.modules_buffers which are then used to broadcast across ranks when broadcast_buffers=True. Note that this must be called every time buffers need to be synced because buffers can be reassigned by user module, see https://github.com/pytorch/pytorch/issues/63916. """ # Collect buffers for modules, filtering out buffers that should be ignored. named_module_buffers = [ (buffer, buffer_name) for buffer_name, buffer in self.module.named_buffers() if buffer_name not in self.parameters_to_ignore ] self.modules_buffers = [ buffer for (buffer, buffer_name) in named_module_buffers ] # Dict[str, tensor] representing module buffers not ignored by DDP. self.named_module_buffers = { buffer_name: buffer for (buffer, buffer_name) in named_module_buffers } def _build_debug_param_to_name_mapping(self, parameters): param_to_param_index = {parameters[i]: i for i in range(len(parameters))} param_set = set(parameters) param_index_to_param_fqn = {} for module_name, module in self.module.named_modules(): for param_name, param in module.named_parameters(recurse=False): fqn = f"{module_name}.{param_name}" # Bypass ignored parameters since those are not reduced by DDP # to begin with. if fqn not in self.parameters_to_ignore and param.requires_grad: if param not in param_set: self._log_and_throw( ValueError, f"Param with name {fqn} found in module parameters, but not DDP parameters." " This indicates a bug in DDP, please report an issue to PyTorch.", ) param_index = param_to_param_index[param] param_index_to_param_fqn[param_index] = fqn # Ensure we covered all parameters if len(param_set) != len(param_index_to_param_fqn): self._log_and_throw( ValueError, ( "Expected param to name mapping to cover all parameters, but" f" got conflicting lengths: {len(param_set)} vs " f"{len(param_index_to_param_fqn)}. This indicates a bug in DDP" ", please report an issue to PyTorch." ), ) return param_index_to_param_fqn def _get_parameters(self, m, recurse=True): """Return a generator of module parameters.""" def model_parameters(m): ps = ( m._former_parameters.values() if hasattr(m, "_former_parameters") else m.parameters(recurse=False) ) yield from ps for mod in m.modules() if recurse else [m]: yield from model_parameters(mod) def _check_default_group(self): pickle_not_supported = False try: if self.process_group != _get_default_group(): pickle_not_supported = True except RuntimeError: pickle_not_supported = True if pickle_not_supported: self._log_and_throw( RuntimeError, "DDP Pickling/Unpickling are only supported " "when using DDP with the default process " "group. That is, when you have called " "init_process_group and have not passed " "process_group argument to DDP constructor", )
[docs] @contextmanager def no_sync(self): r""" Context manager to disable gradient synchronizations across DDP processes. Within this context, gradients will be accumulated on module variables, which will later be synchronized in the first forward-backward pass exiting the context. Example:: >>> # xdoctest: +SKIP("undefined variables") >>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg) >>> with ddp.no_sync(): >>> for input in inputs: >>> ddp(input).backward() # no synchronization, accumulate grads >>> ddp(another_input).backward() # synchronize grads .. warning:: The forward pass should be included inside the context manager, or else gradients will still be synchronized. """ old_require_backward_grad_sync = self.require_backward_grad_sync self.require_backward_grad_sync = False try: yield finally: self.require_backward_grad_sync = old_require_backward_grad_sync
@classmethod def _get_active_ddp_module(cls): """`TorchDynamo` requires DDP's status and module for cooperative optimization.""" return cls._active_ddp_module # note, this ctxmgr function is marked 'skip' in torchdynamo, so dynamo only kicks in # for the 'module_to_run' underneath # see torch._dynamo/eval_frame.py TorchPatcher.patch for more details @contextmanager @torch._disable_dynamo(recursive=False) def _inside_ddp_forward(self): DistributedDataParallel._active_ddp_module = self try: yield finally: DistributedDataParallel._active_ddp_module = None def _run_ddp_forward(self, *inputs, **kwargs): if self._use_python_reducer: return self.module(*inputs, **kwargs) # type: ignore[index] else: with self._inside_ddp_forward(): return self.module(*inputs, **kwargs) # type: ignore[index] def _clear_grad_buffer(self): # Making param.grad points to the grad buffers before backward is based on the # assumption that the grad accumulation is done in place in autograd engine, # for some edge cases, if the grad accumulation in autograd engine is not in # place, then the param.grad and grad buffers are detached. if self._delay_grad_buffer is not None: # We batch zero_grad for all params by resetting the whole grad # buffer when the grad of all params is set to None. all_param_grad_none = all( param.grad is None for param in self._delay_all_reduce_params ) for index, param in enumerate(self._delay_all_reduce_params): if param.grad is None: param.grad = self._delay_grad_views[index] if not all_param_grad_none: param.grad.zero_() if all_param_grad_none: self._delay_grad_buffer.zero_() def _lazy_init(self): # Initialization for DDP that occurs after construction, but lazily # before the first forward pass. self._setup_in_backward_optimizers() self._lazy_init_ran = True def _should_disable_cpp_reducer(self) -> bool: return self._use_python_reducer and ( torch._utils.is_compiling() or self._force_to_disable_cpp_reducer ) def _pre_forward(self, *inputs, **kwargs): if self._should_disable_cpp_reducer(): return inputs, kwargs # Disable the python reducer if compiled_autograd is not enabled. if self._accum_grad_hooks: for index, h in enumerate(self._accum_grad_hooks): h.remove() self._accum_grad_hooks.clear() if not self._lazy_init_ran and not torch._utils.is_compiling(): self._lazy_init() if self._delay_all_reduce_all_params: return inputs, kwargs if torch.is_grad_enabled() and self.require_backward_grad_sync: assert self.logger is not None self.logger.set_runtime_stats_and_log() self.reducer.prepare_for_forward() # Notify the join context that this process has not joined, if # needed work = Join.notify_join_context(self) if work: self.reducer._set_forward_pass_work_handle( work, self._divide_by_initial_world_size # type: ignore[arg-type] ) # Calling _rebuild_buckets before forward computation, # It may allocate new buckets before deallocating old buckets # inside _rebuild_buckets. To save peak memory usage, # call _rebuild_buckets before the peak memory usage increases # during forward computation. # This should be called only once during whole training period. if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): logger.info("Reducer buckets have been rebuilt in this iteration.") self._has_rebuilt_buckets = True # sync params according to location (before/after forward) user # specified as part of hook, if hook was specified. if self._check_sync_bufs_pre_fwd(): self._sync_buffers() if self._join_config.enable: # Notify joined ranks whether they should sync in backwards pass or not. self._check_global_requires_backward_grad_sync(is_joined_rank=False) if self.device_ids: moved_inputs, moved_kwargs = _to_kwargs( inputs, kwargs, torch.device(self.device_type, self.device_ids[0]), self.use_side_stream_for_tensor_copies, ) args, kwargs = moved_inputs[0], moved_kwargs[0] # Cast inputs to reduced precision if needed. if self.mixed_precision is not None: args, kwargs = _cast_forward_inputs( self.mixed_precision.param_dtype, *args, **kwargs, ) return args, kwargs else: # Cast inputs to reduced precision if needed. # TODO (rohan-varma) test this codepath. if self.mixed_precision is not None: inputs, kwargs = _cast_forward_inputs( self.mixed_precision.param_dtype, *inputs, **kwargs, ) return inputs, kwargs def _post_forward(self, output): if self._should_disable_cpp_reducer(): return output if self._delay_all_reduce_all_params: self._clear_grad_buffer() return output # sync params according to location (before/after forward) user # specified as part of hook, if hook was specified. if self._check_sync_bufs_post_fwd(): self._sync_buffers() if torch.is_grad_enabled() and self.require_backward_grad_sync: self.require_forward_param_sync = True # We'll return the output object verbatim since it is a freeform # object. We need to find any tensors in this object, though, # because we need to figure out which parameters were used during # this forward pass, to ensure we short circuit reduction for any # unused parameters. Only if `find_unused_parameters` is set. if self.find_unused_parameters and not self.static_graph: # Do not need to populate this for static graph. self.reducer.prepare_for_backward(list(_find_tensors(output))) else: self.reducer.prepare_for_backward([]) else: self.require_forward_param_sync = False # TODO: DDPSink is currently enabled for unused parameter detection and # static graph training for first iteration. if (self.find_unused_parameters and not self.static_graph) or ( self.static_graph and not self._static_graph_delay_allreduce_enqueued ): ( output_tensor_list, treespec, output_is_rref, ) = _tree_flatten_with_rref(output) output_placeholders = [None for _ in range(len(output_tensor_list))] # Do not touch tensors that have no grad_fn, which can cause issues # such as https://github.com/pytorch/pytorch/issues/60733 for i, output in enumerate(output_tensor_list): if torch.is_tensor(output) and output.grad_fn is None: output_placeholders[i] = output # When find_unused_parameters=True, makes tensors which require grad # run through the DDPSink backward pass. When not all outputs are # used in loss, this makes those corresponding tensors receive # undefined gradient which the reducer then handles to ensure # param.grad field is not touched and we don't error out. passthrough_tensor_list = _DDPSink.apply( weakref.ref(self), *output_tensor_list, ) for i in range(len(output_placeholders)): if output_placeholders[i] is None: output_placeholders[i] = passthrough_tensor_list[i] # Reconstruct output data structure. output = _tree_unflatten_with_rref( output_placeholders, treespec, output_is_rref ) # At the end of the forward pass, reset the grad buffer and grad views self._clear_grad_buffer() return output def forward(self, *inputs, **kwargs): with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): inputs, kwargs = self._pre_forward(*inputs, **kwargs) output = ( self.module.forward(*inputs, **kwargs) if self._delay_all_reduce_all_params else self._run_ddp_forward(*inputs, **kwargs) ) return self._post_forward(output) def scatter(self, inputs, kwargs, device_ids): return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) def to_kwargs(self, inputs, kwargs, device_id): # Kept for BC return _to_kwargs( inputs, kwargs, torch.device(self.device_type, device_id), self.use_side_stream_for_tensor_copies, ) def gather(self, outputs, output_device): return gather(outputs, output_device, dim=self.dim) def train(self, mode=True): super().train(mode) return self # When running in join mode, schedules an allreduce to notify joined ranks # of whether backwards pass synchronization will run this iteration or not. def _check_global_requires_backward_grad_sync(self, is_joined_rank): if not is_joined_rank and self.require_backward_grad_sync: requires_sync_tensor = torch.ones(1, device=self.device) else: requires_sync_tensor = torch.zeros(1, device=self.device) work = dist.all_reduce( requires_sync_tensor, group=self.process_group, async_op=True ) # (kwen2501) This if condition is a plain translation of previous # behavior, i.e. in the `is_joined_rank=False` case, `work.wait()` # is not called and it doesn't care about the result. I am guessing # that it just wants to fire a matching all-reduce and does not want # the main stream to wait. if is_joined_rank: work.wait() should_sync_backwards = requires_sync_tensor.item() != 0 return should_sync_backwards else: return None # Return value is not/should not be used. # When running in join mode, checks and performs sync of module buffers if # the models have buffers that should be synchronized in the forward pass. def _check_and_sync_module_buffers(self): if self._check_sync_bufs_pre_fwd(): authoritative_rank = self._find_common_rank(self._distributed_rank, False) self._sync_module_buffers(authoritative_rank) # When running in join model, agrees upon a common rank and broadcast model # parameters to all other ranks. def _sync_final_model(self, is_last_joiner): # Agree upon the process that will be the authoritative model copy. # The current rank is a candidate for being the authoritative copy if # is_last_joiner=True. We break ties via picking the larger rank. self._authoritative_rank = self._find_common_rank( self._distributed_rank, is_last_joiner ) _sync_module_states( module=self.module, process_group=self.process_group, broadcast_bucket_size=self.broadcast_bucket_size, src=self._authoritative_rank, params_and_buffers_to_ignore=self.parameters_to_ignore, broadcast_buffers=self.broadcast_buffers, ) # Schedule comm ops to match those scheduled in the reducer's backward # pass. def _match_all_reduce_for_bwd_pass(self): comm_work = [] # Schedule comm in the same order as Reducer schedules them, i.e. # the order of the buckets. Retrieving the bucket order from the reducer # ensures that we keep the same order in join mode, such as when bucket # order is rebuilt dynamically. # Returns grad_buckets in order, but real tensors are substituted with # zero tensors of the same shape. grad_buckets = self.reducer._get_zeros_like_grad_buckets() for grad_bucket in grad_buckets: # Joined processes contribute zero gradient. In the case that # divide_by_initial_world_size=True, we divide grads by the static # world size, if not, the dividing factor is reduced by the number # of joined processes. work = self.reducer._run_comm_hook(grad_bucket) comm_work.append(work) for work in comm_work: work.wait() # Allreduces the used parameter mapping across ranks. def _match_unused_params_allreduce(self): locally_used_param_map = self.reducer._get_local_used_map() self.process_group.allreduce(locally_used_param_map)
[docs] def join( self, divide_by_initial_world_size: bool = True, enable: bool = True, throw_on_early_termination: bool = False, ): r""" Context manager for training with uneven inputs across processes in DDP. This context manager will keep track of already-joined DDP processes, and "shadow" the forward and backward passes by inserting collective communication operations to match with the ones created by non-joined DDP processes. This will ensure each collective call has a corresponding call by already-joined DDP processes, preventing hangs or errors that would otherwise happen when training with uneven inputs across processes. Alternatively, if the flag ``throw_on_early_termination`` is specified to be ``True``, all trainers will throw an error once one rank runs out of inputs, allowing these errors to be caught and handled according to application logic. Once all DDP processes have joined, the context manager will broadcast the model corresponding to the last joined process to all processes to ensure the model is the same across all processes (which is guaranteed by DDP). To use this to enable training with uneven inputs across processes, simply wrap this context manager around your training loop. No further modifications to the model or data loading is required. .. warning:: If the model or training loop this context manager is wrapped around has additional distributed collective operations, such as ``SyncBatchNorm`` in the model's forward pass, then the flag ``throw_on_early_termination`` must be enabled. This is because this context manager is not aware of non-DDP collective communication. This flag will cause all ranks to throw when any one rank exhausts inputs, allowing these errors to be caught and recovered from across all ranks. Args: divide_by_initial_world_size (bool): If ``True``, will divide gradients by the initial ``world_size`` DDP training was launched with. If ``False``, will compute the effective world size (number of ranks that have not depleted their inputs yet) and divide gradients by that during allreduce. Set ``divide_by_initial_world_size=True`` to ensure every input sample including the uneven inputs have equal weight in terms of how much they contribute to the global gradient. This is achieved by always dividing the gradient by the initial ``world_size`` even when we encounter uneven inputs. If you set this to ``False``, we divide the gradient by the remaining number of nodes. This ensures parity with training on a smaller ``world_size`` although it also means the uneven inputs would contribute more towards the global gradient. Typically, you would want to set this to ``True`` for cases where the last few inputs of your training job are uneven. In extreme cases, where there is a large discrepancy in the number of inputs, setting this to ``False`` might provide better results. enable (bool): Whether to enable uneven input detection or not. Pass in ``enable=False`` to disable in cases where you know that inputs are even across participating processes. Default is ``True``. throw_on_early_termination (bool): Whether to throw an error or continue training when at least one rank has exhausted inputs. If ``True``, will throw upon the first rank reaching end of data. If ``False``, will continue training with a smaller effective world size until all ranks are joined. Note that if this flag is specified, then the flag ``divide_by_initial_world_size`` would be ignored. Default is ``False``. Example:: >>> # xdoctest: +SKIP("Distributed") >>> import torch >>> import torch.distributed as dist >>> import os >>> import torch.multiprocessing as mp >>> import torch.nn as nn >>> # On each spawned worker >>> def worker(rank): >>> dist.init_process_group("nccl", rank=rank, world_size=2) >>> torch.cuda.set_device(rank) >>> model = nn.Linear(1, 1, bias=False).to(rank) >>> model = torch.nn.parallel.DistributedDataParallel( >>> model, device_ids=[rank], output_device=rank >>> ) >>> # Rank 1 gets one more input than rank 0. >>> inputs = [torch.tensor([1]).float() for _ in range(10 + rank)] >>> with model.join(): >>> for _ in range(5): >>> for inp in inputs: >>> loss = model(inp).sum() >>> loss.backward() >>> # Without the join() API, the below synchronization will hang >>> # blocking for rank 1's allreduce to complete. >>> torch.cuda.synchronize(device=rank) """ return Join( [self], enable, throw_on_early_termination, divide_by_initial_world_size=divide_by_initial_world_size, )
[docs] def join_hook( self, **kwargs, ): r""" DDP join hook enables training on uneven inputs by mirroring communications in forward and backward passes. Arguments: kwargs (dict): a :class:`dict` containing any keyword arguments to modify the behavior of the join hook at run time; all :class:`Joinable` instances sharing the same join context manager are forwarded the same value for ``kwargs``. The hook supports the following keyword arguments: divide_by_initial_world_size (bool, optional): If ``True``, then gradients are divided by the initial world size that DDP was launched with. If ``False``, then gradients are divided by the effective world size (i.e. the number of non-joined processes), meaning that the uneven inputs contribute more toward the global gradient. Typically, this should be set to ``True`` if the degree of unevenness is small but can be set to ``False`` in extreme cases for possibly better results. Default is ``True``. """ divide_by_initial_world_size = kwargs.get("divide_by_initial_world_size", True) return _DDPJoinHook( self, divide_by_initial_world_size=divide_by_initial_world_size )
@property def join_device(self): return self.device @property def join_process_group(self): return self.process_group def _register_buffer_comm_hook( self, state, hook: Callable, comm_hook_location=_BufferCommHookLocation.POST_FORWARD, ): r""" Allow custom registration of hooks that define how buffer are synchronized across ranks. The hook takes in an optional state and is passed in a Dict[str, Tensor] corresponding to buffer names and the buffers, and can run arbitrary reductions on buffers as opposed to DDP's default broadcast from rank 0. This is useful for example if a counter needs to be summed or averaged across ranks every iteration. Args: state (Any): Optional state that is passed to the hook. hook (Callable): Callable with the following signature: ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]`` comm_hook_location (_BufferCommHookLocation): Enum value indicating where to run the hook. _BufferCommHookLocation.PRE_FORWARD means that the hook will run _before_ the forward pass, and _BufferCommHookLocation.POST_FORWARD means that the hook will run _after_ the forward pass. NOTE: To maximize performance, users can return a List[torch.futures.Future] from their hook, and DDP will install and await these hooks appropriately at the end of the backward pass. This will ensure all buffers are synchronized by the end of the backward pass. If this setting is used, it is recommended to pass comm_hook_location=_BufferCommHookLocation.POST_FORWARD, which will trigger the hook after the forward pass. If _BufferCommHookLocation.PRE_FORWARD is used, users must ensure appropriate synchronization when manipulating GPU buffers in the forward pass. """ assert callable(hook) self.buffer_hook = _BufferCommHook( buffer_comm_hook=hook, buffer_comm_hook_state=state, buffer_comm_hook_location=comm_hook_location, )
[docs] def register_comm_hook(self, state: object, hook: Callable): r""" Register communication hook for user-defined DDP aggregation of gradients across multiple workers. This hook would be very useful for researchers to try out new ideas. For example, this hook can be used to implement several algorithms like GossipGrad and gradient compression which involve different communication strategies for parameter syncs while running Distributed DataParallel training. Args: state (object): Passed to the hook to maintain any state information during the training process. Examples include error feedback in gradient compression, peers to communicate with next in GossipGrad, etc. It is locally stored by each worker and shared by all the gradient tensors on the worker. hook (Callable): Callable with the following signature: ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``: This function is called once the bucket is ready. The hook can perform whatever processing is needed and return a Future indicating completion of any async work (ex: allreduce). If the hook doesn't perform any communication, it still must return a completed Future. The Future should hold the new value of grad bucket's tensors. Once a bucket is ready, c10d reducer would call this hook and use the tensors returned by the Future and copy grads to individual parameters. Note that the future's return type must be a single tensor. We also provide an API called ``get_future`` to retrieve a Future associated with the completion of ``c10d.ProcessGroup.Work``. ``get_future`` is currently supported for NCCL and also supported for most operations on GLOO and MPI, except for peer to peer operations (send/recv). .. warning :: Grad bucket's tensors will not be predivided by world_size. User is responsible to divide by the world_size in case of operations like allreduce. .. warning :: DDP communication hook can only be registered once and should be registered before calling backward. .. warning :: The Future object that hook returns should contain a single tensor that has the same shape with the tensors inside grad bucket. .. warning :: ``get_future`` API supports NCCL, and partially GLOO and MPI backends (no support for peer-to-peer operations like send/recv) and will return a ``torch.futures.Future``. Example:: Below is an example of a noop hook that returns the same tensor. >>> # xdoctest: +SKIP('undefined name') >>> def noop(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: >>> fut = torch.futures.Future() >>> fut.set_result(bucket.buffer()) >>> return fut >>> ddp.register_comm_hook(state=None, hook=noop) Example:: Below is an example of a Parallel SGD algorithm where gradients are encoded before allreduce, and then decoded after allreduce. >>> # xdoctest: +SKIP('undefined name') >>> def encode_and_decode(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: >>> encoded_tensor = encode(bucket.buffer()) # encode gradients >>> fut = torch.distributed.all_reduce(encoded_tensor).get_future() >>> # Define the then callback to decode. >>> def decode(fut): >>> decoded_tensor = decode(fut.value()[0]) # decode gradients >>> return decoded_tensor >>> return fut.then(decode) >>> ddp.register_comm_hook(state=None, hook=encode_and_decode) """ self._check_comm_hook(hook) assert self.logger is not None self.logger._set_comm_hook_name(hook.__qualname__) self._comm_hooks.append((hook, state)) dist._register_comm_hook(self.reducer, state, hook)
def _register_builtin_comm_hook(self, comm_hook_type): r""" Register a built-in communication hook that specifies how DDP aggregates gradients across multiple workers. The built-in hooks aim to provide efficient C++ implementations for certain hooks, which might not be as efficient if implemented in Python using a Python communication hook. Args: comm_hook_type (dist.BuiltinCommHookType): type of communication hook, such as ALLREDUCE, FP16_COMPRESS, etc. .. warning :: DDP communication hook can only be registered once and should be registered before calling backward. Example:: Below is an example of a FP16 compression where gradients are compressed into 16-bit floating-point numbers before allreduce, and then decompressed after allreduce. >>> # xdoctest: +SKIP('undefined name') >>> ddp._register_builtin_comm_hook(dist.BuiltinCommHookType.FP16_COMPRESS) """ assert self.logger is not None self.logger._set_comm_hook_name(str(comm_hook_type)) dist._register_builtin_comm_hook(self.reducer, comm_hook_type) def _register_fused_optim(self, optim: Type, *args, optim_params=None, **kwargs): r""" Register an optimizer in DDP to optimize parameter immediately after its gradient reduction. Registers an optimizer with DDP such that the optimization for a parameter will run immediately when that parameter's gradient is finished with reduction, instead of waiting for all parameters' gradients to finish reduction. This can result in a training speedup depending on your workload since the optimizer can run while gradient reduction for other parameters are still ongoing. In addition, this has the potential to reduce peak memory consumption during training, as it only needs to load the per-parameter optimizer states of a single parameter at a time, instead of loading all per-parameter optimizer states at once. Args: optim (Type): a ``torch.optim.Optimizer`` class to be registered as a fused optimizer. *args (Sequence[Any]): Arguments to forward to `optim`. optim_params (Optional[Iterable[torch.Tensor]]): Set of parameters to optimize, similar to `params` argument of traditional `torch.optim` Optimizers. If this is omitted, all DDP model parameters will be optimized. **kwargs: (Dict[str, Any]): Keyword arguments to forward to `optim`. .. warning :: _register_fused_optim should only be called once on a DDP instance, and registering multiple fused optimizers for the same DDP model is not currently supported. Please ping https://github.com/pytorch/pytorch/issues/71595 if this is necessary for your use case. .. warning :: _register_fused_optim and register_comm_hook currently do not compose together, meaning that custom DDP communication hooks are not supported with overlapped optimizers. Please ping https://github.com/pytorch/pytorch/issues/71595 if this is necessary for your use case. .. warning :: Gradient accumulation and DDP `no_sync` are currently not supported with overlapped optimizer. Please ping https://github.com/pytorch/pytorch/issues/71595 if this is necessary for your use case. Example:: >>> # xdoctest: +SKIP("No rendezvous handler") >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...') >>> net = torch.nn.parallel.DistributedDataParallel(model, pg) >>> lr = 1e-2 >>> betas = (0.9, 0.99) >>> eps = 1e-6 >>> net._register_fused_optim(torch.optim.Adam, lr, betas=betas, eps=eps) >>> # Example with subset of parameters >>> params_to_opt = [list(net.parameters())[0]] >>> net._register_fused_optim( ... torch.optim.Adam, lr, optim_params=params_to_opt, betas=betas, eps=eps ... ) """ # Note: importing in function, otherwise this will cause a circular # import as optimizer_overlap module needs to import DistributedDataParallel. from torch.distributed.algorithms._optimizer_overlap import _as_overlapped_optim overlapped_optim = _as_overlapped_optim(optim, optim_params, *args, **kwargs) try: overlapped_optim.register_ddp(self) except NotImplementedError as e: raise RuntimeError( f"{optim} does not support overlapped DDP. Please file an issue to PyTorch or the respective owner of {optim}." ) from e def _distributed_broadcast_coalesced( self, tensors, buffer_size, authoritative_rank=0 ): dist._broadcast_coalesced( self.process_group, tensors, buffer_size, authoritative_rank ) def _check_sync_bufs_post_fwd(self): return ( self.will_sync_module_buffers() and hasattr(self, "buffer_hook") and self.buffer_hook.buffer_comm_hook_location == _BufferCommHookLocation.POST_FORWARD ) def _check_sync_bufs_pre_fwd(self): return self.will_sync_module_buffers() and ( not hasattr(self, "buffer_hook") or self.buffer_hook.buffer_comm_hook_location == _BufferCommHookLocation.PRE_FORWARD ) def will_sync_module_buffers(self): return ( self.require_forward_param_sync and self.broadcast_buffers and len(self.modules_buffers) > 0 ) def _find_common_rank(self, input_rank, rank_cond): # -1 indicates that this rank is not under consideration to be the # common_rank rank_to_use = torch.tensor( [input_rank if rank_cond else -1], device=self.device, ) dist.all_reduce(rank_to_use, op=ReduceOp.MAX, group=self.process_group) if rank_to_use.item() == -1: self._log_and_throw( ValueError, "BUG! Expected rank_cond to be true for at least one process." " This indicates a bug in PyTorch, please report an issue.", ) return rank_to_use.item() def _sync_buffers(self): with torch.no_grad(): # module buffer sync # Synchronize buffers across processes. # If we are running DDP with the join manager, we have to agree # upon a rank to sync module buffers from, since rank 0 may # already have been joined and have stale module buffers. if self._join_config.enable: authoritative_rank = self._find_common_rank( self._distributed_rank, True ) else: # The process with rank 0 is considered the authoritative copy. authoritative_rank = 0 # Update self.modules_buffers incase any buffers were # reassigned. self._assign_modules_buffers() self._sync_module_buffers(authoritative_rank) def _sync_module_buffers(self, authoritative_rank): if not hasattr(self, "buffer_hook"): self._default_broadcast_coalesced(authoritative_rank=authoritative_rank) else: hook = self.buffer_hook.buffer_comm_hook state = self.buffer_hook.buffer_comm_hook_state futs = hook(state, self.named_module_buffers) if futs is not None: self.reducer._install_post_backward_futures(futs) def _default_broadcast_coalesced( self, bufs=None, bucket_size=None, authoritative_rank=0 ): """ Broadcasts buffers from rank 0 to rest of workers. If bufs, bucket_size are None, default values self.modules_buffers and self.broadcast_bucket_size are used instead. """ if bufs is None: bufs = self.modules_buffers if bucket_size is None: bucket_size = self.broadcast_bucket_size self._distributed_broadcast_coalesced(bufs, bucket_size, authoritative_rank) def _passing_sync_batchnorm_handle(self, module): for layer in module.modules(): if isinstance(layer, torch.nn.modules.SyncBatchNorm): if self.device_type == "cpu": self._log_and_throw( ValueError, "SyncBatchNorm layers only work with GPU modules", ) def _check_comm_hook(self, hook): if not callable(hook): self._log_and_throw(TypeError, "Communication hook must be callable.") sig = inspect.signature(hook) if ( sig.parameters["bucket"].annotation != inspect._empty and sig.parameters["bucket"].annotation != dist.GradBucket ): self._log_and_throw( ValueError, "Communication hook: bucket annotation should be dist.GradBucket.", ) if ( sig.return_annotation != inspect._empty and sig.return_annotation != torch.futures.Future[torch.Tensor] ): self._log_and_throw( ValueError, "Communication hook: return annotation should be torch.futures.Future[torch.Tensor].", ) if hook.__name__ in [ "bf16_compress_hook", "bf16_compress_wrapper_hook", ] and ( (torch.version.cuda is None and torch.version.hip is None) or ( torch.version.cuda is not None and int(torch.version.cuda.split(".")[0]) < 11 ) or not dist.is_available() or not dist.is_nccl_available() or torch.cuda.nccl.version() < (2, 10) ): self._log_and_throw( TypeError, "BF16 all reduce communication hook required CUDA 11+ and NCCL 2.10+.", ) @property def _distributed_rank(self): return dist.get_rank(self.process_group) @staticmethod def _get_data_parallel_params(module, named_params=False): """Return a generator of parameters managed by a given DDP unit.""" for param in ( module.parameters() if not named_params else module.named_parameters() ): if not hasattr(param, "_ddp_ignored"): yield param @staticmethod def _set_params_and_buffers_to_ignore_for_model( module, params_and_buffers_to_ignore ): """ Set parameters and buffers to be ignored by DDP. Expected format for parameters is the fully qualified name: {module_name}.{param_name}, and similarly, {module_name}.{buffer_name} for buffers. For example: params_to_ignore = [] # NB: model here is vanilla PyTorch module, not yet wrapped with DDP. for module_name, module in model.named_modules(): for param_name, param in module.named_parameters(recurse=False): if should_ignore(param): # Create expected format fqn = f"{module_name}.{param_name}" params_to_ignore.append(fqn) torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model( model, params_to_ignore ) """ # This is a workaround to set parameters and buffers DDP should ignore # during synchronization. It will be removed when the API is finalized # as part of addressing https://github.com/pytorch/pytorch/issues/43690. module._ddp_params_and_buffers_to_ignore = params_and_buffers_to_ignore for name, param in module.named_parameters(): if name in params_and_buffers_to_ignore: param._ddp_ignored = True for name, buffer in module.named_buffers(): if name in params_and_buffers_to_ignore: buffer._ddp_ignored = True def _get_ddp_logging_data(self): r""" Return a dictionary of logging data for debugging and analysis. This interface can be called after DistributedDataParallel() is constructed. It returns a dictionary of logging data. It could help for debugging and analysis. The logging data includes DistributedDataParallel constructor input parameters, some internal states of DistributedDataParallel and performance metrics. Simply print the dictionary and see what these metrics are. This is a prototype interface and subject to change in the future. """ assert self.logger is not None ddp_logging_data = self.logger._get_ddp_logging_data() return {**ddp_logging_data.strs_map, **ddp_logging_data.ints_map} def _set_ddp_runtime_logging_sample_rate(self, sample_rate): r""" Set sample_rate of collecting runtime stats. This interface allows users to set sample_rate of collecting runtime stats. The runtime stats will be recorded for the first 10 iterations, after 10 iterations runtime stats will be recorded once every "sample_rate" training iterations. In default, runtime stats are recorded for the first 10 iterations, after 10 iterations runtime stats are recorded once every "kDDPRuntimeLoggingSampleRate=100" training iterations. This is a prototype interface and subject to change in the future. """ if sample_rate < 1: self._log_and_throw( ValueError, "DDP runtime logging sample rate should be equal or greater than 1", ) self.reducer._set_ddp_runtime_logging_sample_rate(sample_rate) def _set_static_graph(self): """ Set static graph for DDP. It is recommended to set static graph in the DDP constructor, which will call this private API internally. """ # If self.static_graph has been set, no need to set it again if self.static_graph: warnings.warn( "You've set static_graph to be True, no need to set it again." ) return self.static_graph = True self._static_graph_delay_allreduce_enqueued = False self.reducer._set_static_graph() assert self.logger is not None self.logger._set_static_graph() if self.find_unused_parameters: warnings.warn( "You passed find_unused_parameters=true to DistributedDataParallel, " "`_set_static_graph` will detect unused parameters automatically, so " "you do not need to set find_unused_parameters=true, just be sure these " "unused parameters will not change during training loop while calling " "`_set_static_graph`." ) def _remove_autograd_hooks(self): """Remove autograd hooks registered by the reducer on the model parameters.""" self.reducer._remove_autograd_hooks() def _check_reducer_finalized(self): """ Check if the reducer has processed all buckets and finalized the backward appropriately. It is useful to call this method after calling .backward() in your training loop in order to avoid subsequent hard to debug errors down the road due to the reducer not finalizing backward. """ self.reducer._check_reducer_finalized() def _set_sparse_metadata(self, global_unique_ids): self.reducer._set_sparse_metadata(global_unique_ids) def _update_process_group(self, new_process_group): """ Dynamically updates the process group for DDP so that we can shrink/expand DDP world size without having to reinitialize DDP. NOTE: If you are using custom communications hooks via, register_comm_hook, you need to update the process groups for those hooks separately. """ # Force a rebuild of buckets for a new process group. This ensures all ranks # are synchronized in terms of when they will rebuild buckets and also # re-evaluates previous assumptions of buckets given the world size might have # changed. self._has_rebuilt_buckets = False self.reducer._reset_state() if not _rank_not_in_group(new_process_group): self.process_group = new_process_group self.reducer._update_process_group(new_process_group) def _set_ddp_sink_clone(self, val: bool): """ Sets whether or not DDPSink should clone the output tensors or not. The default is True since if the loss is modified in place we run into the view is modified in-place error. Although, cloning the tensors can add significant memory and performance hit if the number and size of tensors are large. As a result, this can be set to False if you are not modifying the loss in place. """ self._ddp_sink_clone = val

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources