Shortcuts

Checkpoint#

class ignite.handlers.checkpoint.Checkpoint(to_save, save_handler, filename_prefix='', score_function=None, score_name=None, n_saved=1, global_step_transform=None, archived=False, filename_pattern=None, include_self=False, greater_or_equal=False, save_on_rank=0)[source]#

Checkpoint handler can be used to periodically save and load objects which have attribute state_dict/load_state_dict. This class can use specific save handlers to store on the disk or a cloud storage, etc. The Checkpoint handler (if used with DiskSaver) also handles automatically moving data on TPU to CPU before writing the checkpoint.

Parameters
  • to_save (Mapping) – Dictionary with the objects to save. Objects should have implemented state_dict and load_state_dict methods. If contains objects of type torch DistributedDataParallel or DataParallel, their internal wrapped model is automatically saved (to avoid additional key module. in the state dictionary).

  • save_handler (Union[str, Path, Callable, BaseSaveHandler]) – String, function or callable object. used to save engine and other provided objects. Function receives two objects: checkpoint as a dictionary and filename. If save_handler is callable class, it can inherit of BaseSaveHandler and optionally implement remove method to keep a fixed number of saved checkpoints. In case if user needs to save engine’s checkpoint on a disk, save_handler can be defined with DiskSaver or a string specifying directory name can be passed to save_handler.

  • filename_prefix (str) – Prefix for the file name to which objects will be saved. See Note for details.

  • score_function (Optional[Callable]) – If not None, it should be a function taking a single argument, Engine object, and returning a score (float). Objects with highest scores will be retained.

  • score_name (Optional[str]) – If score_function not None, it is possible to store its value using score_name. If score_function is None, score_name can be used alone to define score_function as Checkpoint.get_default_score_fn(score_name) by default.

  • n_saved (Optional[int]) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept.

  • global_step_transform (Optional[Callable]) – global step transform function to output a desired global step. Input of the function is (engine, event_name). Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use global_step_from_engine().

  • archived (bool) – Deprecated argument as models saved by torch.save are already compressed.

  • filename_pattern (Optional[str]) – If filename_pattern is provided, this pattern will be used to render checkpoint filenames. If the pattern is not defined, the default pattern would be used. See Note for details.

  • include_self (bool) – Whether to include the state_dict of this object in the checkpoint. If True, then there must not be another object in to_save with key checkpointer.

  • greater_or_equal (bool) – if True, the latest equally scored model is stored. Otherwise, the first model. Default, False.

  • save_on_rank (int) – Which rank to save the objects on, in the distributed configuration. If save_handler is string or Path, this is also used to instantiate a DiskSaver.

Note

This class stores a single file as a dictionary of provided objects to save. The filename is defined by filename_pattern and by default has the following structure: {filename_prefix}_{name}_{suffix}.{ext} where

  • filename_prefix is the argument passed to the constructor,

  • name is the key in to_save if a single object is to store, otherwise name is “checkpoint”.

  • suffix is composed as following {global_step}_{score_name}={score}.

score_function

score_name

global_step_transform

suffix

None

None

None

{engine.state.iteration}

X

None

None

{score}

X

None

X

{global_step}_{score}

X

X

X

{global_step}_{score_name}={score}

None

None

X

{global_step}

X

X

None

{score_name}={score}

Above global_step defined by the output of global_step_transform and score defined by the output of score_function.

By default, none of score_function, score_name, global_step_transform is defined, then suffix is setup by attached engine’s current iteration. The filename will be {filename_prefix}_{name}_{engine.state.iteration}.{ext}.

For example, score_name="neg_val_loss" and score_function that returns -loss (as objects with highest scores will be retained), then saved filename will be {filename_prefix}_{name}_neg_val_loss=-0.1234.pt.

Note

If filename_pattern is given, it will be used to render the filenames. filename_pattern is a string that can contain {filename_prefix}, {name}, {score}, {score_name} and {global_step} as templates.

For example, let filename_pattern="{global_step}-{name}-{score}.pt" then the saved filename will be 30000-checkpoint-94.pt

Warning: Please, keep in mind that if filename collide with already used one to saved a checkpoint, new checkpoint will replace the older one. This means that filename like checkpoint.pt will be saved every call and will always be overwritten by newer checkpoints.

Note

To get the last stored filename, handler exposes attribute last_checkpoint:

handler = Checkpoint(...)
...
print(handler.last_checkpoint)
> checkpoint_12345.pt

Note

This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 only process. This class supports automatically distributed configuration and if used with DiskSaver, checkpoint is stored by rank 0 process.

Warning

When running on XLA devices or using ZeroRedundancyOptimizer, it should be run in all processes, otherwise application can get stuck while saving the checkpoint.

