LocalSGD¶
This module implements a fault tolerant version of LocalSGD and related methods.
- class torchft.local_sgd.LocalSGD(manager: Manager, model: Module, optimizer: Optimizer, sync_every: int, backup_device: Optional[device] = None, pin_memory: bool = True)[source]¶
Bases:
Module
LocalSGD is a model wrapper similar to DistributedDataParallel that implements the algorithm described in https://arxiv.org/pdf/1805.09767
This will synchronize the model parameters periodically in a fault tolerant way using a torchft Manager. The allreduce on the parameters will happen every sync_every steps after the optimizer.step call.
To implement safe and fault tolerant, this requires a backup copy of the weights. By default these are stored in CPU memory. If any error occurs during the LocalSGD step, the step will be discarded and the model parameters will reset back to the last time LocalSGD synchronized.
The backup weights could be eliminated by relaxing the guarantee of exactly sync_every steps but that would diverge from the LocalSGD algorithm. DiLoCo also needs this backup copy to compute the delta.
The torchft quorum is computed at the beginning of
sync_every
steps. If any error occurs, or a worker fails between syncs,sync_every
steps will be discarded and a new quorum will be computed on the next step.If running in async mode, on a joining worker the first
sync_every
steps will discarded as the model will be recovering during that period. When using sync mode, the checkpoint will be restored prior to the first step.TODO: add a way via Manager to detect workers failing early for shrink only TODO: add DiLoCo support
- forward(*args: object, **kwargs: object) object [source]¶
Run the model parameters.
This should be called before the optimizer step.
This will start the quorum and save the parameters if this is the first step.