Shortcuts

ignite.handlers#

Complete list of handlers#

ModelCheckpoint

ModelCheckpoint handler can be used to periodically save objects to disk only.

Checkpoint

Checkpoint handler can be used to periodically save and load objects which have attribute state_dict/load_state_dict.

DiskSaver

Handler that saves input checkpoint on a disk.

Timer

Timer object can be used to measure (average) time between events.

EarlyStopping

EarlyStopping handler can be used to stop the training if no improvement after a given number of events.

TerminateOnNan

TerminateOnNan handler can be used to stop the training if the process_function's output contains a NaN or infinite number or torch.tensor.

global_step_from_engine

Helper method to setup global_step_transform function using another engine.

TimeLimit

TimeLimit handler can be used to control training time for computing environments where session time is limited.

class ignite.handlers.Checkpoint(to_save, save_handler, filename_prefix='', score_function=None, score_name=None, n_saved=1, global_step_transform=None, archived=False, filename_pattern=None, include_self=False, greater_or_equal=False)[source]#

Checkpoint handler can be used to periodically save and load objects which have attribute state_dict/load_state_dict. This class can use specific save handlers to store on the disk or a cloud storage, etc. The Checkpoint handler (if used with DiskSaver) also handles automatically moving data on TPU to CPU before writing the checkpoint.

Parameters
  • to_save (Mapping) – Dictionary with the objects to save. Objects should have implemented state_dict and load_state_dict methods. If contains objects of type torch DistributedDataParallel or DataParallel, their internal wrapped model is automatically saved (to avoid additional key module. in the state dictionary).

  • save_handler (callable or BaseSaveHandler) – Method or callable class to use to save engine and other provided objects. Function receives two objects: checkpoint as a dictionary and filename. If save_handler is callable class, it can inherit of BaseSaveHandler and optionally implement remove method to keep a fixed number of saved checkpoints. In case if user needs to save engine’s checkpoint on a disk, save_handler can be defined with DiskSaver.

  • filename_prefix (str, optional) – Prefix for the file name to which objects will be saved. See Note for details.

  • score_function (callable, optional) – If not None, it should be a function taking a single argument, Engine object, and returning a score (float). Objects with highest scores will be retained.

  • score_name (str, optional) – If score_function not None, it is possible to store its value using score_name. See Notes for more details.

  • n_saved (int, optional) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept.

  • global_step_transform (callable, optional) – global step transform function to output a desired global step. Input of the function is (engine, event_name). Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use global_step_from_engine().

  • archived (bool, optional) – Deprecated argument as models saved by torch.save are already compressed.

  • filename_pattern (str, optional) – If filename_pattern is provided, this pattern will be used to render checkpoint filenames. If the pattern is not defined, the default pattern would be used. See Note for details.

  • include_self (bool) – Whether to include the state_dict of this object in the checkpoint. If True, then there must not be another object in to_save with key checkpointer.

  • greater_or_equal (bool) – if True, the latest equally scored model is stored. Otherwise, the first model. Default, False.

Note

This class stores a single file as a dictionary of provided objects to save. The filename is defined by filename_pattern and by default has the following structure: {filename_prefix}_{name}_{suffix}.{ext} where

  • filename_prefix is the argument passed to the constructor,

  • name is the key in to_save if a single object is to store, otherwise name is “checkpoint”.

  • suffix is composed as following {global_step}_{score_name}={score}.

score_function

score_name

global_step_transform

suffix

None

None

None

{engine.state.iteration}

X

None

None

{score}

X

None

X

{global_step}_{score}

X

X

X

{global_step}_{score_name}={score}

None

None

X

{global_step}

X

X

None

{score_name}={score}

Above global_step defined by the output of global_step_transform and score defined by the output of score_function.

By default, none of score_function, score_name, global_step_transform is defined, then suffix is setup by attached engine’s current iteration. The filename will be {filename_prefix}_{name}_{engine.state.iteration}.{ext}.

For example, score_name="neg_val_loss" and score_function that returns -loss (as objects with highest scores will be retained), then saved filename will be {filename_prefix}_{name}_neg_val_loss=-0.1234.pt.

Note

If filename_pattern is given, it will be used to render the filenames. filename_pattern is a string that can contain {filename_prefix}, {name}, {score}, {score_name} and {global_step} as templates.

