Shortcuts

Elastic Agent

Server

The elastic agent is the control plane of torchelastic. It is a process that launches and manages underlying worker processes. The agent is responsible for:

  1. Working with distributed torch: the workers are started with all the necessary information to successfully and trivially call torch.distributed.init_process_group().

  2. Fault tolerance: monitors workers and upon detecting worker failures or unhealthiness, tears down all workers and restarts everyone.

  3. Elasticity: Reacts to membership changes and restarts workers with the new members.

The simplest agents are deployed per node and works with local processes. A more advanced agent can launch and manage workers remotely. Agents can be completely decentralized, making decisions based on the workers it manages. Or can be coordinated, communicating to other agents (that manage workers in the same job) to make a collective decision.

Below is a diagram of an agent that manages a local group of workers.

../_images/agent_diagram.jpg

Concepts

This section describes the high-level classes and concepts that are relevant to understanding the role of the agent in torchelastic.

class torch.distributed.elastic.agent.server.ElasticAgent[source]

Agent process responsible for managing one or more worker processes. The worker processes are assumed to be regular distributed PyTorch scripts. When the worker process is created by the agent, the agent provides the necessary information for the worker processes to properly initialize a torch process group.

The exact deployment topology and ratio of agent-to-worker is dependent on the specific implementation of the agent and the user’s job placement preferences. For instance, to run a distributed training job on GPU with 8 trainers (one per GPU) one can:

  1. Use 8 x single GPU instances, place an agent per instance, managing 1 worker per agent.

  2. Use 4 x double GPU instances, place an agent per instance, managing 2 workers per agent.

  3. Use 2 x quad GPU instances, place an agent per instance, managing 4 workers per agent.

  4. Use 1 x 8 GPU instance, place an agent per instance, managing 8 workers per agent.

Usage

group_result = agent.run()
 if group_result.is_failed():
   # workers failed
   failure = group_result.failures[0]
   log.exception(f"worker 0 failed with exit code : {failure.exit_code}")
 else:
   return group_result.return_values[0] # return rank 0's results
abstract get_worker_group(role='default')[source]
Returns

The WorkerGroup for the given role. Note that the worker group is a mutable object and hence in a multi-threaded/process environment it may change state. Implementors are encouraged (but not required) to return a defensive read-only copy.

abstract run(role='default')[source]

Runs the agent, retrying the worker group on failures up to max_restarts.

Returns

The result of the execution, containing the return values or failure details for each worker mapped by the worker’s global rank.

Raises

Exception - any other failures NOT related to worker process

class torch.distributed.elastic.agent.server.WorkerSpec(role, local_world_size, rdzv_handler, fn=None, entrypoint=None, args=(), max_restarts=3, monitor_interval=30.0, master_port=None, master_addr=None, redirects=<Std.NONE: 0>, tee=<Std.NONE: 0>)[source]

Contains blueprint information about a particular type of worker. For a given role, there must only exist a single worker spec. Worker spec is expected to be homogenous across all nodes (machine), that is each node runs the same number of workers for a particular spec.

Parameters
  • role – user-defined role for the workers with this spec

  • local_world_size – number local workers to run

  • fn – (deprecated use entrypoint instead)

  • entrypoint – worker function or command

  • args – arguments to pass to entrypoint

  • rdzv_handler – handles rdzv for this set of workers

  • max_restarts – number of max retries for the workers

  • monitor_interval – monitor status of workers every n seconds

  • master_port – fixed port to run the c10d store on rank 0 if not specified then will chose a random free port

  • master_addr – fixed master_addr to run the c10d store on rank 0 if not specified then will chose hostname on agent rank 0

  • redirects – redirect std streams to a file, selectively redirect for a particular local rank by passing a map

  • tee – tees the specified std stream(s) to console + file, selectively tee for a particular local rank by passing a map, takes precedence over redirects settings.

get_entrypoint_name()[source]

If the entrypoint is a function (e.g. Callable) returns its __qualname__, else if the entrypoint is a binary (e.g. str), returns the binary name.

class torch.distributed.elastic.agent.server.WorkerState(value)[source]

