ignite.handlers#
Complete list of handlers#
- class ignite.handlers.Checkpoint(to_save, save_handler, filename_prefix='', score_function=None, score_name=None, n_saved=1, global_step_transform=None, archived=False, filename_pattern=None, include_self=False)[source]#
Checkpoint handler can be used to periodically save and load objects which have attribute
state_dict`/`load_state_dict
. This class can use specific save handlers to store on the disk or a cloud storage, etc. The Checkpoint handler (if used withDiskSaver
) also handles automatically moving data on TPU to CPU before writing the checkpoint.- Parameters
to_save (Mapping) – Dictionary with the objects to save. Objects should have implemented
state_dict
andload_state_dict
methods. If contains objects of type torch DistributedDataParallel or DataParallel, their internal wrapped model is automatically saved (to avoid additional keymodule.
in the state dictionary).save_handler (callable or
BaseSaveHandler
) – Method or callable class to use to save engine and other provided objects. Function receives two objects: checkpoint as a dictionary and filename. Ifsave_handler
is callable class, it can inherit ofBaseSaveHandler
and optionally implementremove
method to keep a fixed number of saved checkpoints. In case if user needs to save engine’s checkpoint on a disk,save_handler
can be defined withDiskSaver
.filename_prefix (str, optional) – Prefix for the file name to which objects will be saved. See Note for details.
score_function (callable, optional) – If not None, it should be a function taking a single argument,
Engine
object, and returning a score (float). Objects with highest scores will be retained.score_name (str, optional) – If
score_function
not None, it is possible to store its value usingscore_name
. See Notes for more details.n_saved (int, optional) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept.
global_step_transform (callable, optional) – global step transform function to output a desired global step. Input of the function is
(engine, event_name)
. Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please useglobal_step_from_engine()
.archived (bool, optional) – Deprecated argument as models saved by
torch.save
are already compressed.filename_pattern (str, optional) – If
filename_pattern
is provided, this pattern will be used to render checkpoint filenames. If the pattern is not defined, the default pattern would be used. See Note for details.include_self (bool) – Whether to include the state_dict of this object in the checkpoint. If True, then there must not be another object in
to_save
with keycheckpointer
.
Note
This class stores a single file as a dictionary of provided objects to save. The filename is defined by
filename_pattern
and by default has the following structure:{filename_prefix}_{name}_{suffix}.{ext}
wherefilename_prefix
is the argument passed to the constructor,name is the key in
to_save
if a single object is to store, otherwise name is “checkpoint”.suffix is composed as following
{global_step}_{score_name}={score}
.
score_function
score_name
global_step_transform
suffix
None
None
None
{engine.state.iteration}
X
None
None
{score}
X
None
X
{global_step}_{score}
X
X
X
{global_step}_{score_name}={score}
None
None
X
{global_step}
X
X
None
{score_name}={score}
Above global_step defined by the output of global_step_transform and score defined by the output of score_function.
By default, none of
score_function
,score_name
,global_step_transform
is defined, then suffix is setup by attached engine’s current iteration. The filename will be {filename_prefix}_{name}_{engine.state.iteration}.{ext}.For example,
score_name="neg_val_loss"
andscore_function
that returns -loss (as objects with highest scores will be retained), then saved filename will be{filename_prefix}_{name}_neg_val_loss=-0.1234.pt
.Note
If
filename_pattern
is given, it will be used to render the filenames.filename_pattern
is a string that can contain{filename_prefix}
,{name}
,{score}
,{score_name}
and{global_step}
as templates.For example, let
filename_pattern="{global_step}-{name}-{score}.pt"
then the saved filename will be30000-checkpoint-94.pt
Warning: Please, keep in mind that if filename collide with already used one to saved a checkpoint, new checkpoint will not be stored. This means that filename like
checkpoint.pt
will be saved only once and will not be overwritten by newer checkpoints.Note
To get the last stored filename, handler exposes attribute
last_checkpoint
:handler = Checkpoint(...) ... print(handler.last_checkpoint) > checkpoint_12345.pt
Note
This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 only process. This class supports automatically distributed configuration and if used with
DiskSaver
, checkpoint is stored by rank 0 process.Warning
When running on XLA devices, it should be run in all processes, otherwise application can get stuck on saving the checkpoint.
# Wrong: # if idist.get_rank() == 0: # handler = Checkpoint(...) # trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler) # Correct: handler = Checkpoint(...) trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)
Examples
Attach the handler to make checkpoints during training:
from ignite.engine import Engine, Events from ignite.handlers import Checkpoint, DiskSaver trainer = ... model = ... optimizer = ... lr_scheduler = ... to_save = {'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'trainer': trainer} handler = Checkpoint(to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2) trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler) trainer.run(data_loader, max_epochs=6) > ["checkpoint_7000.pt", "checkpoint_8000.pt", ]
Attach the handler to an evaluator to save best model during the training according to computed validation metric:
from ignite.engine import Engine, Events from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine trainer = ... evaluator = ... # Setup Accuracy metric computation on evaluator # Run evaluation on epoch completed event # ... def score_function(engine): return engine.state.metrics['accuracy'] to_save = {'model': model} handler = Checkpoint(to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2, filename_prefix='best', score_function=score_function, score_name="val_acc", global_step_transform=global_step_from_engine(trainer)) evaluator.add_event_handler(Events.COMPLETED, handler) trainer.run(data_loader, max_epochs=10) > ["best_model_9_val_acc=0.77.pt", "best_model_10_val_acc=0.78.pt", ]
- static load_objects(to_load, checkpoint, **kwargs)[source]#
Helper method to apply
load_state_dict
on the objects fromto_load
using states fromcheckpoint
.Exemples:
import torch from ignite.engine import Engine, Events from ignite.handlers import ModelCheckpoint, Checkpoint trainer = Engine(lambda engine, batch: None) handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=None, create_dir=True) model = torch.nn.Linear(3, 3) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) to_save = {"weights": model, "optimizer": optimizer} trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save) trainer.run(torch.randn(10, 1), 5) to_load = to_save checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth" checkpoint = torch.load(checkpoint_fp) Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
Note
If
to_load
contains objects of type torch DistributedDataParallel or DataParallel, methodload_state_dict
will applied to their internal wrapped model (obj.module
).- Parameters
to_load (Mapping) – a dictionary with objects, e.g. {“model”: model, “optimizer”: optimizer, …}
checkpoint (Mapping) – a dictionary with state_dicts to load, e.g. {“model”: model_state_dict, “optimizer”: opt_state_dict}. If to_load contains a single key, then checkpoint can contain directly corresponding state_dict.
**kwargs – Keyword arguments accepted for nn.Module.load_state_dict(). Passing strict=False enables the user to load part of the pretrained model (useful for example, in Transfer Learning)
- Return type
None
- class ignite.handlers.checkpoint.BaseSaveHandler[source]#
Base class for save handlers
Methods to override:
Note
In derived class, please, make sure that in distributed configuration overridden methods are called by a single process. Distributed configuration on XLA devices should be treated slightly differently: for saving checkpoint with xm.save() all processes should pass into the function. Otherwise, application gets stuck.
- abstract __call__(checkpoint, filename, metadata=None)[source]#
Method to save checkpoint with filename. Additionally, metadata dictionary is provided.
Metadata contains:
basename: file prefix (if provided) with checkpoint name, e.g. epoch_checkpoint.
score_name: score name if provided, e.g val_acc.
priority: checkpoint priority value (higher is better), e.g. 12 or 0.6554435
- Parameters
checkpoint (Mapping) – checkpoint dictionary to save.
filename (str) – filename associated with checkpoint.
metadata (Mapping, optional) – metadata on checkpoint to save.
- Return type
None
- class ignite.handlers.DiskSaver(dirname, atomic=True, create_dir=True, require_empty=True, **kwargs)[source]#
Handler that saves input checkpoint on a disk.
- Parameters
dirname (str) – Directory path where the checkpoint will be saved
atomic (bool, optional) – if True, checkpoint is serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occurs during saving).
create_dir (bool, optional) – if True, will create directory
dirname
if it doesnt exist.require_empty (bool, optional) – If True, will raise exception if there are any files in the directory
dirname
.**kwargs – Accepted keyword arguments for torch.save or xm.save.
- class ignite.handlers.ModelCheckpoint(dirname, filename_prefix, save_interval=None, score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, create_dir=True, save_as_state_dict=True, global_step_transform=None, archived=False, include_self=False, **kwargs)[source]#
ModelCheckpoint handler can be used to periodically save objects to disk only. If needed to store checkpoints to another storage type, please consider
Checkpoint
.This handler expects two arguments:
an
Engine
objecta dict mapping names (str) to objects that should be saved to disk.
See Examples for further details.
Warning
Behaviour of this class has been changed since v0.3.0.
Argument
save_as_state_dict
is deprecated and should not be used. It is considered as True.Argument
save_interval
is deprecated and should not be used. Please, use events filtering instead, e.g.ITERATION_STARTED(every=1000)
There is no more internal counter that has been used to indicate the number of save actions. User could see its value step_number in the filename, e.g. {filename_prefix}_{name}_{step_number}.pt. Actually, step_number is replaced by current engine’s epoch if score_function is specified and current iteration otherwise.
A single pt file is created instead of multiple files.
- Parameters
dirname (str) – Directory path where objects will be saved.
filename_prefix (str) – Prefix for the file names to which objects will be saved. See Notes of
Checkpoint
for more details.score_function (callable, optional) – if not None, it should be a function taking a single argument, an
Engine
object, and return a score (float). Objects with highest scores will be retained.score_name (str, optional) – if
score_function
not None, it is possible to store its value using score_name. See Notes for more details.n_saved (int, optional) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept.
atomic (bool, optional) – If True, objects are serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occurs during saving).
require_empty (bool, optional) – If True, will raise exception if there are any files starting with
filename_prefix
in the directorydirname
.create_dir (bool, optional) – If True, will create directory
dirname
if it does not exist.global_step_transform (callable, optional) – global step transform function to output a desired global step. Input of the function is (engine, event_name). Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use
global_step_from_engine()
.archived (bool, optional) – Deprecated argument as models saved by torch.save are already compressed.
include_self (bool) – Whether to include the state_dict of this object in the checkpoint. If True, then there must not be another object in
to_save
with keycheckpointer
.**kwargs – Accepted keyword arguments for torch.save or xm.save in DiskSaver.
save_as_state_dict (bool) –
Examples
>>> import os >>> from ignite.engine import Engine, Events >>> from ignite.handlers import ModelCheckpoint >>> from torch import nn >>> trainer = Engine(lambda engine, batch: None) >>> handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=2, create_dir=True) >>> model = nn.Linear(3, 3) >>> trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {'mymodel': model}) >>> trainer.run([0], max_epochs=6) >>> os.listdir('/tmp/models') ['myprefix_mymodel_4.pt', 'myprefix_mymodel_6.pt'] >>> handler.last_checkpoint ['/tmp/models/myprefix_mymodel_6.pt']
- class ignite.handlers.EarlyStopping(patience, score_function, trainer, min_delta=0.0, cumulative_delta=False)[source]#
EarlyStopping handler can be used to stop the training if no improvement after a given number of events.
- Parameters
patience (int) – Number of events to wait if no improvement and then stop the training.
score_function (callable) – It should be a function taking a single argument, an
Engine
object, and return a score float. An improvement is considered if the score is higher.trainer (Engine) – trainer engine to stop the run if no improvement.
min_delta (float, optional) – A minimum increase in the score to qualify as an improvement, i.e. an increase of less than or equal to min_delta, will count as no improvement.
cumulative_delta (bool, optional) – It True, min_delta defines an increase since the last patience reset, otherwise, it defines an increase after the last event. Default value is False.
Examples:
from ignite.engine import Engine, Events from ignite.handlers import EarlyStopping def score_function(engine): val_loss = engine.state.metrics['nll'] return -val_loss handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer) # Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset). evaluator.add_event_handler(Events.COMPLETED, handler)
- class ignite.handlers.Timer(average=False)[source]#
Timer object can be used to measure (average) time between events.
- Parameters
average (bool, optional) – if True, then when
.value()
method is called, the returned value will be equal to total time measured, divided by the value of internal counter.
- step_count#
internal counter, useful to measure average time, e.g. of processing a single batch. Incremented with the
.step()
method.- Type
Note
When using
Timer(average=True)
do not forget to calltimer.step()
every time an event occurs. See the examples below.Examples
Measuring total time of the epoch:
>>> from ignite.handlers import Timer >>> import time >>> work = lambda : time.sleep(0.1) >>> idle = lambda : time.sleep(0.1) >>> t = Timer(average=False) >>> for _ in range(10): ... work() ... idle() ... >>> t.value() 2.003073937026784
Measuring average time of the epoch:
>>> t = Timer(average=True) >>> for _ in range(10): ... work() ... idle() ... t.step() ... >>> t.value() 0.2003182829997968
Measuring average time it takes to execute a single
work()
call:>>> t = Timer(average=True) >>> for _ in range(10): ... t.resume() ... work() ... t.pause() ... idle() ... t.step() ... >>> t.value() 0.10016545779653825
Using the Timer to measure average time it takes to process a single batch of examples:
>>> from ignite.engine import Engine, Events >>> from ignite.handlers import Timer >>> trainer = Engine(training_update_function) >>> timer = Timer(average=True) >>> timer.attach(trainer, ... start=Events.EPOCH_STARTED, ... resume=Events.ITERATION_STARTED, ... pause=Events.ITERATION_COMPLETED, ... step=Events.ITERATION_COMPLETED)
- attach(engine, start=Events.STARTED, pause=Events.COMPLETED, resume=None, step=None)[source]#
Register callbacks to control the timer.
- Parameters
engine (Engine) – Engine that this timer will be attached to.
start (Events) – Event which should start (reset) the timer.
pause (Events) – Event which should pause the timer.
resume (Events, optional) – Event which should resume the timer.
step (Events, optional) – Event which should call the step method of the counter.
- Returns
self (Timer)
- class ignite.handlers.TerminateOnNan(output_transform=<function TerminateOnNan.<lambda>>)[source]#
TerminateOnNan handler can be used to stop the training if the process_function’s output contains a NaN or infinite number or torch.tensor. The output can be of type: number, tensor or collection of them. The training is stopped if there is at least a single number/tensor have NaN or Infinite value. For example, if the output is [1.23, torch.tensor(…), torch.tensor(float(‘nan’))] the handler will stop the training.
- Parameters
output_transform (callable, optional) – a callable that is used to transform the
Engine
’sprocess_function
’s output into a number or torch.tensor or collection of them. This can be useful if, for example, you have a multi-output model and you want to check one or multiple values of the output.
Examples:
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())