For example, let filename_pattern="{global_step}-{name}-{score}.pt" then the saved filename will be 30000-checkpoint-94.pt

Warning: Please, keep in mind that if filename collide with already used one to saved a checkpoint, new checkpoint will replace the older one. This means that filename like checkpoint.pt will be saved every call and will always be overwritten by newer checkpoints.

Note

To get the last stored filename, handler exposes attribute last_checkpoint:

handler = Checkpoint(...)
...
print(handler.last_checkpoint)
> checkpoint_12345.pt

Note

This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 only process. This class supports automatically distributed configuration and if used with DiskSaver, checkpoint is stored by rank 0 process.

Warning

When running on XLA devices, it should be run in all processes, otherwise application can get stuck on saving the checkpoint.

# Wrong:
# if idist.get_rank() == 0:
#     handler = Checkpoint(...)
#     trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)

# Correct:
handler = Checkpoint(...)
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)

Examples

Attach the handler to make checkpoints during training:

from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint, DiskSaver

trainer = ...
model = ...
optimizer = ...
lr_scheduler = ...

to_save = {'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'trainer': trainer}

if (checkpoint_iters):
    # A: Output is "checkpoint_<iteration>.pt"
    handler = Checkpoint(
        to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2
    )
    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)
else:
    # B:Output is "checkpoint_<epoch>.pt"
    gst = lambda *_: trainer.state.epoch
    handler = Checkpoint(
        to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2, global_step_transform=gst
    )
    trainer.add_event_handler(Events.EPOCH_COMPLETED, handler)

trainer.run(data_loader, max_epochs=6)
> A: ["checkpoint_7000.pt", "checkpoint_8000.pt", ]
> B: ["checkpoint_5.pt", "checkpoint_6.pt", ]

Attach the handler to an evaluator to save best model during the training according to computed validation metric:

from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine

trainer = ...
evaluator = ...
# Setup Accuracy metric computation on evaluator
# Run evaluation on epoch completed event
# ...

score_function = Checkpoint.get_default_score_fn("accuracy")

to_save = {'model': model}
handler = Checkpoint(
    to_save, DiskSaver('/tmp/models', create_dir=True),
    n_saved=2, filename_prefix='best',
    score_function=score_function, score_name="val_acc",
    global_step_transform=global_step_from_engine(trainer)
)

evaluator.add_event_handler(Events.COMPLETED, handler)

trainer.run(data_loader, max_epochs=10)
> ["best_model_9_val_acc=0.77.pt", "best_model_10_val_acc=0.78.pt", ]

Changed in version 0.4.3:

  • Checkpoint can save model with same filename.

  • Added greater_or_equal argument.

static get_default_score_fn(metric_name, score_sign=1.0)[source]#

Helper method to get default score function based on the metric name.

Parameters
  • metric_name (str) – metric name to get the value from engine.state.metrics. Engine is the one to which Checkpoint handler is added.

  • score_sign (float) – sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better, a negative score sign should be used (objects with larger score are retained). Default, 1.0.

Return type

Callable

Exemples:

from ignite.handlers import Checkpoint

best_acc_score = Checkpoint.get_default_score_fn("accuracy")

best_model_handler = Checkpoint(
    to_save, save_handler, score_name="val_accuracy", score_function=best_acc_score
)
evaluator.add_event_handler(Events.COMPLETED, best_model_handler)

Usage with error-like metric:

from ignite.handlers import Checkpoint

neg_loss_score = Checkpoint.get_default_score_fn("loss", -1.0)

best_model_handler = Checkpoint(
    to_save, save_handler, score_name="val_neg_loss", score_function=neg_loss_score
)
evaluator.add_event_handler(Events.COMPLETED, best_model_handler)

New in version 0.4.3.

static load_objects(to_load, checkpoint, **kwargs)[source]#

Helper method to apply load_state_dict on the objects from to_load using states from checkpoint.

Exemples:

import torch
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, Checkpoint
trainer = Engine(lambda engine, batch: None)
handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=None, create_dir=True)
model = torch.nn.Linear(3, 3)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
to_save = {"weights": model, "optimizer": optimizer}
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save)
trainer.run(torch.randn(10, 1), 5)

