Shortcuts

Source code for ignite.handlers.checkpoint

import os
import tempfile

import torch


[docs]class ModelCheckpoint(object): """ ModelCheckpoint handler can be used to periodically save objects to disk. This handler expects two arguments: - an :class:`~ignite.engine.Engine` object - a `dict` mapping names (`str`) to objects that should be saved to disk. See Notes and Examples for further details. Args: dirname (str): Directory path where objects will be saved. filename_prefix (str): Prefix for the filenames to which objects will be saved. See Notes for more details. save_interval (int, optional): if not None, objects will be saved to disk every `save_interval` calls to the handler. Exactly one of (`save_interval`, `score_function`) arguments must be provided. score_function (callable, optional): if not None, it should be a function taking a single argument, an :class:`~ignite.engine.Engine` object, and return a score (`float`). Objects with highest scores will be retained. Exactly one of (`save_interval`, `score_function`) arguments must be provided. score_name (str, optional): if `score_function` not None, it is possible to store its absolute value using `score_name`. See Notes for more details. n_saved (int, optional): Number of objects that should be kept on disk. Older files will be removed. atomic (bool, optional): If True, objects are serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occures during saving). require_empty (bool, optional): If True, will raise exception if there are any files starting with `filename_prefix` in the directory 'dirname'. create_dir (bool, optional): If True, will create directory 'dirname' if it doesnt exist. save_as_state_dict (bool, optional): If True, will save only the `state_dict` of the objects specified, otherwise the whole object will be saved. Notes: This handler expects two arguments: an :class:`~ignite.engine.Engine` object and a `dict` mapping names to objects that should be saved. These names are used to specify filenames for saved objects. Each filename has the following structure: `{filename_prefix}_{name}_{step_number}.pth`. Here, `filename_prefix` is the argument passed to the constructor, `name` is the key in the aforementioned `dict`, and `step_number` is incremented by `1` with every call to the handler. If `score_function` is provided, user can store its absolute value using `score_name` in the filename. Each filename can have the following structure: `{filename_prefix}_{name}_{step_number}_{score_name}={abs(score_function_result)}.pth`. For example, `score_name="val_loss"` and `score_function` that returns `-loss` (as objects with highest scores will be retained), then saved models filenames will be `model_resnet_10_val_loss=0.1234.pth`. Examples: >>> import os >>> from ignite.engine import Engine, Events >>> from ignite.handlers import ModelCheckpoint >>> from torch import nn >>> trainer = Engine(lambda batch: None) >>> handler = ModelCheckpoint('/tmp/models', 'myprefix', save_interval=2, n_saved=2, create_dir=True) >>> model = nn.Linear(3, 3) >>> trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {'mymodel': model}) >>> trainer.run([0], max_epochs=6) >>> os.listdir('/tmp/models') ['myprefix_mymodel_4.pth', 'myprefix_mymodel_6.pth'] """ def __init__(self, dirname, filename_prefix, save_interval=None, score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, create_dir=True, save_as_state_dict=True): self._dirname = os.path.expanduser(dirname) self._fname_prefix = filename_prefix self._n_saved = n_saved self._save_interval = save_interval self._score_function = score_function self._score_name = score_name self._atomic = atomic self._saved = [] # list of tuples (priority, saved_objects) self._iteration = 0 self._save_as_state_dict = save_as_state_dict if not (save_interval is None) ^ (score_function is None): raise ValueError("Exactly one of `save_interval`, or `score_function` " "arguments must be provided.") if score_function is None and score_name is not None: raise ValueError("If `score_name` is provided, then `score_function` " "should be also provided.") if create_dir: if not os.path.exists(dirname): os.makedirs(dirname) # Ensure that dirname exists if not os.path.exists(dirname): raise ValueError("Directory path '{}' is not found.".format(dirname)) if require_empty: matched = [fname for fname in os.listdir(dirname) if fname.startswith(self._fname_prefix)] if len(matched) > 0: raise ValueError("Files prefixed with {} are already present " "in the directory {}. If you want to use this " "directory anyway, pass `require_empty=False`." "".format(filename_prefix, dirname)) def _save(self, obj, path): if not self._atomic: self._internal_save(obj, path) else: tmp = tempfile.NamedTemporaryFile(delete=False, dir=self._dirname) try: self._internal_save(obj, tmp.file) except BaseException: tmp.close() os.remove(tmp.name) raise else: tmp.close() os.rename(tmp.name, path) def _internal_save(self, obj, path): if not self._save_as_state_dict: torch.save(obj, path) else: if not hasattr(obj, "state_dict") or not callable(obj.state_dict): raise ValueError("Object should have `state_dict` method.") torch.save(obj.state_dict(), path) def __call__(self, engine, to_save): if len(to_save) == 0: raise RuntimeError("No objects to checkpoint found.") self._iteration += 1 if self._score_function is not None: priority = self._score_function(engine) else: priority = self._iteration if (self._iteration % self._save_interval) != 0: return if (len(self._saved) < self._n_saved) or (self._saved[0][0] < priority): saved_objs = [] suffix = "" if self._score_name is not None: suffix = "_{}={:.7}".format(self._score_name, abs(priority)) for name, obj in to_save.items(): fname = '{}_{}_{}{}.pth'.format(self._fname_prefix, name, self._iteration, suffix) path = os.path.join(self._dirname, fname) self._save(obj=obj, path=path) saved_objs.append(path) self._saved.append((priority, saved_objs)) self._saved.sort(key=lambda item: item[0]) if len(self._saved) > self._n_saved: _, paths = self._saved.pop(0) for p in paths: os.remove(p)

© Copyright 2022, PyTorch-Ignite Contributors. Last updated on 05/04/2022, 8:31:22 PM.

Built with Sphinx using a theme provided by Read the Docs.