State of the WorkerGroup. Workers in a worker group change state as a unit. If a single worker in a worker group fails the entire set is considered failed:

UNKNOWN - agent lost track of worker group state, unrecoverable
INIT - worker group object created not yet started
HEALTHY - workers running and healthy
UNHEALTHY - workers running and unhealthy
STOPPED - workers stopped (interruped) by the agent
SUCCEEDED - workers finished running (exit 0)
FAILED - workers failed to successfully finish (exit !0)

A worker group starts from an initial INIT state, then progresses to HEALTHY or UNHEALTHY states, and finally reaches a terminal SUCCEEDED or FAILED state.

Worker groups can be interrupted and temporarily put into STOPPED state by the agent. Workers in STOPPED state are scheduled to be restarted in the near future by the agent. Some examples of workers being put into STOPPED state are:

  1. Worker group failure|unhealthy observed

  2. Membership change detected

When actions (start, stop, rdzv, retry, etc) on worker group fails and results in the action being partially applied to the worker group the state will be UNKNOWN. Typically this happens on uncaught/unhandled exceptions during state change events on the agent. The agent is not expected to recover worker groups in UNKNOWN state and is better off self terminating and allowing the job manager to retry the node.

static is_running(state)[source]
Returns

True if the worker state represents workers still running (e.g. that the process exists but not necessarily healthy).

class torch.distributed.elastic.agent.server.Worker(local_rank, global_rank=- 1, role_rank=- 1, world_size=- 1, role_world_size=- 1)[source]

Represents a worker instance. Contrast this with WorkerSpec that represents the specifications of a worker. A Worker is created from a WorkerSpec. A Worker is to a WorkerSpec as an object is to a class.

The id of the worker is interpreted by the specific implementation of ElasticAgent. For a local agent, it could be the pid (int) of the worker, for a remote agent it could be encoded as host:port (string).

Parameters
  • id (Any) – uniquely identifies a worker (interpreted by the agent)

  • local_rank (int) – local rank of the worker

  • global_rank (int) – global rank of the worker

  • role_rank (int) – rank of the worker across all workers that have the same role

  • world_size (int) – number of workers (globally)

  • role_world_size (int) – number of workers that have the same role

class torch.distributed.elastic.agent.server.WorkerGroup(spec)[source]

Represents the set of Worker instances for the given WorkerSpec managed by ElasticAgent. Whether the worker group contains cross instance workers or not depends on the implementation of the agent.

Implementations

Below are the agent implementations provided by torchelastic.

class torch.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent(spec, start_method='spawn', exit_barrier_timeout=300, log_dir=None)[source]

An implementation of torchelastic.agent.server.ElasticAgent that handles host-local workers. This agent is deployed per host and is configured to spawn n workers. When using GPUs, n maps to the number of GPUs available on the host.

The local agent does not communicate to other local agents deployed on other hosts, even if the workers may communicate inter-host. The worker id is interpreted to be a local process. The agent starts and stops all worker processes as a single unit.

The worker function and argument passed to the worker function must be python multiprocessing compatible. To pass multiprocessing data structures to the workers you may create the data structure in the same multiprocessing context as the specified start_method and pass it as a function argument.

The exit_barrier_timeout specifies the amount of time (in seconds) to wait for other agents to finish. This acts as a safety net to handle cases where workers finish at different times, to prevent agents from viewing workers that finished early as a scale-down event. It is strongly advised that the user code deal with ensuring that workers are terminated in a synchronous manner rather than relying on the exit_barrier_timeout.

Example launching function

def trainer(args) -> str:
    return "do train"

def main():
    start_method="spawn"
    shared_queue= multiprocessing.get_context(start_method).Queue()
    spec = WorkerSpec(
                role="trainer",
                local_world_size=nproc_per_process,
                entrypoint=trainer,
                args=("foobar",),
                ...<OTHER_PARAMS...>)
    agent = LocalElasticAgent(spec, start_method)
    results = agent.run()

    if results.is_failed():
        print("trainer failed")
    else:
        print(f"rank 0 return value: {results.return_values[0]}")
        # prints -> rank 0 return value: do train

Example launching binary

