Shortcuts

ModelCheckpoint#

class ignite.handlers.checkpoint.ModelCheckpoint(dirname, filename_prefix='', score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, create_dir=True, global_step_transform=None, filename_pattern=None, include_self=False, greater_or_equal=False, save_on_rank=0, **kwargs)[source]#

ModelCheckpoint handler, inherits from Checkpoint, can be used to periodically save objects to disk only. If needed to store checkpoints to another storage type, please consider Checkpoint. It also provides last_checkpoint attribute to show the last saved checkpoint.

This handler expects two arguments:

  • an Engine object

  • a dict mapping names (str) to objects that should be saved to disk.

See Examples for further details.

Warning

Behaviour of this class has been changed since v0.3.0.

There is no more internal counter that has been used to indicate the number of save actions. User could see its value step_number in the filename, e.g. {filename_prefix}_{name}_{step_number}.pt. Actually, step_number is replaced by current engine’s epoch if score_function is specified and current iteration otherwise.

A single pt file is created instead of multiple files.

Parameters
  • dirname (Union[str, Path]) – Directory path where objects will be saved.

  • filename_prefix (str) – Prefix for the file names to which objects will be saved. See Notes of Checkpoint for more details.

  • score_function (Optional[Callable]) – if not None, it should be a function taking a single argument, an Engine object, and return a score (float). Objects with highest scores will be retained.

  • score_name (Optional[str]) – if score_function not None, it is possible to store its value using score_name. See Examples of Checkpoint for more details.

  • n_saved (Optional[int]) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept.

  • atomic (bool) – If True, objects are serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occurs during saving).

  • require_empty (bool) – If True, will raise exception if there are any files starting with filename_prefix in the directory dirname.

  • create_dir (bool) – If True, will create directory dirname if it does not exist.

  • global_step_transform (Optional[Callable]) – global step transform function to output a desired global step. Input of the function is (engine, event_name). Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use global_step_from_engine().

  • filename_pattern (Optional[str]) – If filename_pattern is provided, this pattern will be used to render checkpoint filenames. If the pattern is not defined, the default pattern would be used. See Checkpoint for details.

  • include_self (bool) – Whether to include the state_dict of this object in the checkpoint. If True, then there must not be another object in to_save with key checkpointer.

  • greater_or_equal (bool) – if True, the latest equally scored model is stored. Otherwise, the first model. Default, False.

  • save_on_rank (int) – Which rank to save the objects on, in the distributed configuration. Used to instantiate a DiskSaver and is also passed to the parent class.

  • kwargs (Any) – Accepted keyword arguments for torch.save or xm.save in DiskSaver.

Changed in version 0.4.2: Accept kwargs for torch.save or xm.save

Changed in version 0.4.9: Accept filename_pattern and greater_or_equal for parity with Checkpoint

Changed in version 0.4.10: Added save_on_rank arg to save objects on this rank in a distributed configuration

Examples

import os
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint
from torch import nn
trainer = Engine(lambda engine, batch: None)
handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=2, create_dir=True, require_empty=False)
model = nn.Linear(3, 3)
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {'mymodel': model})
trainer.run([0, 1, 2, 3, 4], max_epochs=6)
print(sorted(os.listdir('/tmp/models')))
print(handler.last_checkpoint)
['myprefix_mymodel_20.pt', 'myprefix_mymodel_30.pt']
/tmp/models/myprefix_mymodel_30.pt

Methods