Shortcuts

Source code for torch.distributed.elastic.agent.server.local_elastic_agent

#!/usr/bin/env python3
# mypy: allow-untyped-defs

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import json
import os
import signal
import socket
import time
import uuid
from string import Template
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING

import torch.distributed.elastic.timer as timer
from torch.distributed.elastic import events
from torch.distributed.elastic.agent.server.api import (
    RunResult,
    SimpleElasticAgent,
    WorkerGroup,
    WorkerSpec,
    WorkerState,
)
from torch.distributed.elastic.agent.server.health_check_server import (
    create_healthcheck_server,
    HealthCheckServer,
)
from torch.distributed.elastic.metrics.api import prof
from torch.distributed.elastic.multiprocessing import (
    LogsSpecs,
    PContext,
    start_processes,
)
from torch.distributed.elastic.utils import macros
from torch.distributed.elastic.utils.logging import get_logger


if TYPE_CHECKING:
    from torch.distributed.elastic.events.api import EventMetadataValue

logger = get_logger(__name__)

__all__ = [
    "LocalElasticAgent",
    "TORCHELASTIC_ENABLE_FILE_TIMER",
    "TORCHELASTIC_TIMER_FILE",
    "TORCHELASTIC_HEALTH_CHECK_PORT",
]

TORCHELASTIC_ENABLE_FILE_TIMER = "TORCHELASTIC_ENABLE_FILE_TIMER"
TORCHELASTIC_HEALTH_CHECK_PORT = "TORCHELASTIC_HEALTH_CHECK_PORT"
TORCHELASTIC_TIMER_FILE = "TORCHELASTIC_TIMER_FILE"


