Shortcuts

ignite.engine#

Main module of the library containing:

ignite.engine.engine#

Engine

Runs a given process_function over each batch of a dataset, emitting events as it goes.

ignite.engine.events#

CallableEventWithFilter

Single Event containing a filter, specifying whether the event should be run at the current event (if the event type is correct)

EventEnum

Base class for all Events.

Events

Events that are fired by the Engine during execution.

EventsList

Collection of events stacked by operator __or__.

State

An object that is used to pass internal and user-defined state between event handlers.

RemovableEventHandle

A weakref handle to remove a registered event.

ignite.engine.deterministic#

Helper methods for deterministic training

DeterministicEngine

Deterministic engine derived from Engine.

ReproducibleBatchSampler

Reproducible batch sampler.

keep_random_state

Helper decorator to keep random state of torch, numpy and random intact while executing a function.

update_dataloader

Helper function to replace current batch sampler of the dataloader by a new batch sampler.

helper methods to define supervised trainer and evaluator#

create_supervised_trainer

Factory function for creating a trainer for supervised models.

create_supervised_evaluator

Factory function for creating an evaluator for supervised models.

supervised_training_step

Factory function for supervised training.

supervised_training_step_amp

Factory function for supervised training using torch.cuda.amp.

supervised_training_step_apex

Factory function for supervised training using apex.

supervised_training_step_tpu

Factory function for supervised training using torch_xla.

supervised_evaluation_step

Factory function for supervised evaluation.

supervised_evaluation_step_amp

Factory function for supervised evaluation using torch.cuda.amp.

Resuming the training#

It is possible to resume the training from a checkpoint and approximately reproduce original run’s behaviour. Using Ignite, this can be easily done using Checkpoint handler. Engine provides two methods to serialize and deserialize its internal state state_dict() and load_state_dict(). In addition to serializing model, optimizer, lr scheduler, metrics, etc., user can store the trainer and then resume the training. For example:

from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint, DiskSaver

trainer = ...
model = ...
optimizer = ...
lr_scheduler = ...
data_loader = ...
metric = ...

to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'metric': metric}
handler = Checkpoint(to_save, DiskSaver('/tmp/training', create_dir=True))
trainer.add_event_handler(Events.EPOCH_COMPLETED, handler)
trainer.run(data_loader, max_epochs=100)
ls /tmp/training
> "checkpoint_50000.pt"

We can then restore the training from the last checkpoint.

from ignite.handlers import Checkpoint

trainer = ...
model = ...
optimizer = ...
lr_scheduler = ...
data_loader = ...
metric = ...

to_load = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'metric': metric}
checkpoint = torch.load(checkpoint_file)
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)

trainer.run(train_loader, max_epochs=100)

It is also possible to store checkpoints every N iterations and continue the training from one of these checkpoints, i.e from iteration.

Complete examples that resumes the training from a checkpoint can be found here:

Deterministic training#

In general, it is rather difficult task to achieve deterministic and reproducible trainings as it relies on multiple aspects, e.g. data version, code version, software environment, hardware etc. According to PyTorch note on randomness: there are some steps to take in order to make computations deterministic on your specific problem on one specific platform and PyTorch release:

By default, these two options can be enough to run and rerun experiments in a deterministic way. Ignite’s engine does not impact this behaviour.

In this module we provide helper methods and classes to make additional “Dataflow synchronization” to ensure that model sees the same data for a given epoch:

Dataflow synchronization#

Ignite provides an option to control the dataflow by synchronizing random state on epochs. In this way, for a given iteration/epoch the dataflow can be the same for a given seed. More precisely it is roughly looks like:

for e in range(num_epochs):
    set_seed(seed + e)
    do_single_epoch_iterations(dataloader)

In addition, if data provider is torch.utils.data.DataLoader, batch data indices can be made completely deterministic. Here is a trivial example of usage:

import torch
from torch.utils.data import DataLoader
from ignite.engine import DeterministicEngine, Events
from ignite.utils import manual_seed


def random_train_data_loader(size):
    data = torch.arange(0, size)
    return DataLoader(data, batch_size=4, shuffle=True)


def print_train_data(engine, batch):
    i = engine.state.iteration
    e = engine.state.epoch
    print("train", e, i, batch.tolist())

trainer = DeterministicEngine(print_train_data)

print("Original Run")
manual_seed(56)
trainer.run(random_train_data_loader(40), max_epochs=2, epoch_length=5)

print("Resumed Run")
# Resume from 2nd epoch
trainer.load_state_dict({"epoch": 1, "epoch_length": 5, "max_epochs": 2, "rng_states": None})
manual_seed(56)
trainer.run(random_train_data_loader(40))
Original Run
train 1 1 [31, 13, 3, 4]
train 1 2 [23, 18, 6, 16]
train 1 3 [10, 8, 33, 36]
train 1 4 [1, 37, 19, 9]
train 1 5 [20, 30, 14, 26]
train 2 6 [29, 35, 38, 34]
train 2 7 [7, 22, 12, 17]
train 2 8 [25, 21, 24, 15]
train 2 9 [39, 5, 2, 28]
train 2 10 [27, 11, 32, 0]
Resumed Run
train 2 6 [29, 35, 38, 34]
train 2 7 [7, 22, 12, 17]
train 2 8 [25, 21, 24, 15]
train 2 9 [39, 5, 2, 28]
train 2 10 [27, 11, 32, 0]

We can see that the data samples are exactly the same between original and resumed runs.

Complete examples that simulates a crash on a defined iteration and resumes the training from a checkpoint can be found here:

Note

In case when input data is torch.utils.data.DataLoader, previous batches are skipped and the first provided batch corresponds to the batch after the checkpoint iteration. Internally, while resuming, previous datapoint indices are just skipped without fetching the data.

Warning

However, while resuming from iteration, random data augmentations are not synchronized in the middle of the epoch and thus batches remaining until the end of the epoch can be different of those from the initial run.

Warning

However, please, keep in mind that there can be an issue with dataflow synchronization on every epoch if user’s handler synchronizes the random state, for example, by calling periodically torch.manual_seed(seed) during the run. This can have an impact on the dataflow:

def random_train_data_generator():
    while True:
        yield torch.randint(0, 100, size=(1, ))

trainer = DeterministicEngine(print_train_data)

@trainer.on(Events.ITERATION_COMPLETED(every=3))
def user_handler():
    # handler synchronizes the random state
    torch.manual_seed(12)
    a = torch.rand(1)

trainer.run(random_train_data_generator(), max_epochs=3, epoch_length=5);
train 1 1 [32]
train 1 2 [29]
train 1 3 [40]
train 1 4 [3]  <---
train 1 5 [22]
train 2 6 [77]
train 2 7 [3]  <---
train 2 8 [22]
train 2 9 [77]
train 2 10 [3] <---
train 3 11 [22]
train 3 12 [77]
train 3 13 [3] <---
train 3 14 [22]
train 3 15 [77]

Initially, the function random_train_data_generator() generates randomly data batches using the random state set up by trainer. This is intended behaviour until user_handler() is called. After user_handler() execution, random state is altered and thus random_train_data_generator() will produce random batches based on altered random state.

We provide helper decorator keep_random_state() to save and restore random states for torch, numpy and random. Therefore, we can deal with described issue using this decorator:

from ignite.engine.deterministic import keep_random_state

@trainer.on(Events.ITERATION_COMPLETED(every=3))
@keep_random_state
def user_handler():
    # handler synchronizes the random state
    torch.manual_seed(12)
    a = torch.rand(1)