Source code for torch.distributed.elastic.multiprocessing.api
#!/usr/bin/env python3# Copyright (c) Facebook, Inc. and its affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.importabcimportloggingimportosimportreimportsignalimportsubprocessimportsysimporttimefromcontextlibimportnullcontextfromdataclassesimportdataclass,fieldfromenumimportIntFlagfrommultiprocessingimportsynchronizefromtypesimportFrameTypefromtypingimportAny,Callable,Dict,Optional,Set,Tuple,Unionimporttorch.multiprocessingasmpfromtorch.distributed.elastic.multiprocessing.errorsimportProcessFailure,recordfromtorch.distributed.elastic.multiprocessing.redirectsimport(redirect_stderr,redirect_stdout,)fromtorch.distributed.elastic.multiprocessing.tail_logimportTailLogIS_WINDOWS=sys.platform=="win32"IS_MACOS=sys.platform=="darwin"log=logging.getLogger(__name__)__all__=["SignalException","Std","to_map","RunProcsResult","PContext","get_std_cm","MultiprocessContext","SubprocessHandler","SubprocessContext"]classSignalException(Exception):""" Exception is raised inside the torchelastic agent process by the termination handler if the death signal got received by the process. """def__init__(self,msg:str,sigval:signal.Signals)->None:super().__init__(msg)self.sigval=sigvaldef_terminate_process_handler(signum:int,frame:Optional[FrameType])->None:"""Termination handler that raises exceptions on the main process. When the process receives death signal(SIGTERM, SIGINT), this termination handler will be invoked. It raises the ``SignalException`` exception that should be processed by the user code. Python does not terminate process after the termination handler is finished, so the exception should not be silently ignored, otherwise the process will never be terminated. """sigval=signal.Signals(signum)raiseSignalException(f"Process {os.getpid()} got signal: {sigval}",sigval=sigval)def_get_kill_signal()->signal.Signals:""" Get the kill signal. SIGKILL for unix, CTRL_C_EVENT for windows. """ifIS_WINDOWS:returnsignal.CTRL_C_EVENT# type: ignore[attr-defined] # noqa: F821else:returnsignal.SIGKILLdef_get_default_signal()->signal.Signals:""" Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows. """ifIS_WINDOWS:returnsignal.CTRL_C_EVENT# type: ignore[attr-defined] # noqa: F821else:returnsignal.SIGTERMdef_validate_full_rank(d:Dict[int,Any],nprocs:int,what:str):actual_keys=set(d.keys())expected_keys=set(range(nprocs))ifactual_keys!=expected_keys:raiseRuntimeError(f"{what}, local rank mapping mismatch,"f" expected: {expected_keys}, actual: {actual_keys}")_MAPPING_REGEX=r"^(\d:[0123],)*(\d:[0123])$"_VALUE_REGEX=r"^[0123]$"classStd(IntFlag):NONE=0OUT=1ERR=2ALL=OUT|ERR@classmethoddeffrom_str(cls,vm:str)->Union["Std",Dict[int,"Std"]]:""" Example: :: from_str("0") -> Std.NONE from_str("1") -> Std.OUT from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR} Any other input raises an exception """defto_std(v:str)->Std:# type: ignore[return]s=Std(int(v))ifsinStd:returns# return None -> should NEVER reach here since we regex check inputifre.match(_VALUE_REGEX,vm):# vm is a number (e.g. 0)returnto_std(vm)elifre.match(_MAPPING_REGEX,vm):# vm is a mapping (e.g. 0:1,1:2)d:Dict[int,Std]={}forminvm.split(","):i,v=m.split(":")d[int(i)]=to_std(v)returndelse:raiseValueError(f"{vm} does not match: <{_VALUE_REGEX}> or <{_MAPPING_REGEX}>")defto_map(val_or_map:Union[Std,Dict[int,Std]],local_world_size:int)->Dict[int,Std]:""" Certain APIs take redirect settings either as a single value (e.g. apply to all local ranks) or as an explicit user-provided mapping. This method is a convenience method that converts a value or mapping into a mapping. Example: :: to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT} to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT} to_map({0: Std.OUT, 1: Std.OUT}, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT} """ifisinstance(val_or_map,Std):return{i:val_or_mapforiinrange(local_world_size)}else:map={}foriinrange(local_world_size):map[i]=val_or_map.get(i,Std.NONE)returnmap
[docs]@dataclassclassRunProcsResult:""" Results of a completed run of processes started with ``start_processes()``. Returned by ``PContext``. Note the following: 1. All fields are mapped by local rank 2. ``return_values`` - only populated for functions (not the binaries). 3. ``stdouts`` - path to stdout.log (empty string if no redirect) 4. ``stderrs`` - path to stderr.log (empty string if no redirect) """return_values:Dict[int,Any]=field(default_factory=dict)failures:Dict[int,ProcessFailure]=field(default_factory=dict)stdouts:Dict[int,str]=field(default_factory=dict)stderrs:Dict[int,str]=field(default_factory=dict)defis_failed(self)->bool:returnlen(self.failures)>0
[docs]classPContext(abc.ABC):""" The base class that standardizes operations over a set of processes that are launched via different mechanisms. The name ``PContext`` is intentional to disambiguate with ``torch.multiprocessing.ProcessContext``. .. warning:: stdouts and stderrs should ALWAYS be a superset of tee_stdouts and tee_stderrs (respectively) this is b/c tee is implemented as a redirect + tail -f <stdout/stderr.log> """def__init__(self,name:str,entrypoint:Union[Callable,str],args:Dict[int,Tuple],envs:Dict[int,Dict[str,str]],stdouts:Dict[int,str],stderrs:Dict[int,str],tee_stdouts:Dict[int,str],tee_stderrs:Dict[int,str],error_files:Dict[int,str],):self.name=name# validate that all mappings have the same number of keys and# all local ranks are accounted fornprocs=len(args)_validate_full_rank(stdouts,nprocs,"stdouts")_validate_full_rank(stderrs,nprocs,"stderrs")self.entrypoint=entrypointself.args=argsself.envs=envsself.stdouts=stdoutsself.stderrs=stderrsself.error_files=error_filesself.nprocs=nprocsself._stdout_tail=TailLog(name,tee_stdouts,sys.stdout)self._stderr_tail=TailLog(name,tee_stderrs,sys.stderr)defstart(self)->None:""" Start processes using parameters defined in the constructor. """signal.signal(signal.SIGTERM,_terminate_process_handler)signal.signal(signal.SIGINT,_terminate_process_handler)ifnotIS_WINDOWS:signal.signal(signal.SIGHUP,_terminate_process_handler)signal.signal(signal.SIGQUIT,_terminate_process_handler)self._start()self._stdout_tail.start()self._stderr_tail.start()@abc.abstractmethoddef_start(self)->None:""" Start processes using strategy defined in a particular context. """raiseNotImplementedError()@abc.abstractmethoddef_poll(self)->Optional[RunProcsResult]:""" Polls the run status of the processes running under this context. This method follows an "all-or-nothing" policy and returns a ``RunProcessResults`` object if either all processes complete successfully or any process fails. Returns ``None`` if all processes are still running. """raiseNotImplementedError()defwait(self,timeout:float=-1,period:float=1)->Optional[RunProcsResult]:""" Waits for the specified ``timeout`` seconds, polling every ``period`` seconds for the processes to be done. Returns ``None`` if the processes are still running on timeout expiry. Negative timeout values are interpreted as "wait-forever". A timeout value of zero simply queries the status of the processes (e.g. equivalent to a poll). ..note: Multiprocessing library registers SIGTERM and SIGINT signal handlers that raise ``SignalException`` when the signals received. It is up to the consumer of the code to properly handle the exception. It is important not to swallow the exception otherwise the process would not terminate. Example of the typical workflow can be: .. code-block:: python pc = start_processes(...) try: pc.wait(1) .. do some other work except SignalException as e: pc.shutdown(e.sigval, timeout=30) If SIGTERM or SIGINT occurs, the code above will try to shutdown child processes by propagating received signal. If child processes will not terminate in the timeout time, the process will send the SIGKILL. """iftimeout==0:returnself._poll()iftimeout<0:timeout=sys.maxsizeexpiry=time.time()+timeoutwhiletime.time()<expiry:pr=self._poll()ifpr:returnprtime.sleep(period)returnNone@abc.abstractmethoddefpids(self)->Dict[int,int]:""" Returns pids of processes mapped by their respective local_ranks """raiseNotImplementedError()@abc.abstractmethoddef_close(self,death_sig:signal.Signals,timeout:int=30)->None:r""" Terminates all processes managed by this context and cleans up any meta resources (e.g. redirect, error_file files). """raiseNotImplementedError()defclose(self,death_sig:Optional[signal.Signals]=None,timeout:int=30)->None:r""" Terminates all processes managed by this context and cleans up any meta resources (e.g. redirect, error_file files). Args: death_sig: Death signal to terminate processes. timeout: Time to wait for processes to finish, if process is still alive after this time, it will be terminated via SIGKILL. """ifnotdeath_sig:death_sig=_get_default_signal()self._close(death_sig=death_sig,timeout=timeout)ifself._stdout_tail:self._stdout_tail.stop()ifself._stderr_tail:self._stderr_tail.stop()
defget_std_cm(std_rd:str,redirect_fn):ifIS_WINDOWSorIS_MACOSornotstd_rd:returnnullcontext()else:returnredirect_fn(std_rd)def_wrap(local_rank:int,fn:Callable,args:Dict[int,Tuple],envs:Dict[int,Dict[str,str]],stdout_redirects:Dict[int,str],# redirect file for stdout (to console if None)stderr_redirects:Dict[int,str],# redirect file for stderr (to console if None)ret_vals:Dict[int,mp.SimpleQueue],queue_finished_reading_event:synchronize.Event,)->None:# get the per-rank params up front so we fail fast if no mapping is foundargs_=args[local_rank]env_=envs[local_rank]ret_val_=ret_vals[local_rank]stdout_rd=stdout_redirects[local_rank]stderr_rd=stderr_redirects[local_rank]stdout_cm=get_std_cm(stdout_rd,redirect_stdout)stderr_cm=get_std_cm(stderr_rd,redirect_stderr)fork,vinenv_.items():os.environ[k]=vwithstdout_cm,stderr_cm:ret=record(fn)(*args_)ret_val_.put(ret)queue_finished_reading_event.wait()
[docs]classMultiprocessContext(PContext):""" ``PContext`` holding worker processes invoked as a function. """def__init__(self,name:str,entrypoint:Callable,args:Dict[int,Tuple],envs:Dict[int,Dict[str,str]],stdouts:Dict[int,str],stderrs:Dict[int,str],tee_stdouts:Dict[int,str],tee_stderrs:Dict[int,str],error_files:Dict[int,str],start_method:str,):super().__init__(name,entrypoint,args,envs,stdouts,stderrs,tee_stdouts,tee_stderrs,error_files,)self.start_method=start_method# each ret_val queue will always contain a single element.self._ret_vals={local_rank:mp.get_context(self.start_method).SimpleQueue()forlocal_rankinrange(self.nprocs)}# see comments in ``join()`` for what this isself._return_values:Dict[int,Any]={}self._pc:Optional[mp.ProcessContext]=None# Note: set method should ONLY be invoked for the use case when all processes finished# successfully. If any process died on event.wait() calling set() method will deadlock.self._worker_finished_event=mp.get_context(self.start_method).Event()def_start(self):ifself._pc:raiseValueError("The process context already initialized."" Most likely the start method got called twice.")self._pc=mp.start_processes(fn=_wrap,args=(self.entrypoint,self.args,self.envs,self.stdouts,self.stderrs,self._ret_vals,self._worker_finished_event,),nprocs=self.nprocs,join=False,daemon=False,start_method=self.start_method,)def_is_done(self)->bool:returnlen(self._return_values)==self.nprocsdef_poll(self)->Optional[RunProcsResult]:assertself._pcisnotNone# assertion for mypy type checkertry:# torch.mp.ProcessContext Throws an Exception if some/all of# worker processes failed# timeout < 0 checks worker status and return immediately# Join will never return success since we use synchronize.Event to wait# for all processes to finish.self._pc.join(-1)# IMPORTANT: we use multiprocessing.Queue to carry worker return values# back to the parent, the worker process will wait before terminating# until all the buffered items are fed by the feeder thread to the underlying# pipe. Hence to prevent deadlocks on large return values,# we opportunistically try queue.get on each join call# See: https://docs.python.org/2/library/multiprocessing.html#all-platformsforlocal_rankinrange(0,self.nprocs):return_queue=self._ret_vals[local_rank]ifnotreturn_queue.empty():# save the return values temporarily into a member varself._return_values[local_rank]=return_queue.get()ifself._is_done():# we should ALWAYS have ALL the return values when all the processes are doneself._worker_finished_event.set()# Wait untill all processes are finished. At this point workers finished executing# user functionself._pc.join()_validate_full_rank(self._return_values,self.nprocs,"return_value queue")self.close()returnRunProcsResult(return_values=self._return_values,stdouts=self.stdouts,stderrs=self.stderrs,)else:returnNoneexcept(mp.ProcessRaisedException,mp.ProcessExitedException)ase:failed_local_rank=e.error_index# entrypoint for MultiprocessContext will always be a Callablefn_name=self.entrypoint.__qualname__# type: ignore[union-attr]failed_proc=self._pc.processes[failed_local_rank]error_filepath=self.error_files[failed_local_rank]log.error("failed (exitcode: %s)"" local_rank: %s (pid: %s)"" of fn: %s (start_method: %s)",failed_proc.exitcode,failed_local_rank,e.pid,fn_name,self.start_method,exc_info=True,)self.close()returnRunProcsResult(failures={failed_local_rank:ProcessFailure(local_rank=failed_local_rank,pid=e.pid,exitcode=failed_proc.exitcode,error_file=error_filepath,)},stdouts=self.stdouts,stderrs=self.stderrs,)defpids(self)->Dict[int,int]:assertself._pcisnotNone# assertion for mypy type checkingreturndict(enumerate(self._pc.pids()))def_close(self,death_sig:signal.Signals,timeout:int=30)->None:ifnotself._pc:returnforprocinself._pc.processes:ifproc.is_alive():log.warning("Closing process %s via signal %s",proc.pid,death_sig.name)try:os.kill(proc.pid,death_sig)exceptProcessLookupError:# If the process exited because of some reason,# `ProcessLookupError` will be raised, it is safe to ignore it.passend=time.monotonic()+timeoutforprocinself._pc.processes:time_to_wait=end-time.monotonic()iftime_to_wait<=0:breakproc.join(time_to_wait)forprocinself._pc.processes:ifproc.is_alive():log.warning("Unable to shutdown process %s via %s, forcefully exiting via %s",proc.pid,death_sig,_get_kill_signal())try:os.kill(proc.pid,_get_kill_signal())exceptProcessLookupError:# If the process exited because of some reason,# `ProcessLookupError` will be raised, it is safe to ignore it.passproc.join()
classSubprocessHandler:""" Convenience wrapper around python's ``subprocess.Popen``. Keeps track of meta-objects associated to the process (e.g. stdout and stderr redirect fds). """def__init__(self,entrypoint:str,args:Tuple,env:Dict[str,str],stdout:str,stderr:str,):self._stdout=open(stdout,"w")ifstdoutelseNoneself._stderr=open(stderr,"w")ifstderrelseNone# inherit parent environment varsenv_vars=os.environ.copy()env_vars.update(env)args_str=(entrypoint,*[str(e)foreinargs])self.proc:subprocess.Popen=self._popen(args_str,env_vars)def_popen(self,args:Tuple,env:Dict[str,str])->subprocess.Popen:returnsubprocess.Popen(# pyre-fixme[6]: Expected `Union[typing.Sequence[Union[_PathLike[bytes],# _PathLike[str], bytes, str]], bytes, str]` for 1st param but got# `Tuple[str, *Tuple[Any, ...]]`.args=args,env=env,stdout=self._stdout,stderr=self._stderr,)defclose(self,death_sig:Optional[signal.Signals]=None)->None:ifnotdeath_sig:death_sig=_get_default_signal()self.proc.send_signal(death_sig)ifself._stdout:self._stdout.close()ifself._stderr:self._stderr.close()
[docs]classSubprocessContext(PContext):""" ``PContext`` holding worker processes invoked as a binary. """def__init__(self,name:str,entrypoint:str,args:Dict[int,Tuple],envs:Dict[int,Dict[str,str]],stdouts:Dict[int,str],stderrs:Dict[int,str],tee_stdouts:Dict[int,str],tee_stderrs:Dict[int,str],error_files:Dict[int,str],):super().__init__(name,entrypoint,args,envs,stdouts,stderrs,tee_stdouts,tee_stderrs,error_files,)# state vector; _vdone[local_rank] -> is local_rank finished or notself._running_local_ranks:Set[int]=set(range(self.nprocs))self._failures:Dict[int,ProcessFailure]={}self.subprocess_handlers:Dict[int,SubprocessHandler]={}def_start(self):ifself.subprocess_handlers:raiseValueError("The subprocess handlers already initialized. Most likely the start method got called twice.")self.subprocess_handlers={local_rank:SubprocessHandler(entrypoint=self.entrypoint,# type: ignore[arg-type] # entrypoint is always a strargs=self.args[local_rank],env=self.envs[local_rank],stdout=self.stdouts[local_rank],stderr=self.stderrs[local_rank],)forlocal_rankinrange(self.nprocs)}def_poll(self)->Optional[RunProcsResult]:done_local_ranks=set()forlocal_rankinself._running_local_ranks:handler=self.subprocess_handlers[local_rank]exitcode=handler.proc.poll()ifexitcodeisnotNone:done_local_ranks.add(local_rank)ifexitcode!=0:# failed or signaledself._failures[local_rank]=ProcessFailure(local_rank=local_rank,pid=handler.proc.pid,exitcode=exitcode,error_file=self.error_files[local_rank],)# else: --> succeeded; nothing to doself._running_local_ranks.difference_update(done_local_ranks)# if ALL procs are finished or ANY have failedifnotself._running_local_ranksorself._failures:self.close()# terminate all running procsresult=RunProcsResult(failures=self._failures,stdouts=self.stdouts,stderrs=self.stderrs,)ifresult.is_failed():first_failure=min(result.failures.values(),key=lambdaf:f.timestamp)log.error("failed (exitcode: %s)"" local_rank: %s (pid: %s)"" of binary: %s",first_failure.exitcode,first_failure.local_rank,first_failure.pid,self.entrypoint)else:# Populate return with dummy values. This provides consistency with MultiprocessingHandlerresult.return_values={local_rank:Noneforlocal_rankinrange(self.nprocs)}returnresultelse:# there are no failures and procs still runningreturnNonedefpids(self)->Dict[int,int]:return{local_rank:sh.proc.pidforlocal_rank,shinself.subprocess_handlers.items()}def_close(self,death_sig:signal.Signals,timeout:int=30)->None:ifnotself.subprocess_handlers:returnforhandlerinself.subprocess_handlers.values():ifhandler.proc.poll()isNone:log.warning("Sending process %s closing signal %s",handler.proc.pid,death_sig.name)handler.close(death_sig=death_sig)end=time.monotonic()+timeoutforhandlerinself.subprocess_handlers.values():time_to_wait=end-time.monotonic()iftime_to_wait<=0:breaktry:handler.proc.wait(time_to_wait)exceptsubprocess.TimeoutExpired:# Ignore the timeout expired exception, since# the child process will be forcefully terminated via SIGKILLpassforhandlerinself.subprocess_handlers.values():ifhandler.proc.poll()isNone:log.warning("Unable to shutdown process %s via %s, forcefully exiting via %s",handler.proc.pid,death_sig,_get_kill_signal())handler.close(death_sig=_get_kill_signal())handler.proc.wait()
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.