[docs]classJoinHook:r""" This defines a join hook, which provides two entry points in the join context manager. Entry points : a main hook, which is called repeatedly while there exists a non-joined process, and a post-hook, which is called once all processes have joined. To implement a join hook for the generic join context manager, define a class that inherits from :class:`JoinHook` and override ``main_hook()`` and ``post_hook()`` as appropriate. """
[docs]defmain_hook(self)->None:r"""Call this hook while there exists a non-joined process to shadow collective communications in a training iteration. Training iteration i.e., in one forward pass, backward pass, and optimizer step. """
[docs]defpost_hook(self,is_last_joiner:bool)->None:r""" Call hook after all processes have joined. It is passed an additional ``bool`` argument ``is_last_joiner``, which indicates if the rank is one of the last to join. Arguments: is_last_joiner (bool): ``True`` if the rank is one of the last to join; ``False`` otherwise. """
[docs]classJoinable(ABC):r""" This defines an abstract base class for joinable classes. A joinable class (inheriting from :class:`Joinable`) should implement :meth:`join_hook`, which returns a :class:`JoinHook` instance, in addition to :meth:`join_device` and :meth:`join_process_group` that return device and process group information, respectively. """@abstractmethoddef__init__(self)->None:super().__init__()self._join_config=_JoinConfig.construct_disabled_join_config()
[docs]@abstractmethoddefjoin_hook(self,**kwargs)->JoinHook:r""" Return a :class:`JoinHook` instance for the given :class:`Joinable`. Arguments: kwargs (dict): a :class:`dict` containing any keyword arguments to modify the behavior of the join hook at run time; all :class:`Joinable` instances sharing the same join context manager are forwarded the same value for ``kwargs``. """...
@property@abstractmethoddefjoin_device(self)->torch.device:r"""Return the device from which to perform collective communications needed by the join context manager."""...@property@abstractmethoddefjoin_process_group(self)->Any:r"""Returns the process group for the collective communications needed by the join context manager itself."""...
class_JoinConfig(NamedTuple):r"""This includes all fields needed from a :class:`Joinable` instance for the join context manager side."""enable:boolthrow_on_early_termination:boolis_first_joinable:bool@staticmethoddefconstruct_disabled_join_config():r"""Return a :class:`_JoinConfig` instance indicating that join-related logic should be disabled. e.g. if the caller is not in a join context manager. """return_JoinConfig(enable=False,throw_on_early_termination=False,is_first_joinable=False)
[docs]classJoin:r""" This class defines the generic join context manager, which allows custom hooks to be called after a process joins. These hooks should shadow the collective communications of non-joined processes to prevent hanging and erroring and to ensure algorithmic correctness. Refer to :class:`JoinHook` for details about the hook definition. .. warning:: The context manager requires each participating :class:`Joinable` to call the method :meth:`notify_join_context()` before its own per- iteration collective communications to ensure correctness. .. warning:: The context manager requires that all ``process_group`` attributes in the :class:`JoinHook` objects are the same. If there are multiple :class:`JoinHook` objects, then the ``device`` of the first is used. The process group and device information is used for checking for non- joined processes and for notifying processes to throw an exception if ``throw_on_early_termination`` is enabled, both of which using an all- reduce. Arguments: joinables (List[Joinable]): a list of the participating :class:`Joinable` s; their hooks are iterated over in the given order. enable (bool): a flag enabling uneven input detection; setting to ``False`` disables the context manager's functionality and should only be set when the user knows the inputs will not be uneven (default: ``True``). throw_on_early_termination (bool): a flag controlling whether to throw an exception upon detecting uneven inputs (default: ``False``). Example:: >>> import os >>> import torch >>> import torch.distributed as dist >>> import torch.multiprocessing as mp >>> # xdoctest: +SKIP >>> import torch.nn.parallel.DistributedDataParallel as DDP >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO >>> from torch.distributed.algorithms.join import Join >>> >>> # On each spawned worker >>> def worker(rank): >>> dist.init_process_group("nccl", rank=rank, world_size=2) >>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank]) >>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01) >>> # Rank 1 gets one more input than rank 0 >>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)] >>> with Join([model, optim]): >>> for input in inputs: >>> loss = model(input).sum() >>> loss.backward() >>> optim.step() >>> # All ranks reach here without hanging/erroring """def__init__(self,joinables:List[Joinable],enable:bool=True,throw_on_early_termination:bool=False,**kwargs,):iflen(joinables)==0:raiseValueError("The join context manager requires at least one joinable")self._joinables=joinablesself._join_hooks=[joinable.join_hook(**kwargs)forjoinableinself._joinables]self._enable=enableself._throw_on_early_termination=throw_on_early_terminationself._set_joinable_configs()self._extract_dist_info()def_set_joinable_configs(self)->None:r"""Set the :class:`_JoinConfig` of each participating :class:`Joinable`."""assertlen(self._joinables)>0is_first_joinable=Trueforjoinableinself._joinables:joinable._join_config=_JoinConfig(enable=self._enable,throw_on_early_termination=self._throw_on_early_termination,is_first_joinable=is_first_joinable,)is_first_joinable=Falsedef_extract_dist_info(self)->None:r""" Extract the process group and device information from the joinables. If there are multiple joinables, then the context manager uses the first specified device. Preconditions: ``self._joinables`` is not ``None`` and is non-empty. Raises: ValueError If there are multiple conflicting ``process_group`` attributes among the ``Joinable`` objects. """process_group=Nonedevice=Noneforjoinableinself._joinables:ifprocess_groupisNone:process_group=joinable.join_process_groupelifprocess_group!=joinable.join_process_group:raiseValueError("Using join context manager with multiple process groups")ifdeviceisNone:device=joinable.join_deviceself._process_group=process_groupself._rank=dist.get_rank(self._process_group)self._device=devicedef__enter__(self):...def__exit__(self,type:Optional[Type[BaseException]],value:Optional[BaseException],traceback:Optional[TracebackType],):r""" Repeatedly runs the main hooks until all processes join; then, runs the post-hooks. Raises: RuntimeError If ``throw_on_early_termination=True``. """ifnotself._enableortype:return# propagate the exception directly if one was raisedall_procs_joined=Falseis_last_joiner=Truei=0WARN_THRESHOLD=1000warnings.simplefilter("once")whilenotall_procs_joined:ifi>WARN_THRESHOLD:warnings.warn("Detected uneven input skew of greater than "f"{WARN_THRESHOLD}. This means that rank "f"{self._rank} has at least {WARN_THRESHOLD} "f"fewer inputs than other currently-active ranks. ""This level of skew could lead to performance ""degradation during training.")# Shadow the all-reduce in non-joined processesnum_nonjoined_procs=self._get_num_nonjoined_procs()ifnum_nonjoined_procs==0:all_procs_joined=Trueelse:ifself._throw_on_early_termination:self._notify_procs_to_terminate()# Run main hooksforjoin_hookinself._join_hooks:join_hook.main_hook()is_last_joiner=Falsei+=1# Run post-hooksforjoin_hookinself._join_hooks:join_hook.post_hook(is_last_joiner)def_get_num_nonjoined_procs(self):r"""Return the number of non-joined processes by shadowing an all-reduce in the non-joined processes."""num_nonjoined_procs=torch.zeros(1,device=self._device)dist.all_reduce(num_nonjoined_procs,group=self._process_group)returnnum_nonjoined_procs.item()def_notify_procs_to_terminate(self):r"""Schedule an all-reduce to notify non-joined processes to terminate. Also raise a ``RuntimeError`` indicating that the current process has exhausted its inputs. """ones=torch.ones(1,device=self._device)dist.all_reduce(ones,group=self._process_group)raiseRuntimeError(f"Rank {self._rank} exhausted all inputs.")
[docs]@staticmethoddefnotify_join_context(joinable:Joinable):r""" Notifies the join context manager that the calling process has not yet joined. Then, if ``throw_on_early_termination=True``, checks if uneven inputs have been detected (i.e. if one process has already joined) and throws an exception if so. This method should be called from a :class:`Joinable` object before its per-iteration collective communications. For example, this should be called at the beginning of the forward pass in :class:`DistributedDataParallel`. Only the first :class:`Joinable` object passed into the context manager performs the collective communications in this method, and for the others, this method is vacuous. Arguments: joinable (Joinable): the :class:`Joinable` object calling this method. Returns: An async work handle for the all-reduce meant to notify the context manager that the process has not yet joined if ``joinable`` is the first one passed into the context manager; ``None`` otherwise. """asserthasattr(joinable,"_join_config"),(f"Check that the {type(joinable)} constructor calls the ""``Joinable`` constructor")join_config=joinable._join_config# First joinable is responsible for the collective communicationsifnotjoin_config.is_first_joinableornotjoin_config.enable:returnNonedevice=joinable.join_deviceprocess_group=joinable.join_process_group# Schedule an all-reduce to indicate that the caller has not yet joinedones=torch.ones(1,device=device)work=dist.all_reduce(ones,group=process_group,async_op=True)ifjoin_config.throw_on_early_termination:# Check if uneven inputs have been detectedzeros=torch.zeros(1,device=device)dist.all_reduce(zeros,group=process_group)should_throw=zeros.item()ifshould_throw:raiseRuntimeError("Detected at least one rank that exhausted inputs. ""Throwing across all ranks.")returnwork
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.