to_load = to_save
checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth"
checkpoint = torch.load(checkpoint_fp)
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)

Note

If to_load contains objects of type torch DistributedDataParallel or DataParallel, method load_state_dict will applied to their internal wrapped model (obj.module).

Parameters
  • to_load (Mapping) – a dictionary with objects, e.g. {“model”: model, “optimizer”: optimizer, …}

  • checkpoint (Mapping) – a dictionary with state_dicts to load, e.g. {“model”: model_state_dict, “optimizer”: opt_state_dict}. If to_load contains a single key, then checkpoint can contain directly corresponding state_dict.

  • **kwargs – Keyword arguments accepted for nn.Module.load_state_dict(). Passing strict=False enables the user to load part of the pretrained model (useful for example, in Transfer Learning)

Return type

None

load_state_dict(state_dict)[source]#

Method replace internal state of the class with provided state dict data.

Parameters

state_dict (Mapping) – a dict with “saved” key and list of (priority, filename) pairs as values.

Return type

None

reset()[source]#

Method to reset saved checkpoint names.

Use this method if the engine will independently run multiple times:

from ignite.handlers import Checkpoint

trainer = ...
checkpointer = Checkpoint(...)

trainer.add_event_handler(Events.COMPLETED, checkpointer)
trainer.add_event_handler(Events.STARTED, checkpointer.reset)

# fold 0
trainer.run(data0, max_epochs=max_epochs)
print("Last checkpoint:", checkpointer.last_checkpoint)

# fold 1
trainer.run(data1, max_epochs=max_epochs)
print("Last checkpoint:", checkpointer.last_checkpoint)

New in version 0.4.3.

Return type

None

static setup_filename_pattern(with_prefix=True, with_score=True, with_score_name=True, with_global_step=True)[source]#

Helper method to get the default filename pattern for a checkpoint.

Parameters
  • with_prefix (bool) – If True, the filename_prefix is added to the filename pattern: {filename_prefix}_{name}.... Default, True.

  • with_score (bool) – If True, score is added to the filename pattern: ..._{score}.{ext}. Default, True. At least one of with_score and with_global_step should be True.

  • with_score_name (bool) – If True, score_name is added to the filename pattern: ..._{score_name}={score}.{ext}. If activated, argument with_score should be also True, otherwise an error is raised. Default, True.

  • with_global_step (bool) – If True, {global_step} is added to the filename pattern: ...{name}_{global_step}.... At least one of with_score and with_global_step should be True.

Return type

str

Example

from ignite.handlers import Checkpoint

filename_pattern = Checkpoint.setup_filename_pattern()

print(filename_pattern)
> "{filename_prefix}_{name}_{global_step}_{score_name}={score}.{ext}"

New in version 0.4.3.

state_dict()[source]#

Method returns state dict with saved items: list of (priority, filename) pairs. Can be used to save internal state of the class.

Return type

OrderedDict[str, List[Tuple[int, str]]]

class ignite.handlers.checkpoint.BaseSaveHandler[source]#

Base class for save handlers

Methods to override:

Note

In derived class, please, make sure that in distributed configuration overridden methods are called by a single process. Distributed configuration on XLA devices should be treated slightly differently: for saving checkpoint with xm.save() all processes should pass into the function. Otherwise, application gets stuck.

abstract __call__(checkpoint, filename, metadata=None)[source]#

Method to save checkpoint with filename. Additionally, metadata dictionary is provided.

Metadata contains:

  • basename: file prefix (if provided) with checkpoint name, e.g. epoch_checkpoint.

  • score_name: score name if provided, e.g val_acc.

  • priority: checkpoint priority value (higher is better), e.g. 12 or 0.6554435

Parameters
  • checkpoint (Mapping) – checkpoint dictionary to save.

  • filename (str) – filename associated with checkpoint.

  • metadata (Mapping, optional) – metadata on checkpoint to save.

Return type

None

abstract remove(filename)[source]#

Method to remove saved checkpoint.

Parameters

filename (str) – filename associated with checkpoint.

Return type

None

class ignite.handlers.DiskSaver(dirname, atomic=True, create_dir=True, require_empty=True, **kwargs)[source]#

Handler that saves input checkpoint on a disk.

