torchelastic requires you to implement a state object and a train_step function.

For details on what these are refer to how torch elastic works.

While going through the sections below, refer to the imagenet example

for more complete implementation details.

Implement state

The State object has two categories of methods that need to be implemented:

synchronization and persistence.


Lets take a look at synchronization first. The sync method is responsible for

ensuring that all workers get a consistent view of state. It is called at

startup as well as on each event that potentially leaves the workers out of sync,

for instance, on membership changes and rollback events. Torchelastic relies on

the sync() method for state recovery from surviving workers (e.g. when

there are membership changes, either due to worker failure or elasticity,

the new workers receive the most up-to-date state from one of the surviving

workers - usually the one that has the most recent state - we call this worker

the most tenured worker).

Things you should consider doing in sync are:

  1. Broadcasting global parameters/data from a particular worker (e.g. rank 0).
  2. (re)Initializing data loaders based on markers (e.g. last known start index).
  3. (re)Initializing the model.

> IMPORTANT: state.sync() is not meant for synchronizing steps in training. For instance

you should not be synchronizing weights (e.g .all-reduce model weights for synchronous SGD).

These type of collectives operations belong in the train_step.

All workers initially create the state object with the same constructor arguments.

We refer to this initial state as S_0 and assume that any worker is able to create

S*0 without needing any assistance from torchelastic. Essentially S*0 is the bootstrap

state. This concept will become important in the next sections when talking about

state persistence (rollbacks and checkpoints).

(optional) capture*snapshot() and apply*snapshot()

> You do not have to implement these methods if you do not want rollbacks

from failed train_steps

torchelastic has the ability to rollback a state if a train_step fails to

execute successfully, which may result in the state object being left partially

updated. It relies on a properly implemented capture*snapshot() and apply*snapshot()

methods of the state to ensure that the state is restored to before the

faulty train_step.

The capture_snapshot() method, as the name implies, takes a snapshot of the state

and returns the necessary information to be able to restore

the state object. You may return any object from capture_snapshot() so long as you

can use it in the apply_snapshot(snapshot) method. A possible implementation of a

rollback is:


snapshot = state.capture_snapshot()



except RuntimeError:




> NOTE: Since certain fields of the state may need to get re-initialized,

torchelastic calls the sync() method. For instance, data loaders may need

to be restarted as their iterators may end up in a corrupted state when the

train_step does not exit successfully.

Notice that the apply method is called on the existing state object, this implies

that an efficient implementation of snapshot should only return mutable, stateful

data. Immutable fields or fields that can be derived from other member variables or

restored in the sync method need not be included in the snapshot.

By default the capture*snapshot() method returns None and the apply*snapshot() method

is a pass, which essentially means “rollback not supported”.

> IMPORTANT: The apply_snapshot object should make no assumptions about

which state object it is called on (e.g. the values of the member variables).

That is, applying a snapshot

to any state followed by state.sync() should effectively restore the

state object to when the corresponding capture_snapshot method was called.

A good rule of thumb is that the apply_snapshot should act more like a set

method rather than an update method.

(optional) save(stream) and load(stream)

> You do not have to implement these methods if you do not plan on using


Much like the capture*snapshot and apply*snapshot, the save and load methods form a pair.

They are responsible for persisting and restoring the state object to and from

a stream which is a file-like object

that is compatible with

torchelastic relies on these methods to provide checkpoint functionality for your job.

> We encourage users to use and torch.load methods when implementing

save and load methods of their state class.

> NOTE: The default implementations of save and load use capture_snapshot

and apply_snapshot

Implement train_step

The train_step is a function that takes state as a single argument

and carries out a partition of the overall training job.

This is your unit of work and it is up to you to define what

a unit is. When deciding what your unit of work should be, keep in mind the


  1. Rollbacks and checkpoints are done at train_step granularity. This means

that torchelastic can only recover to the last successful train_step Any failures

during the train_step are not recoverable.

  1. A train*step iteration in the train*loop has overhead due

to the work that goes in ensuring that your job is fault-tolerant and elastic.

How much overhead depends on your configurations for rollbacks and checkpoints as well

as how expensive your snapshot, apply, save and load functions are.

> In most cases, your job naturally lends itself to an

obvious train_step. The most canonical one for many training jobs is to map

the processing of a mini-batch of training data to a train_step.

There is a trade-off to be made between how much work you are

willing to lose versus how much overhead you want to pay for that security.

Write a

Now that you have state and train_step implementations all that remains

is to bring everything together and implement a main that will execute your

training. Your script should initialize torchelastic’s coordinator, create

your state object, and call the train_loop. Below is a simple example:


import torchelastic

from torchelastic.p2p import CoordinatorP2P

if name == “main”:

min_workers = 1

max_workers = 1

run_id = 1234

etcd_endpoint = “localhost:2379”

state = MyState()

coordinator = CoordinatorP2P(






torchelastic.train(coordinator, train_step, state)




See metrics documentation.

Checkpoint and Rollback

See checkpoint documentation


See rendezvous documentation


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources