Shortcuts

LocalSGD

This module implements a fault tolerant version of LocalSGD and related methods.

class torchft.local_sgd.LocalSGD(manager: Manager, model: Module, optimizer: Optimizer, sync_every: int, backup_device: Optional[device] = None, pin_memory: bool = True)[source]

Bases: Module

LocalSGD is a model wrapper similar to DistributedDataParallel that implements the algorithm described in https://arxiv.org/pdf/1805.09767

This will synchronize the model parameters periodically in a fault tolerant way using a torchft Manager. The allreduce on the parameters will happen every sync_every steps after the optimizer.step call.

To implement safe and fault tolerant, this requires a backup copy of the weights. By default these are stored in CPU memory. If any error occurs during the LocalSGD step, the step will be discarded and the model parameters will reset back to the last time LocalSGD synchronized.

The backup weights could be eliminated by relaxing the guarantee of exactly sync_every steps but that would diverge from the LocalSGD algorithm. DiLoCo also needs this backup copy to compute the delta.

The torchft quorum is computed at the beginning of sync_every steps. If any error occurs, or a worker fails between syncs, sync_every steps will be discarded and a new quorum will be computed on the next step.

If running in async mode, on a joining worker the first sync_every steps will discarded as the model will be recovering during that period. When using sync mode, the checkpoint will be restored prior to the first step.

TODO: add a way via Manager to detect workers failing early for shrink only TODO: add DiLoCo support

forward(*args: object, **kwargs: object) object[source]

Run the model parameters.

This should be called before the optimizer step.

This will start the quorum and save the parameters if this is the first step.

load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False) None[source]

Loads the state dict to the model and the backup parameters.

This must be called while the model weights aren’t being modified to avoid corrupting the backup weights.

state_dict() Dict[str, object][source]

state_dict returns the state_dict from the last time LocalSGD synchronized and not the current weights.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources