Source code for torchft.manager
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Manager
=========
This module implements the Manager that manages the full fault tolerant training
loop.
The Manager is responsible for managing the
full training loop, communicating with the Lighthouse server to figure out
quorum, reconfiguring the ProcessGroups and restoring checkpoint state when
recovering.
This uses wrapper classes to wrap the standard PyTorch Optimizer and Module
classes to provide fault tolerance. These wrappers indented to add fault
tolerance with minimal changes to the users modeling code and training loop.
This is designed to work with the standard PyTorch DistributedDataParallel module
and Hybrid FSDP.
"""
import concurrent.futures
import logging
import os
import socket
import uuid
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from enum import Enum
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast
import torch
from torch.distributed import ReduceOp, TCPStore
from torchft.checkpointing import CheckpointServer
from torchft.futures import future_timeout
from torchft.torchft import Manager as _Manager, ManagerClient
if TYPE_CHECKING:
from torchft.process_group import ProcessGroup
MANAGER_ADDR_KEY: str = "manager_addr"
MANAGER_PORT_ENV: str = "TORCHFT_MANAGER_PORT"
REPLICA_ID_KEY: str = "replica_id"
T = TypeVar("T")
[docs]class WorldSizeMode(Enum):
"""
This controls the numerics for the job when doing allreduces across replicas
when the world size is larger than ``min_replica_size``. The world size will
never be smaller than ``min_replica_size``.
DYNAMIC:
The world size will dynamical increase to use all available
replicas and normalize the gradient by the world size.
FIXED_WITH_SPARES:
The number of active replicas is ``min_replica_size`` and any spares
will contribute zero gradients.
"""
DYNAMIC = 0
FIXED_WITH_SPARES = 1
[docs]class Manager:
"""
Manager manages the full fault tolerant training loop.
This requires the that the TCPStore specified by the store_addr and
store_port or MASTER_ADDR and MASTER_PORT environment variables to be
started prior to creating this manager. If using a modern version of
torchelastic this will already be the case. Otherwise, it should be started
via torch.distributed.init_process_group prior to creating this manager.
NOTE: when saving periodic checkpoints you must save and restore the
Manager's state_dict as well to avoid synchronization issues.
"""
def __init__(
self,
pg: "ProcessGroup",
load_state_dict: Callable[[T], None],
state_dict: Callable[[], T],
min_replica_size: int,
use_async_quorum: bool = True,
timeout: timedelta = timedelta(seconds=60),
rank: Optional[int] = None,
world_size: Optional[int] = None,
world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC,
store_addr: Optional[str] = None,
store_port: Optional[int] = None,
lighthouse_addr: Optional[str] = None,
replica_id: Optional[str] = None,
port: Optional[int] = None,
hostname: str = socket.gethostname(),
) -> None:
"""
Args:
load_state_dict: function to load the state dict when recovering
state_dict: function to save the state dict with recovering
min_replica_size: minimum number of replicas on each step
port: if rank==0, the port to run the manager server on.
Port assignment priority:
1. this argument
2. TORCHFT_MANAGER_PORT env var
3. arbitrary port assigned via 0
use_async_quorum: whether to run the quorum asynchronously during the forward pass
timeout:
the default timeout for all operation, if you're using per
request timeouts this should be longer than the longest request
timeout.
rank: the replica group local rank
world_size: the replica group local world size
store_addr: TCPStore address for this replica group
store_port: TCPStore port for this replica group
lighthouse_addr: if rank==0, the address of the lighthouse server
replica_id: if rank==0, the replica_id for this group
hostname: if rank==0, the hostname to advertise to the lighthouse server
"""
self._load_state_dict = load_state_dict
self._state_dict = state_dict
self._pending_state_dict: Optional[Dict[str, object]] = None
self._use_async_quorum = use_async_quorum
self._timeout = timeout
self._world_size_mode = world_size_mode
store_addr = store_addr or os.environ["MASTER_ADDR"]
store_port = store_port or int(os.environ["MASTER_PORT"])
self._rank: int = rank if rank is not None else int(os.environ["RANK"])
rank = self._rank
world_size = world_size or int(os.environ["WORLD_SIZE"])
self._min_replica_size = min_replica_size
def _manager_state_dict() -> Dict[str, T]:
return {
"user": state_dict(),
"torchft": cast(T, self.state_dict()),
}
self._ckpt_server = CheckpointServer[Dict[str, T]](_manager_state_dict)
self._executor = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="async_quorum"
)
self._quorum_future: Optional[concurrent.futures.Future] = None
self._store = TCPStore(
host_name=store_addr,
port=store_port,
is_master=False,
wait_for_workers=False,
)
self._pg = pg
self._manager: Optional[_Manager] = None
if rank == 0:
if port is None:
port = int(os.environ.get(MANAGER_PORT_ENV, 0))
bind = f"[::]:{port}"
lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"]
if replica_id is None:
replica_id = ""
replica_id = replica_id + str(uuid.uuid4())
self._manager = _Manager(
replica_id=replica_id,
lighthouse_addr=lighthouse_addr,
hostname=hostname,
bind=bind,
store_addr=f"{store_addr}:{store_port}",
world_size=world_size,
)
self._store.set(MANAGER_ADDR_KEY, self._manager.address())
self._store.set(REPLICA_ID_KEY, replica_id)
addr = self._store.get(MANAGER_ADDR_KEY).decode("utf-8")
self._client = ManagerClient(addr, timeout=timeout)
replica_id = self._store.get(REPLICA_ID_KEY).decode("utf-8")
self._logger = _ManagerLogger(
manager=self, replica_id=replica_id or "", rank=rank
)
self._step = 0
self._quorum_id = -1
self._errored: Optional[Exception] = None
self._healing = False
self._pending_work: List[torch.futures.Future[object]] = []
self._batches_committed = 0
# first step is 1
self._participating_rank: Optional[int] = None
self._participating_world_size: int = 0
[docs] def shutdown(self) -> None:
"""
Shutdown the manager and checkpoint server.
"""
self._ckpt_server.shutdown()
if self._manager is not None:
self._manager.shutdown()
self._executor.shutdown()
[docs] def allreduce(self, tensor: torch.Tensor) -> torch.futures.Future[torch.Tensor]:
"""
Fault tolerant allreduce the tensor and return a Future that will be completed when
the tensor is ready.
This will automatically scale the tensor by 1 / world_size.
If an error occurs during the allreduce:
* The Future will be completed with no error and instead tracked asynchronously.
* After the first error, all subsequent calls will be noops and immediately return.
* The tensor must be zeroed before being used as it may be corrupted.
Args:
tensor: the tensor to allreduce
Returns:
a Future that will be completed with the allreduced tensor
"""
if self.errored():
fut = torch.futures.Future() # pyre-fixme[29]: not a function
fut.set_result(tensor)
return fut
self.wait_quorum()
if not self.is_participating():
tensor.zero_()
# TODO: increase timeout when waiting when healing
try:
# Run the allreduce async and save the work object so we can wait on
# it later.
work = self._pg.allreduce([tensor], ReduceOp.SUM)
fut = work.get_future()
# schedule grad normalization as a continuation
# on the Future
def callback(
fut: torch.futures.Future[List[torch.Tensor]],
) -> torch.Tensor:
nonlocal tensor
# check for exceptions
fut.value()
tensor /= self.num_participants()
return tensor
fut = fut.then(callback)
fut = self.wrap_future(fut, tensor)
return fut
except Exception as e:
self._logger.exception(
f"got exception in all reduce -- skipping remaining: {e}"
)
self.report_error(e)
fut = torch.futures.Future() # pyre-fixme[29]: not a function
fut.set_result(tensor)
return fut
[docs] def report_error(self, e: Exception) -> None:
"""
Report an error to the manager.
This will cause the manager to skip the current step and will be
reconfigured on the next step.
This should be called when an error occurs that leads to a corrupted
gradient that needs to be discarded.
"""
self._errored = e
[docs] def errored(self) -> Optional[Exception]:
"""
Get whether an error has occurred.
Returns:
The error or None if no error has occurred.
"""
return self._errored
[docs] def wrap_future(
self,
fut: torch.futures.Future[T],
default: T,
timeout: Optional[timedelta] = None,
) -> torch.futures.Future[T]:
"""
Wrap a Future and swallow any errors that occur and report them to the manager.
If an error occurs, the Future will be completed with the default value.
Args:
fut: the Future to wrap
default: the default value to complete the Future with if an error occurs
timeout: the timeout for the Future, if None, the manager's timeout will be used
"""
# add a timeout to the future
fut = future_timeout(fut, timeout or self._timeout)
# schedule error handling as a continuation on the Future
def callback(
fut: torch.futures.Future[T],
) -> T:
nonlocal default
try:
return fut.value()
except Exception as e:
self._logger.exception(
f"got exception in future -- skipping remaining: {e}"
)
self.report_error(e)
return default
fut = fut.then(callback)
self._pending_work.append(cast(torch.futures.Future[object], fut))
return fut
[docs] def start_quorum(
self,
room_id: str = "default",
allow_heal: bool = True,
timeout: Optional[timedelta] = None,
) -> None:
"""
.. note::
We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly.
Computes a new quorum (potentially asynchronously) and readies the
manager for a new step.
It's best practice to call this before the forwards pass of each step for
performance as computing quorum may take some time.
Args:
allow_heal: (experimental) whether to allow healing at the beginning of the step
If allow_heal is set, the manager will attempt to heal either
synchronously before returning or asynchronously prior to any network
calls. All replicas must pass the same value to allow_heal.
room_id: (experimental) the room id to use for quorum, this allows
for multiple quorums to be used within the same job.
timeout: the timeout for quorum and recovery operations, if None, the manager's timeout will be used
"""
# wait for previous quorum to complete
if self._quorum_future is not None:
self._quorum_future.result()
self._errored = None
self._healing = False
self._ckpt_server.allow_checkpoint(self._step)
# TODO: we should really be wrapping this whole section in a try-except
# block to allow gracefully recovering from issues in PG setup and quorum.
self._quorum_future = self._executor.submit(
self._async_quorum,
room_id=room_id,
allow_heal=allow_heal,
timeout=timeout or self._timeout,
)
if not self._use_async_quorum:
self.wait_quorum()
if self._healing:
# eagerly apply pending state_dict so we can run the forwards pass
self._apply_pending_state_dict()
# we are forcing healing at the beginning so we're in a good state
# and don't need to zero_grad
self._healing = False
[docs] def wait_quorum(self) -> None:
"""
Wait for the quorum to complete.
ProcessGroup will be in a healthy state after this returns.
"""
assert (
self._quorum_future is not None
), "must call start_quorum before wait_quorum"
self._quorum_future.result()
def _async_quorum(self, room_id: str, allow_heal: bool, timeout: timedelta) -> None:
(
quorum_id,
replica_rank,
replica_world_size,
address,
store_address,
max_step,
max_rank,
max_world_size,
heal,
) = self._client.quorum(
room_id=room_id,
rank=self._rank,
step=self._step,
checkpoint_server_addr=self._ckpt_server.address(),
timeout=timeout,
)
# When using async quorum we need to take the recovered workers.
# When not using async quorum we need to take the max world size as all
# workers will be healthy.
self._participating_rank, self._participating_world_size = (
(max_rank, max_world_size)
if self._use_async_quorum or not allow_heal
else (replica_rank, replica_world_size)
)
# For fixed with spares we need to ensure that we don't have more
# participating replicas than the min replica size.
if self._world_size_mode == WorldSizeMode.FIXED_WITH_SPARES:
self._participating_world_size = min(
self._participating_world_size, self._min_replica_size
)
if (
self._participating_rank is not None
and self._participating_rank >= self._min_replica_size
):
self._participating_rank = None
if quorum_id != self._quorum_id:
store_prefixed_addr = f"{store_address}/torchft/{quorum_id}/{self._rank}"
self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
# We use the replica rank and world as we want all replicas in the PG.
self._pg.configure(store_prefixed_addr, replica_rank, replica_world_size)
self._quorum_id = quorum_id
# See manager.rs for healing conditions
if heal and allow_heal:
self._healing = True
self._logger.info(
f"healing required, fetching checkpoint server address from {address=} {max_step=}"
)
primary_client = ManagerClient(address, timeout=timeout)
checkpoint_server_address = primary_client.checkpoint_address(
self._rank, timeout=timeout
)
self._logger.info(f"fetching checkpoint from {checkpoint_server_address=}")
self._pending_state_dict = CheckpointServer.load_from_address(
checkpoint_server_address
)
self.load_state_dict(self._pending_state_dict["torchft"])
# we apply the user state dict only when safe from the main thread
# This isn't strictly needed as loading the state_dict above should
# restore the correct step but it makes writing tests simpler.
self._step = max_step
def _apply_pending_state_dict(self) -> None:
assert self._healing, "must be in healing state"
# synchronize on future
assert self._quorum_future is not None, "must call step before should_commit"
self._quorum_future.result()
self._logger.info("applying pending state dict")
assert self._pending_state_dict is not None, "checkpoint was not staged"
self._load_state_dict(self._pending_state_dict["user"])
self._pending_state_dict = None
[docs] def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
"""
.. note::
We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly.
Must be called after the backwards pass completes but before stepping the optimizer.
The optimizer must only be stepped if this returns True.
This must be called on all workers within a replica group. This uses a
collective to ensure all workers within a replica return the same value.
If an error occurs on any worker, all workers will return False.
Different replica groups may return different values.
This should only be called once per step.
Returns:
True if the optimizer should be stepped, False otherwise
"""
for work in self._pending_work:
# check at the beginning of since .wait() may trigger errors
if self._errored is not None:
break
# We swallow the error at in a future then callback so this will
# never return an error.
work.wait()
self._pending_work = []
# apply state_dict if healing
if self._healing:
self._apply_pending_state_dict()
enough_replicas = self.num_participants() >= self._min_replica_size
local_should_commit = enough_replicas and self._errored is None
should_commit = self._client.should_commit(
self._rank,
self._step,
local_should_commit,
timeout=timeout or self._timeout,
)
self._logger.info(
f"should_commit={should_commit} enough_replicas={enough_replicas}, errored={self._errored}"
)
self._ckpt_server.disallow_checkpoint()
# decide whether we're in a healthy state to increase the step count
if should_commit:
self._step += 1
self._batches_committed += self.num_participants()
return should_commit
[docs] def load_state_dict(self, state_dict: Dict[str, int]) -> None:
"""
Load the state dict from a previous checkpoint.
This will restore the step count and internal metadata.
Args:
state_dict: the state dict to load
"""
self._step = state_dict["step"]
self._batches_committed = state_dict["batches_committed"]
[docs] def state_dict(self) -> Dict[str, int]:
"""
Get the state dict for this manager.
This can be used to checkpoint the state of the manager to restore
from a previous checkpoint.
Returns:
the state dict for this manager
"""
return {"step": self._step, "batches_committed": self._batches_committed}
[docs] def current_step(self) -> int:
"""
Get the current step count.
This number is incremented on .step()
Returns:
the current step count
"""
return self._step
[docs] def batches_committed(self) -> int:
"""
Get the total number of batches committed across all steps and replicas.
5 replicas participating in 2 steps is 10 batches but may be more than
10 examples depending on batch size.
This number is incremented on .step()
Returns:
the total number of batches committed
"""
return self._batches_committed
[docs] def num_participants(self) -> int:
"""
Get the number of participants in the current quorum.
This is the number of replicas participating in the current step.
Returns:
the number of participants in the current quorum
"""
assert self._participating_world_size >= 0, "internal error"
return self._participating_world_size
[docs] def is_participating(self) -> bool:
"""
Get whether this replica is participating in the current quorum.
Returns:
whether this replica is participating in the current quorum
"""
if self._participating_rank is None:
return False
if self._healing:
assert self._use_async_quorum
return False
return True
class _ManagerLogger:
def __init__(self, manager: Manager, replica_id: str, rank: int) -> None:
self._logger: logging.Logger = logging.getLogger(__name__)
self._replica_id = replica_id
self._rank = rank
self._manager = manager
def prefix(self) -> str:
return (
f"[{self._replica_id}/{self._rank} - step {self._manager.current_step()}]"
)
def info(self, msg: str) -> None:
self._logger.info(f"{self.prefix()} {msg}")
def warn(self, msg: str) -> None:
self._logger.warn(f"{self.prefix()} {msg}")
def exception(self, msg: str) -> None:
self._logger.exception(f"{self.prefix()} {msg}")