# Wrong:
# if idist.get_rank() == 0:
#     handler = Checkpoint(...)
#     trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)

# Correct:
handler = Checkpoint(...)
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)

Examples

Attach the handler to make checkpoints during training:

from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint

trainer = ...
model = ...
optimizer = ...
lr_scheduler = ...

to_save = {'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'trainer': trainer}

if (checkpoint_iters):
    # A: Output is "checkpoint_<iteration>.pt"
    handler = Checkpoint(
        to_save, '/tmp/models', n_saved=2
    )
    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)
else:
    # B:Output is "checkpoint_<epoch>.pt"
    gst = lambda *_: trainer.state.epoch
    handler = Checkpoint(
        to_save, '/tmp/models', n_saved=2, global_step_transform=gst
    )
    trainer.add_event_handler(Events.EPOCH_COMPLETED, handler)

trainer.run(data_loader, max_epochs=6)
> A: ["checkpoint_7000.pt", "checkpoint_8000.pt", ]
> B: ["checkpoint_5.pt", "checkpoint_6.pt", ]

Attach the handler to an evaluator to save best model during the training according to computed validation metric:

from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint, global_step_from_engine

trainer = ...
evaluator = ...
# Setup Accuracy metric computation on evaluator.
# evaluator.state.metrics contain 'accuracy',
# which will be used to define ``score_function`` automatically.
# Run evaluation on epoch completed event
# ...

to_save = {'model': model}
handler = Checkpoint(
    to_save, '/tmp/models',
    n_saved=2, filename_prefix='best',
    score_name="accuracy",
    global_step_transform=global_step_from_engine(trainer)
)

evaluator.add_event_handler(Events.COMPLETED, handler)

trainer.run(data_loader, max_epochs=10)
> ["best_model_9_accuracy=0.77.pt", "best_model_10_accuracy=0.78.pt", ]

Customise the save_handler:

handler = Checkpoint(
    to_save, save_handler=DiskSaver('/tmp/models', create_dir=True, **kwargs), n_saved=2
)

Changed in version 0.4.3:

  • Checkpoint can save model with same filename.

  • Added greater_or_equal argument.

Changed in version 0.4.7:

  • score_name can be used to define score_function automatically without providing score_function.

  • save_handler automatically saves to disk if path to directory is provided.

  • save_on_rank saves objects on this rank in a distributed configuration.

Methods

get_default_score_fn

Helper method to get default score function based on the metric name.

load_objects

Helper method to apply load_state_dict on the objects from to_load using states from checkpoint.

load_state_dict

Method replaces internal state of the class with provided state dict data.

reload_objects

Helper method to apply load_state_dict on the objects from to_load.

reset

Method to reset saved checkpoint names.

setup_filename_pattern

Helper method to get the default filename pattern for a checkpoint.

state_dict

Method returns state dict with saved items: list of (priority, filename) pairs.

class Item(priority, filename)#

Create new instance of Item(priority, filename)

Parameters
  • priority (int) –

  • filename (str) –

filename: str#

Alias for field number 1

priority: int#

Alias for field number 0

static get_default_score_fn(metric_name, score_sign=1.0)[source]#

Helper method to get default score function based on the metric name.

Parameters
  • metric_name (str) – metric name to get the value from engine.state.metrics. Engine is the one to which Checkpoint handler is added.

  • score_sign (float) – sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better, a negative score sign should be used (objects with larger score are retained). Default, 1.0.

Return type

Callable

Examples

from ignite.handlers import Checkpoint

best_acc_score = Checkpoint.get_default_score_fn("accuracy")

best_model_handler = Checkpoint(
    to_save, save_handler, score_name="val_accuracy", score_function=best_acc_score
)
evaluator.add_event_handler(Events.COMPLETED, best_model_handler)

Usage with error-like metric:

from ignite.handlers import Checkpoint

neg_loss_score = Checkpoint.get_default_score_fn("loss", -1.0)

best_model_handler = Checkpoint(
    to_save, save_handler, score_name="val_neg_loss", score_function=neg_loss_score
)
evaluator.add_event_handler(Events.COMPLETED, best_model_handler)

New in version 0.4.3.

static load_objects(to_load, checkpoint, **kwargs)[source]#

Helper method to apply load_state_dict on the objects from to_load using states from checkpoint.

Parameters
  • to_load (Mapping) – a dictionary with objects, e.g. {“model”: model, “optimizer”: optimizer, …}

  • checkpoint (Union[str, Mapping, Path]) – a path, a string filepath or a dictionary with state_dicts to load, e.g. {“model”: model_state_dict, “optimizer”: opt_state_dict}. If to_load contains a single key, then checkpoint can contain directly corresponding state_dict.

  • kwargs (Any) – Keyword arguments accepted for nn.Module.load_state_dict(). Passing strict=False enables the user to load part of the pretrained model (useful for example, in Transfer Learning)

