Source code for torch.distributed.elastic.timer.local_timer
# Copyright (c) Facebook, Inc. and its affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.importloggingimportmultiprocessingasmpimportosimportsignalimporttimefromqueueimportEmptyfromtypingimportAny,Dict,List,Set,Tuplefrom.apiimportRequestQueue,TimerClient,TimerRequest,TimerServer__all__=['LocalTimerClient','MultiprocessingRequestQueue','LocalTimerServer']log=logging.getLogger(__name__)
[docs]classLocalTimerClient(TimerClient):""" Client side of ``LocalTimerServer``. This client is meant to be used on the same host that the ``LocalTimerServer`` is running on and uses pid to uniquely identify a worker. This is particularly useful in situations where one spawns a subprocess (trainer) per GPU on a host with multiple GPU devices. """def__init__(self,mp_queue):super().__init__()self._mp_queue=mp_queuedefacquire(self,scope_id,expiration_time):pid=os.getpid()acquire_request=TimerRequest(pid,scope_id,expiration_time)self._mp_queue.put(acquire_request)defrelease(self,scope_id):pid=os.getpid()release_request=TimerRequest(pid,scope_id,-1)self._mp_queue.put(release_request)
classMultiprocessingRequestQueue(RequestQueue):""" A ``RequestQueue`` backed by python ``multiprocessing.Queue`` """def__init__(self,mp_queue:mp.Queue):super().__init__()self._mp_queue=mp_queuedefsize(self)->int:returnself._mp_queue.qsize()defget(self,size,timeout:float)->List[TimerRequest]:requests=[]wait=timeoutfor_inrange(0,size):start=time.time()try:r=self._mp_queue.get(block=True,timeout=wait)exceptEmpty:breakrequests.append(r)wait=wait-(time.time()-start)ifwait<=0:breakreturnrequests
[docs]classLocalTimerServer(TimerServer):""" Server that works with ``LocalTimerClient``. Clients are expected to be subprocesses to the parent process that is running this server. Each host in the job is expected to start its own timer server locally and each server instance manages timers for local workers (running on processes on the same host). """def__init__(self,mp_queue:mp.Queue,max_interval:float=60,daemon:bool=True):super().__init__(MultiprocessingRequestQueue(mp_queue),max_interval,daemon)self._timers:Dict[Tuple[Any,str],TimerRequest]={}defregister_timers(self,timer_requests:List[TimerRequest])->None:forrequestintimer_requests:pid=request.worker_idscope_id=request.scope_idexpiration_time=request.expiration_time# negative expiration is a proxy for a release callifexpiration_time<0:self._timers.pop((pid,scope_id),None)else:self._timers[(pid,scope_id)]=requestdefclear_timers(self,worker_ids:Set[int])->None:for(pid,scope_id)inlist(self._timers.keys()):ifpidinworker_ids:self._timers.pop((pid,scope_id))defget_expired_timers(self,deadline:float)->Dict[Any,List[TimerRequest]]:# pid -> [timer_requests...]expired_timers:Dict[Any,List[TimerRequest]]={}forrequestinself._timers.values():ifrequest.expiration_time<=deadline:expired_scopes=expired_timers.setdefault(request.worker_id,[])expired_scopes.append(request)returnexpired_timersdef_reap_worker(self,worker_id:int)->bool:try:os.kill(worker_id,signal.SIGKILL)returnTrueexceptProcessLookupError:log.info("Process with pid=%s does not exist. Skipping",worker_id)returnTrueexceptExceptionase:log.error("Error terminating pid=%s",worker_id,exc_info=e)returnFalse
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.