Shortcuts

ModelCheckpoint#

class ignite.handlers.checkpoint.ModelCheckpoint(dirname, filename_prefix='', save_interval=None, score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, create_dir=True, save_as_state_dict=True, global_step_transform=None, archived=False, filename_pattern=None, include_self=False, greater_or_equal=False, **kwargs)[source]#

ModelCheckpoint handler can be used to periodically save objects to disk only. If needed to store checkpoints to another storage type, please consider Checkpoint.

This handler expects two arguments:

  • an Engine object

  • a dict mapping names (str) to objects that should be saved to disk.

See Examples for further details.

Warning

Behaviour of this class has been changed since v0.3.0.

Argument save_as_state_dict is deprecated and should not be used. It is considered as True.

Argument save_interval is deprecated and should not be used. Please, use events filtering instead, e.g. Events.ITERATION_STARTED(every=1000).

There is no more internal counter that has been used to indicate the number of save actions. User could see its value step_number in the filename, e.g. {filename_prefix}_{name}_{step_number}.pt. Actually, step_number is replaced by current engine’s epoch if score_function is specified and current iteration otherwise.

A single pt file is created instead of multiple files.

Parameters
  • dirname (Union[str, Path]) – Directory path where objects will be saved.

  • filename_prefix (str) – Prefix for the file names to which objects will be saved. See Notes of Checkpoint for more details.

  • score_function (Optional[Callable]) – if not None, it should be a function taking a single argument, an Engine object, and return a score (float). Objects with highest scores will be retained.

  • score_name (Optional[str]) – if score_function not None, it is possible to store its value using score_name. See Notes for more details.

  • n_saved (Optional[int]) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept.

  • atomic (bool) – If True, objects are serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occurs during saving).

  • require_empty (bool) – If True, will raise exception if there are any files starting with filename_prefix in the directory dirname.

  • create_dir (bool) – If True, will create directory dirname if it does not exist.

  • global_step_transform (Optional[Callable]) – global step transform function to output a desired global step. Input of the function is (engine, event_name). Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use global_step_from_engine().

  • archived (bool) – Deprecated argument as models saved by torch.save are already compressed.

  • filename_pattern (Optional[str]) – If filename_pattern is provided, this pattern will be used to render checkpoint filenames. If the pattern is not defined, the default pattern would be used. See Checkpoint for details.

  • include_self (bool) – Whether to include the state_dict of this object in the checkpoint. If True, then there must not be another object in to_save with key checkpointer.

  • greater_or_equal (bool) – if True, the latest equally scored model is stored. Otherwise, the first model. Default, False.

  • kwargs (Any) – Accepted keyword arguments for torch.save or xm.save in DiskSaver.

  • save_interval (Optional[int]) –

  • save_as_state_dict (bool) –

Changed in version 0.4.2: Accept kwargs for torch.save or xm.save

Changed in version 0.4.9: Accept filename_pattern and greater_or_equal for parity with Checkpoint

Examples

import os
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint
from torch import nn
trainer = Engine(lambda engine, batch: None)
handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=2, create_dir=True)
model = nn.Linear(3, 3)
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {'mymodel': model})
trainer.run([0, 1, 2, 3, 4], max_epochs=6)
os.listdir('/tmp/models')
# ['myprefix_mymodel_20.pt', 'myprefix_mymodel_30.pt']
handler.last_checkpoint
# ['/tmp/models/myprefix_mymodel_30.pt']

Methods