Source code for torch.distributed.elastic.agent.server.local_elastic_agent
#!/usr/bin/env python3# Copyright (c) Facebook, Inc. and its affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.importjsonimportosimportshutilimportsignalimportsocketimporttempfileimportuuidfromtypingimportAny,Dict,Optional,Tupleimporttorch.distributed.elastic.timerastimerfromtorch.distributed.elasticimporteventsfromtorch.distributed.elastic.agent.server.apiimport(RunResult,SimpleElasticAgent,WorkerGroup,WorkerSpec,WorkerState,)fromtorch.distributed.elastic.events.apiimportEventMetadataValuefromtorch.distributed.elastic.metrics.apiimportproffromtorch.distributed.elastic.multiprocessingimportPContext,start_processesfromtorch.distributed.elastic.utilsimportmacrosfromtorch.distributed.elastic.utils.loggingimportget_loggerlog=get_logger(__name__)__all__=["LocalElasticAgent","TORCHELASTIC_ENABLE_FILE_TIMER","TORCHELASTIC_TIMER_FILE",]TORCHELASTIC_ENABLE_FILE_TIMER="TORCHELASTIC_ENABLE_FILE_TIMER"TORCHELASTIC_TIMER_FILE="TORCHELASTIC_TIMER_FILE"
[docs]classLocalElasticAgent(SimpleElasticAgent):""" An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers. This agent is deployed per host and is configured to spawn ``n`` workers. When using GPUs, ``n`` maps to the number of GPUs available on the host. The local agent does not communicate to other local agents deployed on other hosts, even if the workers may communicate inter-host. The worker id is interpreted to be a local process. The agent starts and stops all worker processes as a single unit. The worker function and argument passed to the worker function must be python multiprocessing compatible. To pass multiprocessing data structures to the workers you may create the data structure in the same multiprocessing context as the specified ``start_method`` and pass it as a function argument. The ``exit_barrier_timeout`` specifies the amount of time (in seconds) to wait for other agents to finish. This acts as a safety net to handle cases where workers finish at different times, to prevent agents from viewing workers that finished early as a scale-down event. It is strongly advised that the user code deal with ensuring that workers are terminated in a synchronous manner rather than relying on the exit_barrier_timeout. A named pipe based watchdog can be enabled in ```LocalElasticAgent``` if an environment variable ``TORCHELASTIC_ENABLE_FILE_TIMER`` with value 1 has been defined in the ```LocalElasticAgent``` process. Optionally, another environment variable ```TORCHELASTIC_TIMER_FILE``` can be set with a unique file name for the named pipe. If the environment variable ```TORCHELASTIC_TIMER_FILE``` is not set, ```LocalElasticAgent``` will internally create a unique file name and set it to the environment variable ```TORCHELASTIC_TIMER_FILE```, and this environment variable will be propagated to the worker processes to allow them to connect to the same named pipe that ```LocalElasticAgent``` uses. Example launching function :: def trainer(args) -> str: return "do train" def main(): start_method="spawn" shared_queue= multiprocessing.get_context(start_method).Queue() spec = WorkerSpec( role="trainer", local_world_size=nproc_per_process, entrypoint=trainer, args=("foobar",), ...<OTHER_PARAMS...>) agent = LocalElasticAgent(spec, start_method) results = agent.run() if results.is_failed(): print("trainer failed") else: print(f"rank 0 return value: {results.return_values[0]}") # prints -> rank 0 return value: do train Example launching binary :: def main(): spec = WorkerSpec( role="trainer", local_world_size=nproc_per_process, entrypoint="/usr/local/bin/trainer", args=("--trainer-args", "foobar"), ...<OTHER_PARAMS...>) agent = LocalElasticAgent(spec) results = agent.run() if not results.is_failed(): print("binary launches do not have return values") """def__init__(self,spec:WorkerSpec,start_method="spawn",exit_barrier_timeout:float=300,log_dir:Optional[str]=None,):super().__init__(spec,exit_barrier_timeout)self._start_method=start_methodself._pcontext:Optional[PContext]=Nonerdzv_run_id=spec.rdzv_handler.get_run_id()self._log_dir=self._make_log_dir(log_dir,rdzv_run_id)self._worker_watchdog:Optional[timer.FileTimerServer]=Nonedef_make_log_dir(self,log_dir:Optional[str],rdzv_run_id:str):base_log_dir=log_dirortempfile.mkdtemp(prefix="torchelastic_")os.makedirs(base_log_dir,exist_ok=True)dir=tempfile.mkdtemp(prefix=f"{rdzv_run_id}_",dir=base_log_dir)log.info("log directory set to: %s",dir)returndirdef_setup_local_watchdog(self,envs:Dict[int,Dict[str,str]])->None:enable_watchdog_env_name=TORCHELASTIC_ENABLE_FILE_TIMERwatchdog_enabled=os.getenv(enable_watchdog_env_name)watchdog_file_env_name=TORCHELASTIC_TIMER_FILEwatchdog_file_path=os.getenv(watchdog_file_env_name)ifwatchdog_enabledisnotNoneandstr(watchdog_enabled)=="1":ifwatchdog_file_pathisNone:watchdog_file_path="/tmp/watchdog_timer_"+str(uuid.uuid4())log.info("Starting a FileTimerServer with %s ...",watchdog_file_path)self._worker_watchdog=timer.FileTimerServer(file_path=watchdog_file_path,max_interval=0.1,daemon=True,log_event=self._log_watchdog_event)self._worker_watchdog.start()log.info("FileTimerServer started")else:log.info("Environment variable '%s' not found. Do not start FileTimerServer.",enable_watchdog_env_name)# Propagate the watchdog file env to worker processesifwatchdog_file_pathisnotNone:forworker_envinenvs.values():worker_env[watchdog_file_env_name]=watchdog_file_pathdef_get_fq_hostname(self)->str:returnsocket.getfqdn(socket.gethostname())def_log_watchdog_event(self,name:str,request:Optional[timer.FileTimerRequest],)->None:wg=self._worker_groupspec=wg.specmd={"watchdog_event":name}ifrequestisnotNone:md["worker_pid"]=str(request.worker_pid)md["scope_id"]=request.scope_idmd["expiration_time"]=str(request.expiration_time)md["signal"]=str(request.signal)md_str=json.dumps(md)state="RUNNING"metadata:Dict[str,EventMetadataValue]={"run_id":spec.rdzv_handler.get_run_id(),"global_rank":None,"group_rank":wg.group_rank,"worker_id":None,"role":spec.role,"hostname":self._get_fq_hostname(),"state":state,"total_run_time":self._total_execution_time,"rdzv_backend":spec.rdzv_handler.get_backend(),"raw_error":None,"metadata":md_str,"agent_restarts":spec.max_restarts-self._remaining_restarts,}# Note: The 'metadata' field of the Event is converted to a TorchelasticStatusLogEntry later.# The 'name' field of the Event is NOT used in the TorchelasticStatusLogEntry.event=events.Event(name=name,source=events.EventSource.AGENT,metadata=metadata)events.record(event)# pyre-fixme[56]: Pyre was not able to infer the type of the decorator# `torch.distributed.elastic.metrics.prof`.@profdef_stop_workers(self,worker_group:WorkerGroup)->None:self._shutdown()# pyre-fixme[56]: Pyre was not able to infer the type of the decorator# `torch.distributed.elastic.metrics.prof`.@profdef_start_workers(self,worker_group:WorkerGroup)->Dict[int,Any]:spec=worker_group.specstore=worker_group.storeassertstoreisnotNonemaster_addr,master_port=super()._get_master_addr_port(store)restart_count=spec.max_restarts-self._remaining_restartsuse_agent_store=spec.rdzv_handler.get_backend()=="static"args:Dict[int,Tuple]={}envs:Dict[int,Dict[str,str]]={}forworkerinworker_group.workers:local_rank=worker.local_rankworker_env={"LOCAL_RANK":str(local_rank),"RANK":str(worker.global_rank),"GROUP_RANK":str(worker_group.group_rank),"ROLE_RANK":str(worker.role_rank),"ROLE_NAME":spec.role,"LOCAL_WORLD_SIZE":str(spec.local_world_size),"WORLD_SIZE":str(worker.world_size),"GROUP_WORLD_SIZE":str(worker_group.group_world_size),"ROLE_WORLD_SIZE":str(worker.role_world_size),"MASTER_ADDR":master_addr,"MASTER_PORT":str(master_port),"TORCHELASTIC_RESTART_COUNT":str(restart_count),"TORCHELASTIC_MAX_RESTARTS":str(spec.max_restarts),"TORCHELASTIC_RUN_ID":spec.rdzv_handler.get_run_id(),"TORCHELASTIC_USE_AGENT_STORE":str(use_agent_store),"NCCL_ASYNC_ERROR_HANDLING":os.getenv("NCCL_ASYNC_ERROR_HANDLING",str(1)),}if"OMP_NUM_THREADS"inos.environ:worker_env["OMP_NUM_THREADS"]=os.environ["OMP_NUM_THREADS"]envs[local_rank]=worker_envworker_args=list(spec.args)worker_args=macros.substitute(worker_args,str(local_rank))args[local_rank]=tuple(worker_args)# scaling events do not count towards restarts (gets same attempt #)# remove existing log dir if this restart is due to a scaling eventattempt_log_dir=os.path.join(self._log_dir,f"attempt_{restart_count}")shutil.rmtree(attempt_log_dir,ignore_errors=True)os.makedirs(attempt_log_dir)self._setup_local_watchdog(envs=envs)assertspec.entrypointisnotNoneself._pcontext=start_processes(name=spec.role,entrypoint=spec.entrypoint,args=args,envs=envs,log_dir=attempt_log_dir,start_method=self._start_method,redirects=spec.redirects,tee=spec.tee,)returnself._pcontext.pids()def_shutdown(self,death_sig:signal.Signals=signal.SIGTERM)->None:ifself._worker_watchdogisnotNone:self._worker_watchdog.stop()self._worker_watchdog=Noneifself._pcontext:self._pcontext.close(death_sig)# pyre-fixme[56]: Pyre was not able to infer the type of the decorator# `torch.distributed.elastic.metrics.prof`.@profdef_monitor_workers(self,worker_group:WorkerGroup)->RunResult:role=worker_group.spec.roleworker_pids={w.idforwinworker_group.workers}assertself._pcontextisnotNonepc_pids=set(self._pcontext.pids().values())ifworker_pids!=pc_pids:log.error("[%s] worker pids do not match process_context pids."" Expected: %s, actual: %s",role,worker_pids,pc_pids)returnRunResult(state=WorkerState.UNKNOWN)result=self._pcontext.wait(0)ifresult:ifresult.is_failed():# map local rank failure to global rankworker_failures={}forlocal_rank,failureinresult.failures.items():worker=worker_group.workers[local_rank]worker_failures[worker.global_rank]=failurereturnRunResult(state=WorkerState.FAILED,failures=worker_failures,)else:# copy ret_val_queue into a map with a global ranksworkers_ret_vals={}forlocal_rank,ret_valinresult.return_values.items():worker=worker_group.workers[local_rank]workers_ret_vals[worker.global_rank]=ret_valreturnRunResult(state=WorkerState.SUCCEEDED,return_values=workers_ret_vals,)else:returnRunResult(state=WorkerState.HEALTHY)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.