def main():
    spec = WorkerSpec(
                role="trainer",
                local_world_size=nproc_per_process,
                entrypoint="/usr/local/bin/trainer",
                args=("--trainer_args", "foobar"),
                ...<OTHER_PARAMS...>)
    agent = LocalElasticAgent(spec)
    results = agent.run()

    if not results.is_failed():
        print("binary launches do not have return values")

Extending the Agent

To extend the agent you can implement `ElasticAgent directly, however we recommend you extend SimpleElasticAgent instead, which provides most of the scaffolding and leaves you with a few specific abstract methods to implement.

class torch.distributed.elastic.agent.server.SimpleElasticAgent(spec, exit_barrier_timeout=300)[source]

An ElasticAgent that manages workers (WorkerGroup) for a single WorkerSpec (e.g. one particular type of worker role).

_assign_worker_ranks(store, group_rank, group_world_size, spec)[source]

Determines proper ranks for worker processes. The rank assignment is done according to the following algorithm:

  1. Each agent writes its configuration(group_rank, group_world_size , num_workers) to the common store.

  2. Each agent retrieves configuration for all agents and performs two level sort using role and rank.

  3. Determine the global rank: the global rank of the workers for the current agent is the offset of the infos array up to group_rank of the agent. The offset is computed as a sum of local_world_size of all agents that have rank less than the group_rank. The workers would have the ranks: [offset, offset+local_world_size)

  4. Determine the role rank: The role rank is determined using the algorithms in the point 3 with the exception that the offset is done from the first agent that has the same role as current one and has the minimum group rank.

_exit_barrier()[source]

Wait for exit_barrier_timeout seconds for all agents to finish executing their local workers (either successfully or not). This acts as a safety guard against user scripts that terminate at different times. This barrier keeps the agent process alive until all workers finish.

_initialize_workers(worker_group)[source]

Starts a fresh set of workers for the worker_group. Essentially a rendezvous followed by a start_workers.

The caller should first call _stop_workers() to stop running workers prior to calling this method.

Optimistically sets the state of the worker group that just started as HEALTHY and delegates the actual monitoring of state to _monitor_workers() method

abstract _monitor_workers(worker_group)[source]

Checks on the workers for the worker_group and returns the new state of the worker group.

_rendezvous(worker_group)[source]

Runs rendezvous for the workers specified by worker spec. Assigns workers a new global rank and world size. Updates the rendezvous store for the worker group.

_restart_workers(worker_group)[source]

Restarts (stops, rendezvous, starts) all local workers in the group.

abstract _shutdown(death_sig=<Signals.SIGTERM: 15>)[source]

Cleans up any resources that were allocated during the agent’s work.

Parameters

death_sig – Signal to send to the child process, SIGTERM is default

abstract _start_workers(worker_group)[source]

Starts worker_group.spec.local_world_size number of workers according to worker spec for the worker group .

Returns a map of local_rank to worker id.

abstract _stop_workers(worker_group)[source]

Stops all workers in the given worker group. Implementors must deal with workers in all states defined by WorkerState. That is, it must gracefully handle stopping non-existent workers, unhealthy (stuck) workers, etc.

class torch.distributed.elastic.agent.server.api.RunResult(state, return_values=<factory>, failures=<factory>)[source]

Results returned by the worker executions. Run results follow an “all-or-nothing” policy where the run is successful if and only if ALL local workers managed by this agent complete successfully.

If the result is successful (e.g. is_failed() = False) then the return_values field contains the outputs (return values) of the workers managed by THIS agent mapped by their GLOBAL ranks. That is result.return_values[0] is the return value of global rank 0.

Note

return_values are only meaningful for when the worker entrypoint is a function. Workers specified as a binary entrypoint do not canonically have a return value and the return_values field is meaningless and may be empty.

If is_failed() returns True then the failures field contains the failure information, again, mapped by the GLOBAL rank of the worker that failed.

The keys in return_values and failures are mutually exclusive, that is, a worker’s final state can only be one of: succeeded, failed. Workers intentionally terminated by the agent according to the agent’s restart policy, are not represented in either return_values nor failures.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources