Shortcuts

Manager

This module implements the Manager that manages the full fault tolerant training loop.

The Manager is responsible for managing the full training loop, communicating with the Lighthouse server to figure out quorum, reconfiguring the ProcessGroups and restoring checkpoint state when recovering.

This uses wrapper classes to wrap the standard PyTorch Optimizer and Module classes to provide fault tolerance. These wrappers indented to add fault tolerance with minimal changes to the users modeling code and training loop.

This is designed to work with the standard PyTorch DistributedDataParallel module and Hybrid FSDP.

class torchft.manager.Manager(pg: ProcessGroup, load_state_dict: Callable[[T], None], state_dict: Callable[[], T], min_replica_size: int, use_async_quorum: bool = True, timeout: timedelta = datetime.timedelta(seconds=60), rank: Optional[int] = None, world_size: Optional[int] = None, world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC, store_addr: Optional[str] = None, store_port: Optional[int] = None, lighthouse_addr: Optional[str] = None, replica_id: Optional[str] = None, port: Optional[int] = None)[source]

Bases: object

Manager manages the full fault tolerant training loop.

This requires the that the TCPStore specified by the store_addr and store_port or MASTER_ADDR and MASTER_PORT environment variables to be started prior to creating this manager. If using a modern version of torchelastic this will already be the case. Otherwise, it should be started via torch.distributed.init_process_group prior to creating this manager.

NOTE: when saving periodic checkpoints you must save and restore the Manager’s state_dict as well to avoid synchronization issues.

allreduce(tensor: Tensor) Future[Tensor][source]

Fault tolerant allreduce the tensor and return a Future that will be completed when the tensor is ready.

This will automatically scale the tensor by 1 / world_size.

If an error occurs during the allreduce:

  • The Future will be completed with no error and instead tracked asynchronously.

  • After the first error, all subsequent calls will be noops and immediately return.

  • The tensor must be zeroed before being used as it may be corrupted.

Parameters:

tensor – the tensor to allreduce

Returns:

a Future that will be completed with the allreduced tensor

batches_committed() int[source]

Get the total number of batches committed across all steps and replicas. 5 replicas participating in 2 steps is 10 batches but may be more than 10 examples depending on batch size.

This number is incremented on .step()

Returns:

the total number of batches committed

current_step() int[source]

Get the current step count.

This number is incremented on .step()

Returns:

the current step count

errored() Optional[Exception][source]

Get whether an error has occurred.

Returns:

The error or None if no error has occurred.

is_participating() bool[source]

Get whether this replica is participating in the current quorum.

Returns:

whether this replica is participating in the current quorum

load_state_dict(state_dict: Dict[str, int]) None[source]

Load the state dict from a previous checkpoint.

This will restore the step count and internal metadata.

Parameters:

state_dict – the state dict to load

num_participants() int[source]

Get the number of participants in the current quorum.

This is the number of replicas participating in the current step.

Returns:

the number of participants in the current quorum

report_error(e: Exception) None[source]

Report an error to the manager.

This will cause the manager to skip the current step and will be reconfigured on the next step.

This should be called when an error occurs that leads to a corrupted gradient that needs to be discarded.

should_commit(timeout: Optional[timedelta] = None) bool[source]

Note

We recommend using the torchft.optim.OptimizerWrapper instead of calling this directly.

Must be called after the backwards pass completes but before stepping the optimizer.

The optimizer must only be stepped if this returns True.

This must be called on all workers within a replica group. This uses a collective to ensure all workers within a replica return the same value. If an error occurs on any worker, all workers will return False. Different replica groups may return different values.

This should only be called once per step.

Returns:

True if the optimizer should be stepped, False otherwise

shutdown() None[source]

Shutdown the manager and checkpoint server.

start_quorum(room_id: str = 'default', allow_heal: bool = True, timeout: Optional[timedelta] = None) None[source]

Note

We recommend using the torchft.optim.OptimizerWrapper instead of calling this directly.

Computes a new quorum (potentially asynchronously) and readies the manager for a new step.

It’s best practice to call this before the forwards pass of each step for performance as computing quorum may take some time.

Parameters:
  • allow_heal – (experimental) whether to allow healing at the beginning of the step If allow_heal is set, the manager will attempt to heal either synchronously before returning or asynchronously prior to any network calls. All replicas must pass the same value to allow_heal.

  • room_id – (experimental) the room id to use for quorum, this allows for multiple quorums to be used within the same job.

  • timeout – the timeout for quorum and recovery operations, if None, the manager’s timeout will be used

state_dict() Dict[str, int][source]

Get the state dict for this manager.

This can be used to checkpoint the state of the manager to restore from a previous checkpoint.

Returns:

the state dict for this manager

wait_quorum() None[source]

Wait for the quorum to complete.

ProcessGroup will be in a healthy state after this returns.

wrap_future(fut: Future[T], default: T, timeout: Optional[timedelta] = None) Future[T][source]

Wrap a Future and swallow any errors that occur and report them to the manager.

If an error occurs, the Future will be completed with the default value.

Parameters:
  • fut – the Future to wrap

  • default – the default value to complete the Future with if an error occurs

  • timeout – the timeout for the Future, if None, the manager’s timeout will be used

class torchft.manager.WorldSizeMode(value)[source]

Bases: Enum

This controls the numerics for the job when doing allreduces across replicas when the world size is larger than min_replica_size. The world size will never be smaller than min_replica_size.

DYNAMIC:

The world size will dynamical increase to use all available replicas and normalize the gradient by the world size.

FIXED_WITH_SPARES:

The number of active replicas is min_replica_size and any spares will contribute zero gradients.

DYNAMIC = 0
FIXED_WITH_SPARES = 1

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources