Usage ===== torchelastic requires you to implement a `state` object and a `train_step` function. For details on what these are refer to `how torch elastic works <README.md>`_. While going through the sections below, refer to the imagenet `example <examples/imagenet/main.py>`_ for more complete implementation details. Implement `state` ------------------ The `State` object has two categories of methods that need to be implemented: synchronization and persistence. `sync()` ~~~~~~~~~ Lets take a look at synchronization first. The `sync` method is responsible for ensuring that all workers get a consistent view of `state`. It is called at startup as well as on each event that potentially leaves the workers out of sync, for instance, on membership changes and rollback events. Torchelastic relies on the `sync()` method for `state` recovery from surviving workers (e.g. when there are membership changes, either due to worker failure or elasticity, the new workers receive the most up-to-date `state` from one of the surviving workers - usually the one that has the most recent `state` - we call this worker the most **tenured** worker). Things you should consider doing in `sync` are: 1. Broadcasting global parameters/data from a particular worker (e.g. rank 0). 2. (re)Initializing data loaders based on markers (e.g. last known start index). 3. (re)Initializing the model. > IMPORTANT: `state.sync()` is **not** meant for synchronizing steps in training. For instance you should not be synchronizing weights (e.g .all-reduce model weights for synchronous SGD). These type of collectives operations belong in the `train_step`. All workers initially create the `state` object with the same constructor arguments. We refer to this initial state as `S_0` and assume that any worker is able to create `S*0` without needing any assistance from torchelastic. Essentially `S*0` is the **bootstrap** state. This concept will become important in the next sections when talking about state persistence (rollbacks and checkpoints). (optional) `capture*snapshot()` and `apply*snapshot()` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ > You do not have to implement these methods if you do not want rollbacks from failed `train_steps` torchelastic has the ability to rollback a state if a `train_step` fails to execute successfully, which may result in the `state` object being left partially updated. It relies on a properly implemented `capture*snapshot()` and `apply*snapshot()` methods of the `state` to ensure that the `state` is restored to before the faulty `train_step`. The `capture_snapshot()` method, as the name implies, takes a snapshot of the `state` and returns the necessary information to be able to restore the `state` object. You may return **any** object from `capture_snapshot()` so long as you can use it in the `apply_snapshot(snapshot)` method. A possible implementation of a rollback is: ```python snapshot = state.capture_snapshot() try: train\_step(state) except RuntimeError: state.apply\_snapshot(snapshot) state.sync() ``` > NOTE: Since certain fields of the `state` may need to get re-initialized, torchelastic calls the `sync()` method. For instance, data loaders may need to be restarted as their iterators may end up in a corrupted state when the `train_step` does not exit successfully. Notice that the apply method is called on the existing `state` object, this implies that an efficient implementation of `snapshot` should only return mutable, stateful data. Immutable fields or fields that can be derived from other member variables or restored in the `sync` method need not be included in the snapshot. By default the `capture*snapshot()` method returns `None` and the `apply*snapshot()` method is a `pass`, which essentially means "rollback not supported". > IMPORTANT: The `apply_snapshot` object should make **no** assumptions about which `state` object it is called on (e.g. the values of the member variables). That is, applying a `snapshot` to **any** state followed by `state.sync()` should effectively restore the state object to when the corresponding `capture_snapshot` method was called. A good rule of thumb is that the `apply_snapshot` should act more like a `set` method rather than an `update` method. (optional) `save(stream)` and `load(stream)` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ > You do not have to implement these methods if you do not plan on using checkpointing. Much like the `capture*snapshot` and `apply*snapshot`, the `save` and `load` methods form a pair. They are responsible for persisting and restoring the `state` object to and from a `stream` which is a *file-like* object that is compatible with `pytorch.save <https://pytorch.org/docs/stable/torch.html?highlight=save#torch.save>`_. torchelastic relies on these methods to provide checkpoint functionality for your job. > We encourage users to use `torch.save` and `torch.load` methods when implementing `save` and `load` methods of their `state` class. > NOTE: The default implementations of `save` and `load` use `capture_snapshot` and `apply_snapshot` Implement `train_step` ----------------------- The `train_step` is a function that takes `state` as a single argument and carries out a partition of the overall training job. This is your unit of work and it is up to you to define what a *unit* is. When deciding what your unit of work should be, keep in mind the following: 1. Rollbacks and checkpoints are done at `train_step` granularity. This means that torchelastic can only recover to the last successful `train_step` Any failures **during** the train_step are not recoverable. 2. A `train*step` iteration in the `train*loop` has overhead due to the work that goes in ensuring that your job is fault-tolerant and elastic. How much overhead depends on your configurations for rollbacks and checkpoints as well as how expensive your `snapshot`, `apply`, `save` and `load` functions are. > In most cases, your job naturally lends itself to an obvious `train_step`. The most canonical one for many training jobs is to map the processing of a mini-batch of training data to a `train_step`. There is a trade-off to be made between how much work you are willing to lose versus how much overhead you want to pay for that security. Write a `main.py` ----------------- Now that you have `state` and `train_step` implementations all that remains is to bring everything together and implement a `main` that will execute your training. Your script should initialize torchelastic's `coordinator`, create your `state` object, and call the `train_loop`. Below is a simple example: ```python import torchelastic from torchelastic.p2p import CoordinatorP2P if **name** == "**main**": min\_workers = 1 max\_workers = 1 run\_id = 1234 etcd\_endpoint = "localhost:2379" state = MyState() coordinator = CoordinatorP2P( c10d\_backend="gloo", init\_method=f"etcd://{etcd\_endpoint}/{run\_id}?min\_workers={min\_workers}&max\_workers={max\_workers}", max\_num\_trainers=max\_workers, process\_group\_timeout=60000, ) torchelastic.train(coordinator, train\_step, state) ``` Configuring ------------ Metrics ~~~~~~~~ See metrics `documentation <torchelastic/metrics/README.md>`_. Checkpoint and Rollback ~~~~~~~~~~~~~~~~~~~~~~~~~ See checkpoint `documentation <torchelastic/checkpoint/README.md>`_ Rendezvous ~~~~~~~~~~~~ See rendezvous `documentation <torchelastic/rendezvous/README.md>`_