[docs]class LocalElasticAgent(SimpleElasticAgent): """An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers. This agent is deployed per host and is configured to spawn ``n`` workers. When using GPUs, ``n`` maps to the number of GPUs available on the host. The local agent does not communicate to other local agents deployed on other hosts, even if the workers may communicate inter-host. The worker id is interpreted to be a local process. The agent starts and stops all worker processes as a single unit. The worker function and argument passed to the worker function must be python multiprocessing compatible. To pass multiprocessing data structures to the workers you may create the data structure in the same multiprocessing context as the specified ``start_method`` and pass it as a function argument. The ``exit_barrier_timeout`` specifies the amount of time (in seconds) to wait for other agents to finish. This acts as a safety net to handle cases where workers finish at different times, to prevent agents from viewing workers that finished early as a scale-down event. It is strongly advised that the user code deal with ensuring that workers are terminated in a synchronous manner rather than relying on the exit_barrier_timeout. A named pipe based watchdog can be enabled in ```LocalElasticAgent``` if an environment variable ``TORCHELASTIC_ENABLE_FILE_TIMER`` with value 1 has been defined in the ```LocalElasticAgent``` process. Optionally, another environment variable ```TORCHELASTIC_TIMER_FILE``` can be set with a unique file name for the named pipe. If the environment variable ```TORCHELASTIC_TIMER_FILE``` is not set, ```LocalElasticAgent``` will internally create a unique file name and set it to the environment variable ```TORCHELASTIC_TIMER_FILE```, and this environment variable will be propagated to the worker processes to allow them to connect to the same named pipe that ```LocalElasticAgent``` uses. Logs are written to the specified log directory. Each log line will be by default prefixed by ``[${role_name}${local_rank}]:`` (e.g. ``[trainer0]: foobar``). Log prefixes can be customized by passing a `template string <https://docs.python.org/3/library/string.html#template-strings>`_ as the ``log_line_prefix_template`` argument. The following macros (identifiers) are substituted at runtime: ``${role_name}, ${local_rank}, ${rank}``. For example, to prefix each log line with global rank instead of the local rank, set ``log_line_prefix_template = "[${rank}]:``. Example launching function :: def trainer(args) -> str: return "do train" def main(): start_method="spawn" shared_queue= multiprocessing.get_context(start_method).Queue() spec = WorkerSpec( role="trainer", local_world_size=nproc_per_process, entrypoint=trainer, args=("foobar",), ...<OTHER_PARAMS...>) agent = LocalElasticAgent(spec, start_method) results = agent.run() if results.is_failed(): print("trainer failed") else: print(f"rank 0 return value: {results.return_values[0]}") # prints -> rank 0 return value: do train Example launching binary :: def main(): spec = WorkerSpec( role="trainer", local_world_size=nproc_per_process, entrypoint="/usr/local/bin/trainer", args=("--trainer-args", "foobar"), ...<OTHER_PARAMS...>) agent = LocalElasticAgent(spec) results = agent.run() if not results.is_failed(): print("binary launches do not have return values") """ def __init__( self, spec: WorkerSpec, logs_specs: LogsSpecs, start_method="spawn", exit_barrier_timeout: float = 300, log_line_prefix_template: Optional[str] = None, ): super().__init__(spec, exit_barrier_timeout) self._start_method = start_method self._pcontext: Optional[PContext] = None self._rdzv_handler = spec.rdzv_handler self._log_line_prefix_template = log_line_prefix_template self._worker_watchdog: Optional[timer.FileTimerServer] = None self._logs_specs = logs_specs self._health_check_server: Optional[HealthCheckServer] = None def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None: enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER watchdog_enabled = os.getenv(enable_watchdog_env_name) watchdog_file_env_name = TORCHELASTIC_TIMER_FILE watchdog_file_path = os.getenv(watchdog_file_env_name) if watchdog_enabled is not None and str(watchdog_enabled) == "1": if watchdog_file_path is None: watchdog_file_path = "/tmp/watchdog_timer_" + str(uuid.uuid4()) logger.info("Starting a FileTimerServer with %s ...", watchdog_file_path) if not envs: logger.warning( "Empty envs variables, using empty run_id for FileTimerServer" ) run_id = "" else: run_id = envs[0]["TORCHELASTIC_RUN_ID"] self._worker_watchdog = timer.FileTimerServer( file_path=watchdog_file_path, run_id=run_id, max_interval=0.1, daemon=True, log_event=self._log_watchdog_event, ) self._worker_watchdog.start() logger.info("FileTimerServer started") else: logger.info( "Environment variable '%s' not found. Do not start FileTimerServer.", enable_watchdog_env_name, ) # Propagate the watchdog file env to worker processes if watchdog_file_path is not None: for worker_env in envs.values(): worker_env[watchdog_file_env_name] = watchdog_file_path @staticmethod def _get_current_time_secs() -> int: return int(time.time()) def _setup_healthcheck(self) -> None: healthcheck_port_env_name = TORCHELASTIC_HEALTH_CHECK_PORT healthcheck_port = os.getenv(healthcheck_port_env_name) if healthcheck_port is not None: logger.info( "Found healthcheck port %s: %s", healthcheck_port_env_name, healthcheck_port, ) if self._worker_watchdog is None: logger.info( "FileTimerServer doesn't exist, using current time as dummy callback" ) alive_callback = LocalElasticAgent._get_current_time_secs else: alive_callback = self._worker_watchdog.get_last_progress_time try: healthcheck_port_as_int = int(healthcheck_port) self._health_check_server = create_healthcheck_server( alive_callback=alive_callback, port=healthcheck_port_as_int, timeout=60, ) self._health_check_server.start() except ValueError: logger.info( "Invalid healthcheck port value: '%s', expecting integer. Not starting healthcheck server.", healthcheck_port, ) else: logger.info( "Environment variable '%s' not found. Do not start health check.", healthcheck_port_env_name, ) def _get_fq_hostname(self) -> str: return socket.getfqdn(socket.gethostname()) def _log_watchdog_event( self, name: str, request: Optional[timer.FileTimerRequest], ) -> None: wg = self._worker_group spec = wg.spec md = {"watchdog_event": name} if request is not None: md["worker_pid"] = str(request.worker_pid) md["scope_id"] = request.scope_id md["expiration_time"] = str(request.expiration_time) md["signal"] = str(request.signal) md_str = json.dumps(md) state = "RUNNING" metadata: Dict[str, EventMetadataValue] = { "run_id": spec.rdzv_handler.get_run_id(), "global_rank": None, "group_rank": wg.group_rank, "worker_id": None, "role": spec.role, "hostname": self._get_fq_hostname(), "state": state, "total_run_time": self._total_execution_time, "rdzv_backend": spec.rdzv_handler.get_backend(), "raw_error": None, "metadata": md_str, "agent_restarts": spec.max_restarts - self._remaining_restarts, } # Note: The 'metadata' field of the Event is converted to a TorchelasticStatusLogEntry later. # The 'name' field of the Event is NOT used in the TorchelasticStatusLogEntry. event = events.Event( name=name, source=events.EventSource.AGENT, metadata=metadata ) events.record(event) # pyre-fixme[56]: Pyre was not able to infer the type of the decorator # `torch.distributed.elastic.metrics.prof`. @prof def _stop_workers( self, worker_group: WorkerGroup, is_restart: bool = False ) -> None: self._shutdown(is_restart=is_restart) # pyre-fixme[56]: Pyre was not able to infer the type of the decorator # `torch.distributed.elastic.metrics.prof`. @prof def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: spec = worker_group.spec store = worker_group.store assert store is not None restart_count = spec.max_restarts - self._remaining_restarts use_agent_store: bool = spec.rdzv_handler.use_agent_store logger.info("use_agent_store: %s", use_agent_store) args: Dict[int, Tuple] = {} envs: Dict[int, Dict[str, str]] = {} log_line_prefixes: Optional[Dict[int, str]] = ( {} if self._log_line_prefix_template else None ) for worker in worker_group.workers: local_rank = worker.local_rank worker_env = { "LOCAL_RANK": str(local_rank), "RANK": str(worker.global_rank), "GROUP_RANK": str(worker_group.group_rank), "ROLE_RANK": str(worker.role_rank), "ROLE_NAME": spec.role, "LOCAL_WORLD_SIZE": str(spec.local_world_size), "WORLD_SIZE": str(worker.world_size), "GROUP_WORLD_SIZE": str(worker_group.group_world_size), "ROLE_WORLD_SIZE": str(worker.role_world_size), "MASTER_ADDR": worker_group.master_addr, "MASTER_PORT": str(worker_group.master_port), "TORCHELASTIC_RESTART_COUNT": str(restart_count), "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts), "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(), "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store), "TORCH_NCCL_ASYNC_ERROR_HANDLING": os.getenv( "TORCH_NCCL_ASYNC_ERROR_HANDLING", str(1) ), } if "OMP_NUM_THREADS" in os.environ: worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"] if self._log_line_prefix_template: log_line_prefix = Template( self._log_line_prefix_template ).safe_substitute( role_name=spec.role, rank=worker.global_rank, local_rank=local_rank, ) log_line_prefixes[local_rank] = log_line_prefix envs[local_rank] = worker_env worker_args = list(spec.args) worker_args = macros.substitute(worker_args, str(local_rank)) args[local_rank] = tuple(worker_args) self._setup_local_watchdog(envs=envs) self._setup_healthcheck() assert spec.entrypoint is not None assert self._logs_specs is not None self._pcontext = start_processes( name=spec.role, entrypoint=spec.entrypoint, args=args, envs=envs, logs_specs=self._logs_specs, log_line_prefixes=log_line_prefixes, start_method=self._start_method, ) return self._pcontext.pids() def _shutdown( self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False ) -> None: if self._worker_watchdog is not None: self._worker_watchdog.stop() self._worker_watchdog = None if self._health_check_server is not None: self._health_check_server.stop() self._health_check_server = None if self._pcontext: self._pcontext.close(death_sig) if not is_restart and self._rdzv_handler: self._rdzv_handler.shutdown() # pyre-fixme[56]: Pyre was not able to infer the type of the decorator # `torch.distributed.elastic.metrics.prof`. @prof def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: role = worker_group.spec.role worker_pids = {w.id for w in worker_group.workers} assert self._pcontext is not None pc_pids = set(self._pcontext.pids().values()) if worker_pids != pc_pids: logger.error( "[%s] worker pids do not match process_context pids." " Expected: %s, actual: %s", role, worker_pids, pc_pids, ) return RunResult(state=WorkerState.UNKNOWN) result = self._pcontext.wait(0) if result: if result.is_failed(): # map local rank failure to global rank worker_failures = {} for local_rank, failure in result.failures.items(): worker = worker_group.workers[local_rank] worker_failures[worker.global_rank] = failure return RunResult( state=WorkerState.FAILED, failures=worker_failures, ) else: # copy ret_val_queue into a map with a global ranks workers_ret_vals = {} for local_rank, ret_val in result.return_values.items(): worker = worker_group.workers[local_rank] workers_ret_vals[worker.global_rank] = ret_val return RunResult( state=WorkerState.SUCCEEDED, return_values=workers_ret_vals, ) else: return RunResult(state=WorkerState.HEALTHY)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources