Distributed Data Parallel¶
This module implements a DistributedDataParallel wrapper that works with the Manager to provide fault tolerance.
- class torchft.ddp.DistributedDataParallel(manager: Manager, module: Module, **kwargs: object)[source]¶
Bases:
DistributedDataParallel
This is a patched DistributedDataParallel implementation that makes it compatible with torchft.
Important notes:
This requires states to be synced on step 0 using an external mechanism rather than an internal broadcast (torchft.Manager will do this).
Using non-basic features of the DDP may cause your model to catch fire as they haven’t been tested with torchft.
This doesn’t any sanity checks such as verifying parameter sizes are the same across workers.
- class torchft.ddp.PureDistributedDataParallel(manager: Manager, module: Module)[source]¶
Bases:
Module
A pure Python reimplementation of the DDP wrapper.
We recommend using DistributedDataParallel instead of this class.
This calls one allreduce per gradient tensor and doesn’t use a reducer. This may be very slow for real models.
- forward(*args: object) object [source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.