Shortcuts

ignite.handlers#

Complete list of handlers#

class ignite.handlers.Checkpoint(to_save, save_handler, filename_prefix='', score_function=None, score_name=None, n_saved=1, global_step_transform=None, archived=False)[source]#

Checkpoint handler can be used to periodically save and load objects which have attribute state_dict/load_state_dict. This class can use specific save handlers to store on the disk or a cloud storage, etc. The Checkpoint handler (if used with DiskSaver) also handles automatically moving data on TPU to CPU before writing the checkpoint.

Parameters
  • to_save (Mapping) – Dictionary with the objects to save. Objects should have implemented state_dict and ` load_state_dict` methods. If contains objects of type torch DistributedDataParallel or DataParallel, their internal wrapped model is automatically saved (to avoid additional key module. in the state dictionary).

  • save_handler (callable or BaseSaveHandler) – Method or callable class to use to save engine and other provided objects. Function receives two objects: checkpoint as a dictionary and filename. If save_handler is callable class, it can inherit of BaseSaveHandler and optionally implement remove method to keep a fixed number of saved checkpoints. In case if user needs to save engine’s checkpoint on a disk, save_handler can be defined with DiskSaver.

  • filename_prefix (str, optional) – Prefix for the filename to which objects will be saved. See Note for details.

  • score_function (callable, optional) – If not None, it should be a function taking a single argument, Engine object, and returning a score (float). Objects with highest scores will be retained.

  • score_name (str, optional) – If score_function not None, it is possible to store its value using score_name. See Notes for more details.

  • n_saved (int, optional) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept.

  • global_step_transform (callable, optional) – global step transform function to output a desired global step. Input of the function is (engine, event_name). Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use global_step_from_engine().

  • archived (bool, optional) – Deprecated argument as models saved by torch.save are already compressed.

Note

This class stores a single file as a dictionary of provided objects to save. The filename has the following structure: {filename_prefix}_{name}_{suffix}.{ext} where

  • filename_prefix is the argument passed to the constructor,

  • name is the key in to_save if a single object is to store, otherwise name is “checkpoint”.

  • suffix is composed as following {global_step}_{score_name}={score}.

Above global_step defined by the output of global_step_transform and score defined by the output of score_function.

By default, none of score_function, score_name, global_step_transform is defined, then suffix is setup by attached engine’s current iteration. The filename will be {filename_prefix}_{name}_{engine.state.iteration}.{ext}.

If defined a score_function, but without score_name, then suffix is defined by provided score. The filename will be {filename_prefix}_{name}_{global_step}_{score}.pt.

If defined score_function and score_name, then the filename will be {filename_prefix}_{name}_{score_name}={score}.{ext}. If global_step_transform is provided, then the filename will be {filename_prefix}_{name}_{global_step}_{score_name}={score}.{ext}

For example, score_name=”neg_val_loss” and score_function that returns -loss (as objects with highest scores will be retained), then saved filename will be {filename_prefix}_{name}_neg_val_loss=-0.1234.pt.

To get the last stored filename, handler exposes attribute last_checkpoint:

handler = Checkpoint(...)
...
print(handler.last_checkpoint)
> checkpoint_12345.pt

Note

This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 only process. This class supports automatically distributed configuration and if used with DiskSaver, checkpoint is stored by rank 0 process.

Warning

When running on TPUs, it should be run in all processes, otherwise application can get stuck on saving the checkpoint.

# Wrong:
# if idist.get_rank() == 0:
#     handler = Checkpoint(...)
#     trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)

# Correct:
handler = Checkpoint(...)
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)

Examples

Attach the handler to make checkpoints during training:

from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint, DiskSaver

trainer = ...
model = ...
optimizer = ...
lr_scheduler = ...

to_save = {'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'trainer': trainer}
handler = Checkpoint(to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2)
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)
trainer.run(data_loader, max_epochs=6)
> ["checkpoint_7000.pt", "checkpoint_8000.pt", ]

Attach the handler to an evaluator to save best model during the training according to computed validation metric:

from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine

trainer = ...
evaluator = ...
# Setup Accuracy metric computation on evaluator
# Run evaluation on epoch completed event
# ...

def score_function(engine):
    return engine.state.metrics['accuracy']

to_save = {'model': model}
handler = Checkpoint(to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2,
                     filename_prefix='best', score_function=score_function, score_name="val_acc",
                     global_step_transform=global_step_from_engine(trainer))

evaluator.add_event_handler(Events.COMPLETED, handler)

trainer.run(data_loader, max_epochs=10)
> ["best_model_9_val_acc=0.77.pt", "best_model_10_val_acc=0.78.pt", ]
static load_objects(to_load, checkpoint, **kwargs)[source]#