Return type

None

Examples

import tempfile
from pathlib import Path

import torch

from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, Checkpoint

trainer = Engine(lambda engine, batch: None)

with tempfile.TemporaryDirectory() as tmpdirname:
    handler = ModelCheckpoint(tmpdirname, 'myprefix', n_saved=None, create_dir=True)

    model = torch.nn.Linear(3, 3)
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

    to_save = {"weights": model, "optimizer": optimizer}

    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save)
    trainer.run(torch.randn(10, 1), 5)

    to_load = to_save
    checkpoint_fp = Path(tmpdirname) / 'myprefix_checkpoint_40.pt'
    checkpoint = torch.load(checkpoint_fp)
    Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)

    # or using a string for checkpoint filepath

    to_load = to_save
    checkpoint_fp = Path(tmpdirname) / 'myprefix_checkpoint_40.pt'
    Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_fp)

Note

If to_load contains objects of type torch DistributedDataParallel or DataParallel, method load_state_dict will applied to their internal wrapped model (obj.module).

load_state_dict(state_dict)[source]#

Method replaces internal state of the class with provided state dict data.

Parameters

state_dict (Mapping) – a dict with “saved” key and list of (priority, filename) pairs as values.

Return type

None

reload_objects(to_load, load_kwargs=None, **filename_components)[source]#

Helper method to apply load_state_dict on the objects from to_load. Filename components such as name, score and global state can be configured.

Parameters
  • to_load (Mapping) – a dictionary with objects, e.g. {“model”: model, “optimizer”: optimizer, …}

  • load_kwargs (Optional[Dict]) – Keyword arguments accepted for nn.Module.load_state_dict(). Passing strict=False enables the user to load part of the pretrained model (useful for example, in Transfer Learning)

  • filename_components (Any) – Filename components used to define the checkpoint file path. Keyword arguments accepted are name, score and global_state.

Return type

None

Examples

import tempfile

import torch

from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint

trainer = Engine(lambda engine, batch: None)

with tempfile.TemporaryDirectory() as tmpdirname:
    checkpoint = ModelCheckpoint(tmpdirname, 'myprefix', n_saved=None, create_dir=True)

    model = torch.nn.Linear(3, 3)
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

    to_save = {"weights": model, "optimizer": optimizer}

    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), checkpoint, to_save)
    trainer.run(torch.randn(10, 1), 5)

    to_load = to_save
    # load checkpoint myprefix_checkpoint_40.pt
    checkpoint.reload_objects(to_load=to_load, global_step=40)

Note

If to_load contains objects of type torch DistributedDataParallel or DataParallel, method load_state_dict will applied to their internal wrapped model (obj.module).

Note

This method works only when the save_handler is of types string, Path or DiskSaver.

reset()[source]#

Method to reset saved checkpoint names.

Use this method if the engine will independently run multiple times:

from ignite.handlers import Checkpoint

trainer = ...
checkpointer = Checkpoint(...)

trainer.add_event_handler(Events.COMPLETED, checkpointer)
trainer.add_event_handler(Events.STARTED, checkpointer.reset)

# fold 0
trainer.run(data0, max_epochs=max_epochs)
print("Last checkpoint:", checkpointer.last_checkpoint)

# fold 1
trainer.run(data1, max_epochs=max_epochs)
print("Last checkpoint:", checkpointer.last_checkpoint)

New in version 0.4.3.

Return type

None

static setup_filename_pattern(with_prefix=True, with_score=True, with_score_name=True, with_global_step=True)[source]#

Helper method to get the default filename pattern for a checkpoint.

Parameters
  • with_prefix (bool) – If True, the filename_prefix is added to the filename pattern: {filename_prefix}_{name}.... Default, True.

  • with_score (bool) – If True, score is added to the filename pattern: ..._{score}.{ext}. Default, True. At least one of with_score and with_global_step should be True.

  • with_score_name (bool) – If True, score_name is added to the filename pattern: ..._{score_name}={score}.{ext}. If activated, argument with_score should be also True, otherwise an error is raised. Default, True.

  • with_global_step (bool) – If True, {global_step} is added to the filename pattern: ...{name}_{global_step}.... At least one of with_score and with_global_step should be True.

Return type

str

Examples

from ignite.handlers import Checkpoint

filename_pattern = Checkpoint.setup_filename_pattern()

print(filename_pattern)
> "{filename_prefix}_{name}_{global_step}_{score_name}={score}.{ext}"

New in version 0.4.3.

state_dict()[source]#

Method returns state dict with saved items: list of (priority, filename) pairs. Can be used to save internal state of the class.

Return type

OrderedDict