Manager¶
This module implements the Manager that manages the full fault tolerant training loop.
The Manager is responsible for managing the full training loop, communicating with the Lighthouse server to figure out quorum, reconfiguring the ProcessGroups and restoring checkpoint state when recovering.
This uses wrapper classes to wrap the standard PyTorch Optimizer and Module classes to provide fault tolerance. These wrappers indented to add fault tolerance with minimal changes to the users modeling code and training loop.
This is designed to work with the standard PyTorch DistributedDataParallel module and Hybrid FSDP.
- class torchft.manager.Manager(pg: ProcessGroup, load_state_dict: Callable[[T], None], state_dict: Callable[[], T], min_replica_size: int, use_async_quorum: bool = True, timeout: timedelta = datetime.timedelta(seconds=60), rank: Optional[int] = None, world_size: Optional[int] = None, world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC, store_addr: Optional[str] = None, store_port: Optional[int] = None, lighthouse_addr: Optional[str] = None, replica_id: Optional[str] = None, port: Optional[int] = None)[source]¶
Bases:
object
Manager manages the full fault tolerant training loop.
This requires the that the TCPStore specified by the store_addr and store_port or MASTER_ADDR and MASTER_PORT environment variables to be started prior to creating this manager. If using a modern version of torchelastic this will already be the case. Otherwise, it should be started via torch.distributed.init_process_group prior to creating this manager.
NOTE: when saving periodic checkpoints you must save and restore the Manager’s state_dict as well to avoid synchronization issues.
- allreduce(tensor: Tensor) Future[Tensor] [source]¶
Fault tolerant allreduce the tensor and return a Future that will be completed when the tensor is ready.
This will automatically scale the tensor by 1 / world_size.
If an error occurs during the allreduce:
The Future will be completed with no error and instead tracked asynchronously.
After the first error, all subsequent calls will be noops and immediately return.
The tensor must be zeroed before being used as it may be corrupted.
- Parameters:
tensor – the tensor to allreduce
- Returns:
a Future that will be completed with the allreduced tensor
- batches_committed() int [source]¶
Get the total number of batches committed across all steps and replicas. 5 replicas participating in 2 steps is 10 batches but may be more than 10 examples depending on batch size.
This number is incremented on .step()
- Returns:
the total number of batches committed
- current_step() int [source]¶
Get the current step count.
This number is incremented on .step()
- Returns:
the current step count
- errored() Optional[Exception] [source]¶
Get whether an error has occurred.
- Returns:
The error or None if no error has occurred.
- is_participating() bool [source]¶
Get whether this replica is participating in the current quorum.
- Returns:
whether this replica is participating in the current quorum
- load_state_dict(state_dict: Dict[str, int]) None [source]¶
Load the state dict from a previous checkpoint.
This will restore the step count and internal metadata.
- Parameters:
state_dict – the state dict to load
- num_participants() int [source]¶
Get the number of participants in the current quorum.
This is the number of replicas participating in the current step.
- Returns:
the number of participants in the current quorum
- report_error(e: Exception) None [source]¶
Report an error to the manager.
This will cause the manager to skip the current step and will be reconfigured on the next step.
This should be called when an error occurs that leads to a corrupted gradient that needs to be discarded.
- should_commit(timeout: Optional[timedelta] = None) bool [source]¶
Note
We recommend using the
torchft.optim.OptimizerWrapper
instead of calling this directly.Must be called after the backwards pass completes but before stepping the optimizer.
The optimizer must only be stepped if this returns True.
This must be called on all workers within a replica group. This uses a collective to ensure all workers within a replica return the same value. If an error occurs on any worker, all workers will return False. Different replica groups may return different values.
This should only be called once per step.
- Returns:
True if the optimizer should be stepped, False otherwise
- start_quorum(room_id: str = 'default', allow_heal: bool = True, timeout: Optional[timedelta] = None) None [source]¶
Note
We recommend using the
torchft.optim.OptimizerWrapper
instead of calling this directly.Computes a new quorum (potentially asynchronously) and readies the manager for a new step.
It’s best practice to call this before the forwards pass of each step for performance as computing quorum may take some time.
- Parameters:
allow_heal – (experimental) whether to allow healing at the beginning of the step If allow_heal is set, the manager will attempt to heal either synchronously before returning or asynchronously prior to any network calls. All replicas must pass the same value to allow_heal.
room_id – (experimental) the room id to use for quorum, this allows for multiple quorums to be used within the same job.
timeout – the timeout for quorum and recovery operations, if None, the manager’s timeout will be used
- state_dict() Dict[str, int] [source]¶
Get the state dict for this manager.
This can be used to checkpoint the state of the manager to restore from a previous checkpoint.
- Returns:
the state dict for this manager
- wait_quorum() None [source]¶
Wait for the quorum to complete.
ProcessGroup will be in a healthy state after this returns.
- wrap_future(fut: Future[T], default: T, timeout: Optional[timedelta] = None) Future[T] [source]¶
Wrap a Future and swallow any errors that occur and report them to the manager.
If an error occurs, the Future will be completed with the default value.
- Parameters:
fut – the Future to wrap
default – the default value to complete the Future with if an error occurs
timeout – the timeout for the Future, if None, the manager’s timeout will be used
- class torchft.manager.WorldSizeMode(value)[source]¶
Bases:
Enum
This controls the numerics for the job when doing allreduces across replicas when the world size is larger than
min_replica_size
. The world size will never be smaller thanmin_replica_size
.- DYNAMIC:
The world size will dynamical increase to use all available replicas and normalize the gradient by the world size.
- FIXED_WITH_SPARES:
The number of active replicas is
min_replica_size
and any spares will contribute zero gradients.
- DYNAMIC = 0¶
- FIXED_WITH_SPARES = 1¶