Helper method to apply load_state_dict on the objects from to_load using states from checkpoint.

Exemples:

import torch
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, Checkpoint
trainer = Engine(lambda engine, batch: None)
handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=None, create_dir=True)
model = torch.nn.Linear(3, 3)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
to_save = {"weights": model, "optimizer": optimizer}
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save)
trainer.run(torch.randn(10, 1), 5)

to_load = to_save
checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth"
checkpoint = torch.load(checkpoint_fp)
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
Parameters
  • to_load (Mapping) – a dictionary with objects, e.g. {“model”: model, “optimizer”: optimizer, …}

  • checkpoint (Mapping) – a dictionary with state_dicts to load, e.g. {“model”: model_state_dict, “optimizer”: opt_state_dict}. If to_load contains a single key, then checkpoint can contain directly corresponding state_dict.

  • **kwargs – Keyword arguments accepted for nn.Module.load_state_dict(). Passing strict=False enables the user to load part of the pretrained model (useful for example, in Transfer Learning)

Return type

None

class ignite.handlers.checkpoint.BaseSaveHandler[source]#

Base class for save handlers

Methods to override:

Note

In derived class, please, make sure that in distributed configuration overridden methods are called by a single process. Distributed configuration on XLA devices should be treated slightly differently: for saving checkpoint with xm.save() all processes should pass into the function. Otherwise, application gets stuck.

abstract __call__(checkpoint, filename, metadata=None)[source]#

Method to save checkpoint with filename. Additionally, metadata dictionary is provided.

Metadata contains:

  • basename: file prefix (if provided) with checkpoint name, e.g. epoch_checkpoint.

  • score_name: score name if provided, e.g val_acc.

  • priority: checkpoint priority value (higher is better), e.g. 12 or 0.6554435

Parameters
  • checkpoint (Mapping) – checkpoint dictionary to save.

  • filename (str) – filename associated with checkpoint.

  • metadata (Mapping, optional) – metadata on checkpoint to save.

Return type

None

abstract remove(filename)[source]#

Method to remove saved checkpoint.

Parameters

filename (str) – filename associated with checkpoint.

Return type

None

class ignite.handlers.DiskSaver(dirname, atomic=True, create_dir=True, require_empty=True)[source]#

Handler that saves input checkpoint on a disk.

Parameters
  • dirname (str) – Directory path where the checkpoint will be saved

  • atomic (bool, optional) – if True, checkpoint is serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occurs during saving).

  • create_dir (bool, optional) – if True, will create directory ‘dirname’ if it doesnt exist.

  • require_empty (bool, optional) – If True, will raise exception if there are any files in the directory ‘dirname’.

class ignite.handlers.ModelCheckpoint(dirname, filename_prefix, save_interval=None, score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, create_dir=True, save_as_state_dict=True, global_step_transform=None, archived=False)[source]#

ModelCheckpoint handler can be used to periodically save objects to disk only. If needed to store checkpoints to another storage type, please consider Checkpoint.

This handler expects two arguments:

  • an Engine object

  • a dict mapping names (str) to objects that should be saved to disk.

See Examples for further details.

Warning

Behaviour of this class has been changed since v0.3.0.

Argument save_as_state_dict is deprecated and should not be used. It is considered as True.

Argument save_interval is deprecated and should not be used. Please, use events filtering instead, e.g. ITERATION_STARTED(every=1000)

There is no more internal counter that has been used to indicate the number of save actions. User could see its value step_number in the filename, e.g. {filename_prefix}_{name}_{step_number}.pt. Actually, step_number is replaced by current engine’s epoch if score_function is specified and current iteration otherwise.

A single pt file is created instead of multiple files.

Parameters
  • dirname (str) – Directory path where objects will be saved.

  • filename_prefix (str) – Prefix for the filenames to which objects will be saved. See Notes of Checkpoint for more details.

  • score_function (callable, optional) – if not None, it should be a function taking a single argument, an Engine object, and return a score (float). Objects with highest scores will be retained.

  • score_name (str, optional) – if score_function not None, it is possible to store its value using score_name. See Notes for more details.

  • n_saved (int, optional) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept.

  • atomic (bool, optional) – If True, objects are serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occurs during saving).

  • require_empty (bool, optional) – If True, will raise exception if there are any files starting with filename_prefix in the directory ‘dirname’.

  • create_dir (bool, optional) – If True, will create directory ‘dirname’ if it doesnt exist.

  • global_step_transform (callable, optional) – global step transform function to output a desired global step. Input of the function is (engine, event_name). Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use global_step_from_engine().

  • archived (bool, optional) – Deprecated argument as models saved by torch.save are already compressed.

  • save_interval (Optional[Callable]) –

  • save_as_state_dict (bool) –

