ModelCheckpoint#
- class ignite.handlers.checkpoint.ModelCheckpoint(dirname, filename_prefix='', score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, create_dir=True, global_step_transform=None, filename_pattern=None, include_self=False, greater_or_equal=False, save_on_rank=0, **kwargs)[source]#
ModelCheckpoint handler, inherits from
Checkpoint
, can be used to periodically save objects to disk only. If needed to store checkpoints to another storage type, please considerCheckpoint
. It also provides last_checkpoint attribute to show the last saved checkpoint.This handler expects two arguments:
an
Engine
objecta dict mapping names (str) to objects that should be saved to disk.
See Examples for further details.
Warning
Behaviour of this class has been changed since v0.3.0.
There is no more internal counter that has been used to indicate the number of save actions. User could see its value step_number in the filename, e.g. {filename_prefix}_{name}_{step_number}.pt. Actually, step_number is replaced by current engine’s epoch if score_function is specified and current iteration otherwise.
A single pt file is created instead of multiple files.
- Parameters
dirname (Union[str, Path]) – Directory path where objects will be saved.
filename_prefix (str) – Prefix for the file names to which objects will be saved. See Notes of
Checkpoint
for more details.score_function (Optional[Callable]) – if not None, it should be a function taking a single argument, an
Engine
object, and return a score (float). Objects with highest scores will be retained.score_name (Optional[str]) – if
score_function
not None, it is possible to store its value using score_name. See Examples ofCheckpoint
for more details.n_saved (Optional[int]) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept.
atomic (bool) – If True, objects are serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occurs during saving).
require_empty (bool) – If True, will raise exception if there are any files starting with
filename_prefix
in the directorydirname
.create_dir (bool) – If True, will create directory
dirname
if it does not exist.global_step_transform (Optional[Callable]) – global step transform function to output a desired global step. Input of the function is (engine, event_name). Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use
global_step_from_engine()
.filename_pattern (Optional[str]) – If
filename_pattern
is provided, this pattern will be used to render checkpoint filenames. If the pattern is not defined, the default pattern would be used. SeeCheckpoint
for details.include_self (bool) – Whether to include the state_dict of this object in the checkpoint. If True, then there must not be another object in
to_save
with keycheckpointer
.greater_or_equal (bool) – if True, the latest equally scored model is stored. Otherwise, the first model. Default, False.
save_on_rank (int) – Which rank to save the objects on, in the distributed configuration. Used to instantiate a
DiskSaver
and is also passed to the parent class.kwargs (Any) – Accepted keyword arguments for torch.save or xm.save in DiskSaver.
Changed in version 0.4.2: Accept
kwargs
for torch.save or xm.saveChanged in version 0.4.9: Accept
filename_pattern
andgreater_or_equal
for parity withCheckpoint
Changed in version 0.4.10: Added save_on_rank arg to save objects on this rank in a distributed configuration
Examples
import os from ignite.engine import Engine, Events from ignite.handlers import ModelCheckpoint from torch import nn trainer = Engine(lambda engine, batch: None) handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=2, create_dir=True, require_empty=False) model = nn.Linear(3, 3) trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {'mymodel': model}) trainer.run([0, 1, 2, 3, 4], max_epochs=6) print(sorted(os.listdir('/tmp/models'))) print(handler.last_checkpoint)
['myprefix_mymodel_20.pt', 'myprefix_mymodel_30.pt'] /tmp/models/myprefix_mymodel_30.pt
Methods