Shortcuts

Source code for ignite.handlers.checkpoint

import os
import tempfile

from collections import namedtuple
import collections.abc as collections
import warnings

import torch

from ignite.engine import Events


[docs]class Checkpoint: """Checkpoint handler can be used to periodically save and load objects which have attribute `state_dict`/`load_state_dict`. This class can use specific save handlers to store on the disk or a cloud storage, etc. Args: to_save (dict): Dictionary with the objects to save. Objects should have implemented `state_dict` and ` load_state_dict` methods. save_handler (callable): Method to use to save engine and other provided objects. Function receives a checkpoint as a dictionary to save. In case if user needs to save engine's checkpoint on a disk, `save_handler` can be defined with :class:`~ignite.handlers.DiskSaver`. filename_prefix (str, optional): Prefix for the filename to which objects will be saved. See Note for details. score_function (callable, optional): If not None, it should be a function taking a single argument, :class:`~ignite.engine.Engine` object, and returning a score (`float`). Objects with highest scores will be retained. score_name (str, optional): If `score_function` not None, it is possible to store its absolute value using `score_name`. See Notes for more details. n_saved (int, optional): Number of objects that should be kept on disk. Older files will be removed. If set to `None`, all objects are kept. global_step_transform (callable, optional): global step transform function to output a desired global step. Input of the function is `(engine, event_name)`. Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use :meth:`~ignite.handlers.global_step_from_engine`. archived (bool, optional): It True, saved checkpoint extension will be `.pth.tar`, Default value is False. Note: This class stores a single file as a dictionary of provided objects to save. The filename has the following structure: `{filename_prefix}_{name}_{suffix}.{ext}` where - `filename_prefix` is the argument passed to the constructor, - `name` is the key in `to_save` if a single object is to store, otherwise `name` is "checkpoint". - `ext` is `.pth.tar` if `archived=True` otherwise `.pth`. - `suffix` is composed as following `{global_step}_{score_name}={score}`. Above `global_step` defined by the output of `global_step_transform` and `score` defined by the output of `score_function`. By default, none of `score_function`, `score_name`, `global_step_transform` is defined, then suffix is setup by attached engine's current iteration. The filename will be `{filename_prefix}_{name}_{engine.state.iteration}.{ext}`. If defined a `score_function`, but without `score_name`, then suffix is defined by provided score. The filename will be `{filename_prefix}_{name}_{global_step}_{score}.pth`. If defined `score_function` and `score_name`, then the filename will be `{filename_prefix}_{name}_{score_name}={abs(score)}.{ext}`. If `global_step_transform` is provided, then the filename will be `{filename_prefix}_{name}_{global_step}_{score_name}={abs(score)}.{ext}` For example, `score_name="val_loss"` and `score_function` that returns `-loss` (as objects with highest scores will be retained), then saved filename will be `{filename_prefix}_{name}_val_loss=0.1234.pth`. To get the last stored filename, handler exposes attribute `last_checkpoint`: .. code-block:: python handler = Checkpoint(...) ... print(handler.last_checkpoint) > checkpoint_12345.pth Examples: Attach the handler to make checkpoints during training: .. code-block:: python from ignite.engine import Engine, Events from ignite.handlers import Checkpoint, DiskSaver trainer = ... model = ... optimizer = ... lr_scheduler = ... to_save = {'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'trainer': trainer} handler = Checkpoint(to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2) trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler) trainer.run(data_loader, max_epochs=6) > ["checkpoint_7000.pth", "checkpoint_8000.pth", ] Attach the handler to an evaluator to save best model during the training according to computed validation metric: .. code-block:: python from ignite.engine import Engine, Events from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine trainer = ... evaluator = ... def score_function(engine): engine.state.metrics['accuracy'] to_save = {'model': model} handler = Checkpoint(to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2, filename_prefix='best', score_function=score_function, score_name="val_acc", global_step_transform=global_step_from_engine(trainer)) evaluator.add_event_handler(Events.COMPLETED, handler) trainer.run(data_loader, max_epochs=10) > ["best_model_9_val_acc=0.77.pth", "best_model_10_val_acc=0.78.pth", ] """ Item = namedtuple("Item", ["priority", "filename"]) def __init__(self, to_save, save_handler, filename_prefix="", score_function=None, score_name=None, n_saved=1, global_step_transform=None, archived=False): if not isinstance(to_save, collections.Mapping): raise TypeError("Argument `to_save` should be a dictionary, but given {}".format(type(to_save))) if len(to_save) < 1: raise ValueError("No objects to checkpoint.") if not callable(save_handler): raise TypeError("Argument `save_handler` should be callable") if score_function is None and score_name is not None: raise ValueError("If `score_name` is provided, then `score_function` " "should be also provided.") if global_step_transform is not None and not callable(global_step_transform): raise TypeError("global_step_transform should be a function, got {} instead." .format(type(global_step_transform))) self._check_objects(to_save, "state_dict") self._fname_prefix = filename_prefix + "_" if len(filename_prefix) > 0 else filename_prefix self.save_handler = save_handler self.to_save = to_save self._score_function = score_function self._score_name = score_name self._n_saved = n_saved self._saved = [] self._ext = ".pth.tar" if archived else ".pth" self.global_step_transform = global_step_transform @property def last_checkpoint(self): if len(self._saved) < 1: return None return self._saved[0].filename def _check_lt_n_saved(self, or_equal=False): if self._n_saved is None: return True return len(self._saved) < self._n_saved + int(or_equal) def __call__(self, engine): suffix = "" if self.global_step_transform is not None: global_step = self.global_step_transform(engine, engine.last_event_name) suffix = "{}".format(global_step) if self._score_function is not None: priority = self._score_function(engine) else: priority = engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED) if self._check_lt_n_saved() or self._saved[0].priority < priority: if self._score_name is not None: if len(suffix) > 0: suffix += "_" suffix = "{}{}={}".format(suffix, self._score_name, priority) elif self._score_function is not None: if len(suffix) > 0: suffix += "_" suffix = "{}{}".format(suffix, priority) elif len(suffix) == 0: suffix = "{}".format(priority) checkpoint = self._setup_checkpoint() name = "checkpoint" if len(checkpoint) == 1: for k in checkpoint: name = k checkpoint = checkpoint[name] filename = '{}{}_{}{}'.format(self._fname_prefix, name, suffix, self._ext) self.save_handler(checkpoint, filename) self._saved.append(Checkpoint.Item(priority, filename)) self._saved.sort(key=lambda item: item[0]) if not self._check_lt_n_saved(or_equal=True): item = self._saved.pop(0) self.save_handler.remove(item.filename) def _setup_checkpoint(self): checkpoint = {} for k, obj in self.to_save.items(): checkpoint[k] = obj.state_dict() return checkpoint @staticmethod def _check_objects(objs, attr): for k, obj in objs.items(): if not hasattr(obj, attr): raise TypeError("Object {} should have `{}` method".format(type(obj), attr)) @staticmethod def load_objects(to_load, checkpoint): """Helper method to apply `load_state_dict` on the objects from `to_load` using states from `checkpoint`. Args: to_load (Mapping): a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}` checkpoint (Mapping): a dictionary with state_dicts to load, e.g. `{"model": model_state_dict, "optimizer": opt_state_dict}` """ Checkpoint._check_objects(to_load, "load_state_dict") if not isinstance(checkpoint, collections.Mapping): raise TypeError("Argument checkpoint should be a dictionary, but given {}".format(type(checkpoint))) for k, obj in to_load.items(): if k not in checkpoint: raise ValueError("Object labeled by '{}' from `to_load` is not found in the checkpoint".format(k)) obj.load_state_dict(checkpoint[k])
[docs]class DiskSaver: """Handler that saves input checkpoint on a disk. Args: dirname (str): Directory path where the checkpoint will be saved atomic (bool, optional): if True, checkpoint is serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occures during saving). create_dir (bool, optional): if True, will create directory 'dirname' if it doesnt exist. require_empty (bool, optional): If True, will raise exception if there are any files in the directory 'dirname'. """ def __init__(self, dirname, atomic=True, create_dir=True, require_empty=True): self.dirname = os.path.expanduser(dirname) self._atomic = atomic if create_dir: if not os.path.exists(dirname): os.makedirs(dirname) # Ensure that dirname exists if not os.path.exists(dirname): raise ValueError("Directory path '{}' is not found".format(dirname)) if require_empty: matched = [fname for fname in os.listdir(dirname) if fname.endswith(".pth") or fname.endswith(".pth.tar")] if len(matched) > 0: raise ValueError("Files {} with extension '.pth' or '.pth.tar' are already present " "in the directory {}. If you want to use this " "directory anyway, pass `require_empty=False`." "".format(matched, dirname)) def __call__(self, checkpoint, filename): path = os.path.join(self.dirname, filename) if not self._atomic: torch.save(checkpoint, path) else: tmp = tempfile.NamedTemporaryFile(delete=False, dir=self.dirname) try: torch.save(checkpoint, tmp.file) except BaseException: tmp.close() os.remove(tmp.name) raise else: tmp.close() os.rename(tmp.name, path) def remove(self, filename): path = os.path.join(self.dirname, filename) os.remove(path)
[docs]class ModelCheckpoint(Checkpoint): """ModelCheckpoint handler can be used to periodically save objects to disk only. If needed to store checkpoints to another storage type, please consider :class:`~ignite.handlers.checkpoint.Checkpoint`. This handler expects two arguments: - an :class:`~ignite.engine.Engine` object - a `dict` mapping names (`str`) to objects that should be saved to disk. See Examples for further details. .. warning:: Behaviour of this class has been changed since v0.3.0. Argument `save_as_state_dict` is deprecated and should not be used. It is considered as True. Argument `save_interval` is deprecated and should not be used. Please, use events filtering instead, e.g. :attr:`~ignite.engine.Events.ITERATION_STARTED(every=1000)` There is no more internal counter that has been used to indicate the number of save actions. User could see its value `step_number` in the filename, e.g. `{filename_prefix}_{name}_{step_number}.pth`. Actually, `step_number` is replaced by current engine's epoch if `score_function` is specified and current iteration otherwise. A single `pth` file is created instead of multiple files. Args: dirname (str): Directory path where objects will be saved. filename_prefix (str): Prefix for the filenames to which objects will be saved. See Notes of :class:`~ignite.handlers.Checkpoint` for more details. score_function (callable, optional): if not None, it should be a function taking a single argument, an :class:`~ignite.engine.Engine` object, and return a score (`float`). Objects with highest scores will be retained. score_name (str, optional): if `score_function` not None, it is possible to store its absolute value using `score_name`. See Notes for more details. n_saved (int, optional): Number of objects that should be kept on disk. Older files will be removed. If set to `None`, all objects are kept. atomic (bool, optional): If True, objects are serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occurs during saving). require_empty (bool, optional): If True, will raise exception if there are any files starting with `filename_prefix` in the directory 'dirname'. create_dir (bool, optional): If True, will create directory 'dirname' if it doesnt exist. global_step_transform (callable, optional): global step transform function to output a desired global step. Input of the function is `(engine, event_name)`. Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use :meth:`~ignite.handlers.global_step_from_engine`. archived (bool, optional): It True, saved checkpoint extension will be `.pth.tar`, Default value is False. Examples: >>> import os >>> from ignite.engine import Engine, Events >>> from ignite.handlers import ModelCheckpoint >>> from torch import nn >>> trainer = Engine(lambda batch: None) >>> handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=2, create_dir=True) >>> model = nn.Linear(3, 3) >>> trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {'mymodel': model}) >>> trainer.run([0], max_epochs=6) >>> os.listdir('/tmp/models') ['myprefix_mymodel_4.pth', 'myprefix_mymodel_6.pth'] >>> handler.last_checkpoint ['/tmp/models/myprefix_mymodel_6.pth'] """ def __init__(self, dirname, filename_prefix, save_interval=None, score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, create_dir=True, save_as_state_dict=True, global_step_transform=None, archived=False): if not save_as_state_dict: raise ValueError("Argument save_as_state_dict is deprecated and should be True") if save_interval is not None: msg = "Argument save_interval is deprecated and should be None. " \ "Please, use events filtering instead, e.g. Events.ITERATION_STARTED(every=1000)" if save_interval == 1: # Do not break for old version who used `save_interval=1` warnings.warn(msg) else: # No choice raise ValueError(msg) disk_saver = DiskSaver(dirname, atomic=atomic, create_dir=create_dir, require_empty=require_empty) if score_function is None and score_name is not None: raise ValueError("If `score_name` is provided, then `score_function` " "should be also provided.") if global_step_transform is not None and not callable(global_step_transform): raise TypeError("global_step_transform should be a function, got {} instead." .format(type(global_step_transform))) self._fname_prefix = filename_prefix + "_" if len(filename_prefix) > 0 else filename_prefix self.save_handler = disk_saver self.to_save = None self._score_function = score_function self._score_name = score_name self._n_saved = n_saved self._saved = [] self._ext = ".pth.tar" if archived else ".pth" self.global_step_transform = global_step_transform @property def last_checkpoint(self): if len(self._saved) < 1: return None return os.path.join(self.save_handler.dirname, self._saved[0].filename) def __call__(self, engine, to_save): if len(to_save) == 0: raise RuntimeError("No objects to checkpoint found.") self._check_objects(to_save, "state_dict") self.to_save = to_save super(ModelCheckpoint, self).__call__(engine)

© Copyright 2024, PyTorch-Ignite Contributors. Last updated on 09/14/2024, 10:14:58 AM.

Built with Sphinx using a theme provided by Read the Docs.