Examples

>>> import os
>>> from ignite.engine import Engine, Events
>>> from ignite.handlers import ModelCheckpoint
>>> from torch import nn
>>> trainer = Engine(lambda batch: None)
>>> handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=2, create_dir=True)
>>> model = nn.Linear(3, 3)
>>> trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {'mymodel': model})
>>> trainer.run([0], max_epochs=6)
>>> os.listdir('/tmp/models')
['myprefix_mymodel_4.pt', 'myprefix_mymodel_6.pt']
>>> handler.last_checkpoint
['/tmp/models/myprefix_mymodel_6.pt']
class ignite.handlers.EarlyStopping(patience, score_function, trainer, min_delta=0.0, cumulative_delta=False)[source]#

EarlyStopping handler can be used to stop the training if no improvement after a given number of events.

Parameters
  • patience (int) – Number of events to wait if no improvement and then stop the training.

  • score_function (callable) – It should be a function taking a single argument, an Engine object, and return a score float. An improvement is considered if the score is higher.

  • trainer (Engine) – trainer engine to stop the run if no improvement.

  • min_delta (float, optional) – A minimum increase in the score to qualify as an improvement, i.e. an increase of less than or equal to min_delta, will count as no improvement.

  • cumulative_delta (bool, optional) – It True, min_delta defines an increase since the last patience reset, otherwise, it defines an increase after the last event. Default value is False.

Examples:

from ignite.engine import Engine, Events
from ignite.handlers import EarlyStopping

def score_function(engine):
    val_loss = engine.state.metrics['nll']
    return -val_loss

handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer)
# Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset).
evaluator.add_event_handler(Events.COMPLETED, handler)
class ignite.handlers.Timer(average=False)[source]#

Timer object can be used to measure (average) time between events.

Parameters

average (bool, optional) – if True, then when .value() method is called, the returned value will be equal to total time measured, divided by the value of internal counter.

total#

total time elapsed when the Timer was running (in seconds).

Type

float

step_count#

internal counter, usefull to measure average time, e.g. of processing a single batch. Incremented with the .step() method.

Type

int

running#

flag indicating if timer is measuring time.

Type

bool

Note

When using Timer(average=True) do not forget to call timer.step() every time an event occurs. See the examples below.

Examples

Measuring total time of the epoch:

>>> from ignite.handlers import Timer
>>> import time
>>> work = lambda : time.sleep(0.1)
>>> idle = lambda : time.sleep(0.1)
>>> t = Timer(average=False)
>>> for _ in range(10):
...    work()
...    idle()
...
>>> t.value()
2.003073937026784

Measuring average time of the epoch:

>>> t = Timer(average=True)
>>> for _ in range(10):
...    work()
...    idle()
...    t.step()
...
>>> t.value()
0.2003182829997968

Measuring average time it takes to execute a single work() call:

>>> t = Timer(average=True)
>>> for _ in range(10):
...    t.resume()
...    work()
...    t.pause()
...    idle()
...    t.step()
...
>>> t.value()
0.10016545779653825

Using the Timer to measure average time it takes to process a single batch of examples:

>>> from ignite.engine import Engine, Events
>>> from ignite.handlers import Timer
>>> trainer = Engine(training_update_function)
>>> timer = Timer(average=True)
>>> timer.attach(trainer,
...              start=Events.EPOCH_STARTED,
...              resume=Events.ITERATION_STARTED,
...              pause=Events.ITERATION_COMPLETED,
...              step=Events.ITERATION_COMPLETED)
attach(engine, start=Events.STARTED, pause=Events.COMPLETED, resume=None, step=None)[source]#

Register callbacks to control the timer.

Parameters
  • engine (Engine) – Engine that this timer will be attached to.

  • start (Events) – Event which should start (reset) the timer.

  • pause (Events) – Event which should pause the timer.

  • resume (Events, optional) – Event which should resume the timer.

  • step (Events, optional) – Event which should call the step method of the counter.

Returns

self (Timer)

class ignite.handlers.TerminateOnNan(output_transform=<function TerminateOnNan.<lambda>>)[source]#

TerminateOnNan handler can be used to stop the training if the process_function’s output contains a NaN or infinite number or torch.tensor. The output can be of type: number, tensor or collection of them. The training is stopped if there is at least a single number/tensor have NaN or Infinite value. For example, if the output is [1.23, torch.tensor(…), torch.tensor(float(‘nan’))] the handler will stop the training.

Parameters

output_transform (callable, optional) – a callable that is used to transform the Engine’s process_function’s output into a number or torch.tensor or collection of them. This can be useful if, for example, you have a multi-output model and you want to check one or multiple values of the output.

Examples:

trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())