Parameters
  • dirname (str) – Directory path where the checkpoint will be saved

  • atomic (bool, optional) – if True, checkpoint is serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occurs during saving).

  • create_dir (bool, optional) – if True, will create directory dirname if it doesnt exist.

  • require_empty (bool, optional) – If True, will raise exception if there are any files in the directory dirname.

  • **kwargs – Accepted keyword arguments for torch.save or xm.save.

Changed in version 0.4.2: Accept kwargs for torch.save or xm.save.

class ignite.handlers.ModelCheckpoint(dirname, filename_prefix, save_interval=None, score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, create_dir=True, save_as_state_dict=True, global_step_transform=None, archived=False, include_self=False, **kwargs)[source]#

ModelCheckpoint handler can be used to periodically save objects to disk only. If needed to store checkpoints to another storage type, please consider Checkpoint.

This handler expects two arguments:

  • an Engine object

  • a dict mapping names (str) to objects that should be saved to disk.

See Examples for further details.

Warning

Behaviour of this class has been changed since v0.3.0.

Argument save_as_state_dict is deprecated and should not be used. It is considered as True.

Argument save_interval is deprecated and should not be used. Please, use events filtering instead, e.g. ITERATION_STARTED(every=1000)

There is no more internal counter that has been used to indicate the number of save actions. User could see its value step_number in the filename, e.g. {filename_prefix}_{name}_{step_number}.pt. Actually, step_number is replaced by current engine’s epoch if score_function is specified and current iteration otherwise.

A single pt file is created instead of multiple files.

Parameters
  • dirname (str) – Directory path where objects will be saved.

  • filename_prefix (str) – Prefix for the file names to which objects will be saved. See Notes of Checkpoint for more details.

  • score_function (callable, optional) – if not None, it should be a function taking a single argument, an Engine object, and return a score (float). Objects with highest scores will be retained.

  • score_name (str, optional) – if score_function not None, it is possible to store its value using score_name. See Notes for more details.

  • n_saved (int, optional) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept.

  • atomic (bool, optional) – If True, objects are serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occurs during saving).

  • require_empty (bool, optional) – If True, will raise exception if there are any files starting with filename_prefix in the directory dirname.

  • create_dir (bool, optional) – If True, will create directory dirname if it does not exist.

  • global_step_transform (callable, optional) – global step transform function to output a desired global step. Input of the function is (engine, event_name). Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use global_step_from_engine().

  • archived (bool, optional) – Deprecated argument as models saved by torch.save are already compressed.

  • include_self (bool) – Whether to include the state_dict of this object in the checkpoint. If True, then there must not be another object in to_save with key checkpointer.

  • **kwargs – Accepted keyword arguments for torch.save or xm.save in DiskSaver.

  • save_interval (Optional[int]) –

  • save_as_state_dict (bool) –

Changed in version 0.4.2: Accept kwargs for torch.save or xm.save

Examples

>>> import os
>>> from ignite.engine import Engine, Events
>>> from ignite.handlers import ModelCheckpoint
>>> from torch import nn
>>> trainer = Engine(lambda engine, batch: None)
>>> handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=2, create_dir=True)
>>> model = nn.Linear(3, 3)
>>> trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {'mymodel': model})
>>> trainer.run([0, 1, 2, 3, 4], max_epochs=6)
>>> os.listdir('/tmp/models')
['myprefix_mymodel_20.pt', 'myprefix_mymodel_30.pt']
>>> handler.last_checkpoint
['/tmp/models/myprefix_mymodel_30.pt']
class ignite.handlers.EarlyStopping(patience, score_function, trainer, min_delta=0.0, cumulative_delta=False)[source]#

EarlyStopping handler can be used to stop the training if no improvement after a given number of events.

Parameters
  • patience (int) – Number of events to wait if no improvement and then stop the training.

  • score_function (callable) – It should be a function taking a single argument, an Engine object, and return a score float. An improvement is considered if the score is higher.

  • trainer (Engine) – trainer engine to stop the run if no improvement.

  • min_delta (float, optional) – A minimum increase in the score to qualify as an improvement, i.e. an increase of less than or equal to min_delta, will count as no improvement.

  • cumulative_delta (bool, optional) – It True, min_delta defines an increase since the last patience reset, otherwise, it defines an increase after the last event. Default value is False.

Examples:

