Source code for torch.distributed.elastic.agent.server.local_elastic_agent

#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import shutil
import signal
import tempfile
from typing import Any, Dict, Optional, Tuple

from torch.distributed.elastic.agent.server.api import (
from torch.distributed.elastic.metrics.api import prof
from torch.distributed.elastic.multiprocessing import PContext, start_processes
from torch.distributed.elastic.utils import macros
from torch.distributed.elastic.utils.logging import get_logger

log = get_logger()

[docs]class LocalElasticAgent(SimpleElasticAgent): """ An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers. This agent is deployed per host and is configured to spawn ``n`` workers. When using GPUs, ``n`` maps to the number of GPUs available on the host. The local agent does not communicate to other local agents deployed on other hosts, even if the workers may communicate inter-host. The worker id is interpreted to be a local process. The agent starts and stops all worker processes as a single unit. The worker function and argument passed to the worker function must be python multiprocessing compatible. To pass multiprocessing data structures to the workers you may create the data structure in the same multiprocessing context as the specified ``start_method`` and pass it as a function argument. The ``exit_barrier_timeout`` specifies the amount of time (in seconds) to wait for other agents to finish. This acts as a safety net to handle cases where workers finish at different times, to prevent agents from viewing workers that finished early as a scale-down event. It is strongly advised that the user code deal with ensuring that workers are terminated in a synchronous manner rather than relying on the exit_barrier_timeout. Example launching function :: def trainer(args) -> str: return "do train" def main(): start_method="spawn" shared_queue= multiprocessing.get_context(start_method).Queue() spec = WorkerSpec( role="trainer", local_world_size=nproc_per_process, entrypoint=trainer, args=("foobar",), ...<OTHER_PARAMS...>) agent = LocalElasticAgent(spec, start_method) results = if results.is_failed(): print("trainer failed") else: print(f"rank 0 return value: {results.return_values[0]}") # prints -> rank 0 return value: do train Example launching binary :: def main(): spec = WorkerSpec( role="trainer", local_world_size=nproc_per_process, entrypoint="/usr/local/bin/trainer", args=("--trainer_args", "foobar"), ...<OTHER_PARAMS...>) agent = LocalElasticAgent(spec) results = if not results.is_failed(): print("binary launches do not have return values") """ def __init__( self, spec: WorkerSpec, start_method="spawn", exit_barrier_timeout: float = 300, log_dir: Optional[str] = None, ): super().__init__(spec, exit_barrier_timeout) self._start_method = start_method self._pcontext: Optional[PContext] = None rdzv_run_id = spec.rdzv_handler.get_run_id() self._log_dir = self._make_log_dir(log_dir, rdzv_run_id) def _make_log_dir(self, log_dir: Optional[str], rdzv_run_id: str): base_log_dir = log_dir or tempfile.mkdtemp(prefix="torchelastic_") os.makedirs(base_log_dir, exist_ok=True) dir = tempfile.mkdtemp(prefix=f"{rdzv_run_id}_", dir=base_log_dir)"log directory set to: {dir}") return dir # pyre-fixme[56]: Pyre was not able to infer the type of the decorator # ``. @prof def _stop_workers(self, worker_group: WorkerGroup) -> None: self._shutdown() # pyre-fixme[56]: Pyre was not able to infer the type of the decorator # ``. @prof def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: spec = worker_group.spec store = assert store is not None master_addr, master_port = super()._get_master_addr_port(store) restart_count = spec.max_restarts - self._remaining_restarts use_agent_store = spec.rdzv_handler.get_backend() == "static" args: Dict[int, Tuple] = {} envs: Dict[int, Dict[str, str]] = {} for worker in worker_group.workers: local_rank = worker.local_rank worker_env = { "LOCAL_RANK": str(local_rank), "RANK": str(worker.global_rank), "GROUP_RANK": str(worker_group.group_rank), "ROLE_RANK": str(worker.role_rank), "ROLE_NAME": spec.role, "LOCAL_WORLD_SIZE": str(spec.local_world_size), "WORLD_SIZE": str(worker.world_size), "GROUP_WORLD_SIZE": str(worker_group.group_world_size), "ROLE_WORLD_SIZE": str(worker.role_world_size), "MASTER_ADDR": master_addr, "MASTER_PORT": str(master_port), "TORCHELASTIC_RESTART_COUNT": str(restart_count), "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts), "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(), "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store), "NCCL_ASYNC_ERROR_HANDLING": str(1), } if "OMP_NUM_THREADS" in os.environ: worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"] envs[local_rank] = worker_env worker_args = list(spec.args) worker_args = macros.substitute(worker_args, str(local_rank)) args[local_rank] = tuple(worker_args) # scaling events do not count towards restarts (gets same attempt #) # remove existing log dir if this restart is due to a scaling event attempt_log_dir = os.path.join(self._log_dir, f"attempt_{restart_count}") shutil.rmtree(attempt_log_dir, ignore_errors=True) os.makedirs(attempt_log_dir) assert spec.entrypoint is not None self._pcontext = start_processes( name=spec.role, entrypoint=spec.entrypoint, args=args, envs=envs, log_dir=attempt_log_dir, start_method=self._start_method, redirects=spec.redirects, tee=spec.tee, ) return self._pcontext.pids() def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None: if self._pcontext: self._pcontext.close(death_sig) # pyre-fixme[56]: Pyre was not able to infer the type of the decorator # ``. @prof def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: role = worker_group.spec.role worker_pids = { for w in worker_group.workers} assert self._pcontext is not None pc_pids = set(self._pcontext.pids().values()) if worker_pids != pc_pids: log.error( f"[{role}] worker pids do not match process_context pids." f" Expected: {worker_pids}, actual: {pc_pids}" ) return RunResult(state=WorkerState.UNKNOWN) result = self._pcontext.wait(0) if result: if result.is_failed(): # map local rank failure to global rank worker_failures = {} for local_rank, failure in result.failures.items(): worker = worker_group.workers[local_rank] worker_failures[worker.global_rank] = failure return RunResult( state=WorkerState.FAILED, failures=worker_failures, ) else: # copy ret_val_queue into a map with a global ranks workers_ret_vals = {} for local_rank, ret_val in result.return_values.items(): worker = worker_group.workers[local_rank] workers_ret_vals[worker.global_rank] = ret_val return RunResult( state=WorkerState.SUCCEEDED, return_values=workers_ret_vals, ) else: return RunResult(state=WorkerState.HEALTHY)


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources