Checkpoint#
- class ignite.handlers.checkpoint.Checkpoint(to_save, save_handler, filename_prefix='', score_function=None, score_name=None, n_saved=1, global_step_transform=None, archived=False, filename_pattern=None, include_self=False, greater_or_equal=False, save_on_rank=0)[source]#
Checkpoint handler can be used to periodically save and load objects which have attribute
state_dict/load_state_dict
. This class can use specific save handlers to store on the disk or a cloud storage, etc. The Checkpoint handler (if used withDiskSaver
) also handles automatically moving data on TPU to CPU before writing the checkpoint.- Parameters
to_save (Mapping) – Dictionary with the objects to save. Objects should have implemented
state_dict
andload_state_dict
methods. If contains objects of type torch DistributedDataParallel or DataParallel, their internal wrapped model is automatically saved (to avoid additional keymodule.
in the state dictionary).save_handler (Union[str, Path, Callable, BaseSaveHandler]) – String, method or callable class used to save engine and other provided objects. Function receives two objects: checkpoint as a dictionary and filename. If
save_handler
is callable class, it can inherit ofBaseSaveHandler
and optionally implementremove
method to keep a fixed number of saved checkpoints. In case if user needs to save engine’s checkpoint on a disk,save_handler
can be defined withDiskSaver
or a string specifying directory name can be passed tosave_handler
.filename_prefix (str) – Prefix for the file name to which objects will be saved. See Note for details.
score_function (Optional[Callable]) – If not None, it should be a function taking a single argument,
Engine
object, and returning a score (float). Objects with highest scores will be retained.score_name (Optional[str]) – If
score_function
not None, it is possible to store its value usingscore_name
. Ifscore_function
is None,score_name
can be used alone to definescore_function
asCheckpoint.get_default_score_fn(score_name)
by default.n_saved (Optional[int]) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept.
global_step_transform (Optional[Callable]) – global step transform function to output a desired global step. Input of the function is
(engine, event_name)
. Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please useglobal_step_from_engine()
.archived (bool) – Deprecated argument as models saved by
torch.save
are already compressed.filename_pattern (Optional[str]) – If
filename_pattern
is provided, this pattern will be used to render checkpoint filenames. If the pattern is not defined, the default pattern would be used. See Note for details.include_self (bool) – Whether to include the state_dict of this object in the checkpoint. If True, then there must not be another object in
to_save
with keycheckpointer
.greater_or_equal (bool) – if True, the latest equally scored model is stored. Otherwise, the first model. Default, False.
save_on_rank (int) – Which rank to save the objects on, in the distributed configuration. If
save_handler
is string orPath
, this is also used to instantiate aDiskSaver
.
Note
This class stores a single file as a dictionary of provided objects to save. The filename is defined by
filename_pattern
and by default has the following structure:{filename_prefix}_{name}_{suffix}.{ext}
wherefilename_prefix
is the argument passed to the constructor,name is the key in
to_save
if a single object is to store, otherwise name is “checkpoint”.suffix is composed as following
{global_step}_{score_name}={score}
.
score_function
score_name
global_step_transform
suffix
None
None
None
{engine.state.iteration}
X
None
None
{score}
X
None
X
{global_step}_{score}
X
X
X
{global_step}_{score_name}={score}
None
None
X
{global_step}
X
X
None
{score_name}={score}
Above global_step defined by the output of global_step_transform and score defined by the output of score_function.
By default, none of
score_function
,score_name
,global_step_transform
is defined, then suffix is setup by attached engine’s current iteration. The filename will be {filename_prefix}_{name}_{engine.state.iteration}.{ext}.For example,
score_name="neg_val_loss"
andscore_function
that returns -loss (as objects with highest scores will be retained), then saved filename will be{filename_prefix}_{name}_neg_val_loss=-0.1234.pt
.Note
If
filename_pattern
is given, it will be used to render the filenames.filename_pattern
is a string that can contain{filename_prefix}
,{name}
,{score}
,{score_name}
and{global_step}
as templates.For example, let
filename_pattern="{global_step}-{name}-{score}.pt"
then the saved filename will be30000-checkpoint-94.pt
Warning: Please, keep in mind that if filename collide with already used one to saved a checkpoint, new checkpoint will replace the older one. This means that filename like
checkpoint.pt
will be saved every call and will always be overwritten by newer checkpoints.Note
To get the last stored filename, handler exposes attribute
last_checkpoint
:handler = Checkpoint(...) ... print(handler.last_checkpoint) > checkpoint_12345.pt
Note
This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 only process. This class supports automatically distributed configuration and if used with
DiskSaver
, checkpoint is stored by rank 0 process.Warning
When running on XLA devices or using
ZeroRedundancyOptimizer
, it should be run in all processes, otherwise application can get stuck while saving the checkpoint.# Wrong: # if idist.get_rank() == 0: # handler = Checkpoint(...) # trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler) # Correct: handler = Checkpoint(...) trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)
Examples
Attach the handler to make checkpoints during training:
from ignite.engine import Engine, Events from ignite.handlers import Checkpoint trainer = ... model = ... optimizer = ... lr_scheduler = ... to_save = {'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'trainer': trainer} if (checkpoint_iters): # A: Output is "checkpoint_<iteration>.pt" handler = Checkpoint( to_save, '/tmp/models', n_saved=2 ) trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler) else: # B:Output is "checkpoint_<epoch>.pt" gst = lambda *_: trainer.state.epoch handler = Checkpoint( to_save, '/tmp/models', n_saved=2, global_step_transform=gst ) trainer.add_event_handler(Events.EPOCH_COMPLETED, handler) trainer.run(data_loader, max_epochs=6) > A: ["checkpoint_7000.pt", "checkpoint_8000.pt", ] > B: ["checkpoint_5.pt", "checkpoint_6.pt", ]
Attach the handler to an evaluator to save best model during the training according to computed validation metric:
from ignite.engine import Engine, Events from ignite.handlers import Checkpoint, global_step_from_engine trainer = ... evaluator = ... # Setup Accuracy metric computation on evaluator. # evaluator.state.metrics contain 'accuracy', # which will be used to define ``score_function`` automatically. # Run evaluation on epoch completed event # ... to_save = {'model': model} handler = Checkpoint( to_save, '/tmp/models', n_saved=2, filename_prefix='best', score_name="accuracy", global_step_transform=global_step_from_engine(trainer) ) evaluator.add_event_handler(Events.COMPLETED, handler) trainer.run(data_loader, max_epochs=10) > ["best_model_9_accuracy=0.77.pt", "best_model_10_accuracy=0.78.pt", ]
Customise the
save_handler
:handler = Checkpoint( to_save, save_handler=DiskSaver('/tmp/models', create_dir=True, **kwargs), n_saved=2 )
Changed in version 0.4.3:
Checkpoint can save model with same filename.
Added
greater_or_equal
argument.
Changed in version 0.4.7:
score_name can be used to define score_function automatically without providing score_function.
save_handler automatically saves to disk if path to directory is provided.
save_on_rank saves objects on this rank in a distributed configuration.
Methods
Helper method to get default score function based on the metric name.
Helper method to apply
load_state_dict
on the objects fromto_load
using states fromcheckpoint
.Method replace internal state of the class with provided state dict data.
Helper method to apply
load_state_dict
on the objects fromto_load
.Method to reset saved checkpoint names.
Helper method to get the default filename pattern for a checkpoint.
Method returns state dict with saved items: list of
(priority, filename)
pairs.- class Item(priority, filename)#
Create new instance of Item(priority, filename)
- static get_default_score_fn(metric_name, score_sign=1.0)[source]#
Helper method to get default score function based on the metric name.
- Parameters
metric_name (str) – metric name to get the value from
engine.state.metrics
. Engine is the one to whichCheckpoint
handler is added.score_sign (float) – sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better, a negative score sign should be used (objects with larger score are retained). Default, 1.0.
- Return type
Examples
from ignite.handlers import Checkpoint best_acc_score = Checkpoint.get_default_score_fn("accuracy") best_model_handler = Checkpoint( to_save, save_handler, score_name="val_accuracy", score_function=best_acc_score ) evaluator.add_event_handler(Events.COMPLETED, best_model_handler)
Usage with error-like metric:
from ignite.handlers import Checkpoint neg_loss_score = Checkpoint.get_default_score_fn("loss", -1.0) best_model_handler = Checkpoint( to_save, save_handler, score_name="val_neg_loss", score_function=neg_loss_score ) evaluator.add_event_handler(Events.COMPLETED, best_model_handler)
New in version 0.4.3.
- static load_objects(to_load, checkpoint, **kwargs)[source]#
Helper method to apply
load_state_dict
on the objects fromto_load
using states fromcheckpoint
.- Parameters
to_load (Mapping) – a dictionary with objects, e.g. {“model”: model, “optimizer”: optimizer, …}
checkpoint (Union[str, Mapping, Path]) – a path, a string filepath or a dictionary with state_dicts to load, e.g. {“model”: model_state_dict, “optimizer”: opt_state_dict}. If to_load contains a single key, then checkpoint can contain directly corresponding state_dict.
kwargs (Any) – Keyword arguments accepted for nn.Module.load_state_dict(). Passing strict=False enables the user to load part of the pretrained model (useful for example, in Transfer Learning)
- Return type
None
Examples
import tempfile from pathlib import Path import torch from ignite.engine import Engine, Events from ignite.handlers import ModelCheckpoint, Checkpoint trainer = Engine(lambda engine, batch: None) with tempfile.TemporaryDirectory() as tmpdirname: handler = ModelCheckpoint(tmpdirname, 'myprefix', n_saved=None, create_dir=True) model = torch.nn.Linear(3, 3) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) to_save = {"weights": model, "optimizer": optimizer} trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save) trainer.run(torch.randn(10, 1), 5) to_load = to_save checkpoint_fp = Path(tmpdirname) / 'myprefix_checkpoint_40.pt' checkpoint = torch.load(checkpoint_fp) Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint) # or using a string for checkpoint filepath to_load = to_save checkpoint_fp = Path(tmpdirname) / 'myprefix_checkpoint_40.pt' Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_fp)
Note
If
to_load
contains objects of type torch DistributedDataParallel or DataParallel, methodload_state_dict
will applied to their internal wrapped model (obj.module
).
- load_state_dict(state_dict)[source]#
Method replace internal state of the class with provided state dict data.
- Parameters
state_dict (Mapping) – a dict with “saved” key and list of
(priority, filename)
pairs as values.- Return type
None
- reload_objects(to_load, load_kwargs=None, **filename_components)[source]#
Helper method to apply
load_state_dict
on the objects fromto_load
. Filename components such as name, score and global state can be configured.- Parameters
to_load (Mapping) – a dictionary with objects, e.g. {“model”: model, “optimizer”: optimizer, …}
load_kwargs (Optional[Dict]) – Keyword arguments accepted for nn.Module.load_state_dict(). Passing strict=False enables the user to load part of the pretrained model (useful for example, in Transfer Learning)
filename_components (Any) – Filename components used to define the checkpoint file path. Keyword arguments accepted are name, score and global_state.
- Return type
None
Examples
import tempfile import torch from ignite.engine import Engine, Events from ignite.handlers import ModelCheckpoint trainer = Engine(lambda engine, batch: None) with tempfile.TemporaryDirectory() as tmpdirname: checkpoint = ModelCheckpoint(tmpdirname, 'myprefix', n_saved=None, create_dir=True) model = torch.nn.Linear(3, 3) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) to_save = {"weights": model, "optimizer": optimizer} trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), checkpoint, to_save) trainer.run(torch.randn(10, 1), 5) to_load = to_save # load checkpoint myprefix_checkpoint_40.pt checkpoint.reload_objects(to_load=to_load, global_step=40)
Note
If
to_load
contains objects of type torch DistributedDataParallel or DataParallel, methodload_state_dict
will applied to their internal wrapped model (obj.module
).
- reset()[source]#
Method to reset saved checkpoint names.
Use this method if the engine will independently run multiple times:
from ignite.handlers import Checkpoint trainer = ... checkpointer = Checkpoint(...) trainer.add_event_handler(Events.COMPLETED, checkpointer) trainer.add_event_handler(Events.STARTED, checkpointer.reset) # fold 0 trainer.run(data0, max_epochs=max_epochs) print("Last checkpoint:", checkpointer.last_checkpoint) # fold 1 trainer.run(data1, max_epochs=max_epochs) print("Last checkpoint:", checkpointer.last_checkpoint)
New in version 0.4.3.
- Return type
None
- static setup_filename_pattern(with_prefix=True, with_score=True, with_score_name=True, with_global_step=True)[source]#
Helper method to get the default filename pattern for a checkpoint.
- Parameters
with_prefix (bool) – If True, the
filename_prefix
is added to the filename pattern:{filename_prefix}_{name}...
. Default, True.with_score (bool) – If True,
score
is added to the filename pattern:..._{score}.{ext}
. Default, True. At least one ofwith_score
andwith_global_step
should be True.with_score_name (bool) – If True,
score_name
is added to the filename pattern:..._{score_name}={score}.{ext}
. If activated, argumentwith_score
should be also True, otherwise an error is raised. Default, True.with_global_step (bool) – If True,
{global_step}
is added to the filename pattern:...{name}_{global_step}...
. At least one ofwith_score
andwith_global_step
should be True.
- Return type
Examples
from ignite.handlers import Checkpoint filename_pattern = Checkpoint.setup_filename_pattern() print(filename_pattern) > "{filename_prefix}_{name}_{global_step}_{score_name}={score}.{ext}"
New in version 0.4.3.