from ignite.engine import Engine, Events
from ignite.handlers import EarlyStopping

def score_function(engine):
    val_loss = engine.state.metrics['nll']
    return -val_loss

handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer)
# Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset).
evaluator.add_event_handler(Events.COMPLETED, handler)
class ignite.handlers.Timer(average=False)[source]#

Timer object can be used to measure (average) time between events.

Parameters

average (bool, optional) – if True, then when .value() method is called, the returned value will be equal to total time measured, divided by the value of internal counter.

total#

total time elapsed when the Timer was running (in seconds).

Type

float

step_count#

internal counter, useful to measure average time, e.g. of processing a single batch. Incremented with the .step() method.

Type

int

running#

flag indicating if timer is measuring time.

Type

bool

Note

When using Timer(average=True) do not forget to call timer.step() every time an event occurs. See the examples below.

Examples

Measuring total time of the epoch:

>>> from ignite.handlers import Timer
>>> import time
>>> work = lambda : time.sleep(0.1)
>>> idle = lambda : time.sleep(0.1)
>>> t = Timer(average=False)
>>> for _ in range(10):
...    work()
...    idle()
...
>>> t.value()
2.003073937026784

Measuring average time of the epoch:

>>> t = Timer(average=True)
>>> for _ in range(10):
...    work()
...    idle()
...    t.step()
...
>>> t.value()
0.2003182829997968

Measuring average time it takes to execute a single work() call:

>>> t = Timer(average=True)
>>> for _ in range(10):
...    t.resume()
...    work()
...    t.pause()
...    idle()
...    t.step()
...
>>> t.value()
0.10016545779653825

Using the Timer to measure average time it takes to process a single batch of examples:

>>> from ignite.engine import Engine, Events
>>> from ignite.handlers import Timer
>>> trainer = Engine(training_update_function)
>>> timer = Timer(average=True)
>>> timer.attach(trainer,
...              start=Events.EPOCH_STARTED,
...              resume=Events.ITERATION_STARTED,
...              pause=Events.ITERATION_COMPLETED,
...              step=Events.ITERATION_COMPLETED)
attach(engine, start=Events.STARTED, pause=Events.COMPLETED, resume=None, step=None)[source]#

Register callbacks to control the timer.

Parameters
  • engine (Engine) – Engine that this timer will be attached to.

  • start (Events) – Event which should start (reset) the timer.

  • pause (Events) – Event which should pause the timer.

  • resume (Events, optional) – Event which should resume the timer.

  • step (Events, optional) – Event which should call the step method of the counter.

Returns

self (Timer)

Return type

Timer

class ignite.handlers.TerminateOnNan(output_transform=<function TerminateOnNan.<lambda>>)[source]#

TerminateOnNan handler can be used to stop the training if the process_function’s output contains a NaN or infinite number or torch.tensor. The output can be of type: number, tensor or collection of them. The training is stopped if there is at least a single number/tensor have NaN or Infinite value. For example, if the output is [1.23, torch.tensor(…), torch.tensor(float(‘nan’))] the handler will stop the training.

Parameters

output_transform (callable, optional) – a callable that is used to transform the Engine’s process_function’s output into a number or torch.tensor or collection of them. This can be useful if, for example, you have a multi-output model and you want to check one or multiple values of the output.

Examples:

trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
ignite.handlers.global_step_from_engine(engine, custom_event_name=None)[source]#

Helper method to setup global_step_transform function using another engine. This can be helpful for logging trainer epoch/iteration while output handler is attached to an evaluator.

Parameters
  • engine (Engine) – engine which state is used to provide the global step

  • custom_event_name (optional) – registered event name. Optional argument, event name to use.

Returns

global step

Return type

Callable

class ignite.handlers.TimeLimit(limit_sec=28800)[source]#

TimeLimit handler can be used to control training time for computing environments where session time is limited. Timer starts when handler is created and not training started. This handler gracefully terminates the training if time passed in the training exceeds a limit.

Parameters

limit_sec (int, optional) – Maximum time before training terminates (in seconds). Defaults to 28800.

Examples

from ignite.engine import Events
from ignite.handlers import TimeLimit

handler = TimeLimit() # 8 hours of training
trainer.add_event_handler(Events.ITERATION_COMPLETED, handler)

New in version 0.4.3.