Source code for torchelastic.agent.server.local_elastic_agent

#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os
from typing import Any, Dict

import torch.multiprocessing as mp
from torchelastic.agent.server.api import (
from torchelastic.metrics.api import prof

log = logging.getLogger(__name__)

class _DistInfo:
    Container for information required to create a torch process group.
    To be created on the agent's process and passed to the worker sub-process.
    Hence this object needs to be a pure data object with no state and
    preferably only primitive member variables

    __slots__ = [

    def __init__(
        rank: int,
        group_rank: int,
        local_world_size: int,
        world_size: int,
        master_addr: str,
        master_port: int,
        restart_count: int,
        max_restarts: int,
        self.rank = rank
        self.group_rank = group_rank
        self.local_world_size = local_world_size
        self.world_size = world_size
        self.master_addr = master_addr
        self.master_port = master_port
        self.restart_count = restart_count
        self.max_restarts = max_restarts

def _wrap(local_rank, ret_vals, dist_infos, fn, args):
    info = dist_infos[local_rank]
    os.environ["LOCAL_RANK"] = str(local_rank)
    os.environ["RANK"] = str(info.rank)
    os.environ["GROUP_RANK"] = str(info.group_rank)
    os.environ["LOCAL_WORLD_SIZE"] = str(info.local_world_size)
    os.environ["WORLD_SIZE"] = str(info.world_size)
    os.environ["MASTER_ADDR"] = info.master_addr
    os.environ["MASTER_PORT"] = str(info.master_port)
    os.environ["TORCHELASTIC_RESTART_COUNT"] = str(info.restart_count)
    os.environ["TORCHELASTIC_MAX_RESTARTS"] = str(info.max_restarts)
    ret = fn(*args)
    ret_vals[info.rank] = ret

[docs]class LocalElasticAgent(SimpleElasticAgent): """ An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers. This agent is deployed per host and is configured to spawn ``n`` workers. When using GPUs, ``n`` maps to the number of GPUs available on the host. The local agent does not communicate to other local agents deployed on other hosts, even if the workers may communicate inter-host. The worker id is interpreted to be a local process. The agent starts and stops all worker processes as a single unit. The worker function and argument passed to the worker function must be python multiprocessing compatible. To pass multiprocessing data structures to the workers you may create the data structure in the same multiprocessing context as the specified ``start_method`` and pass it as a function argument. Example :: def trainer(shared_queue): pass def main(): start_method="spawn" shared_queue= multiprocessing.get_context(start_method).Queue() spec = WorkerSpec( role="trainer", local_world_size=nproc_per_process, fn=trainer, args=(shared_queue,), ...<OTHER_PARAMS...>) agent = LocalElasticAgent(spec, start_method) """ def __init__(self, spec: WorkerSpec, start_method="spawn"): super().__init__(spec) self._start_method = start_method # pyre-fixme[8]: Attribute has type `ProcessContext`; used as `None`. self._process_context: mp.ProcessContext = None # a map that holds return values for each worker fn # ret_val[0] holds the return value for worker_0 (global rank 0) self._manager = mp.get_context(start_method).Manager() self._ret_vals = self._manager.dict() @prof def _stop_workers(self, worker_group: WorkerGroup) -> None: for proc in self._process_context.processes: if proc.is_alive(): proc.terminate() proc.join() @prof def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: spec = worker_group.spec store = master_addr, master_port = super()._get_master_addr_port(store) restart_count = spec.max_restarts - self._remaining_restarts dist_infos: Dict[int, _DistInfo] = {} for worker in worker_group.workers: local_rank = worker.local_rank dist_infos[local_rank] = _DistInfo( worker.global_rank, worker_group.group_rank, worker_group.spec.local_world_size, worker.world_size, master_addr, master_port, restart_count, spec.max_restarts, ) self._ret_vals.clear() self._process_context = mp.start_processes( fn=_wrap, args=(self._ret_vals, dist_infos, spec.fn, spec.args), nprocs=spec.local_world_size, join=False, daemon=False, start_method=self._start_method, ) return { local_rank: pid for local_rank, pid in enumerate(self._process_context.pids()) } @prof def _monitor_workers(self, worker_group: WorkerGroup) -> MonitorResult: role = worker_group.spec.role # torch process context join() isn't really a join in the # traditional sense, it returns True if all the workers have # successfully finished, False if some/all are still running # and throws an Exception if some/all of them failed # passing timeout < 0 means check worker status and return immediately worker_pids = { for w in worker_group.workers} pc_pids = set(self._process_context.pids()) if worker_pids != pc_pids: log.error(f"[{role}] worker pids do not match process_context pids") return MonitorResult(WorkerState.UNKNOWN) try: if self._process_context.join(timeout=-1): # copy ret_vals since we do not want to return an mp map return MonitorResult(WorkerState.SUCCEEDED, dict(self._ret_vals)) else: return MonitorResult(WorkerState.HEALTHY) except Exception as e: log.exception(f"[{role}] Worker group failed") return MonitorResult( WorkerState.FAILED, exceptions={w.global_rank: e for w in worker_group.workers}, )


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources