[docs]classBaseSaveHandler(metaclass=ABCMeta):"""Base class for save handlers Methods to override: - :meth:`~ignite.handlers.checkpoint.BaseSaveHandler.__call__` - :meth:`~ignite.handlers.checkpoint.BaseSaveHandler.remove` Note: In derived class, please, make sure that in distributed configuration overridden methods are called by a single process. Distributed configuration on XLA devices should be treated slightly differently: for saving checkpoint with `xm.save() <https://pytorch.org/xla/release/1.5/index.html#torch_xla.core.xla_model.save>`_ all processes should pass into the function. Otherwise, application gets stuck. """
[docs]@abstractmethoddef__call__(self,checkpoint:Mapping,filename:str,metadata:Optional[Mapping]=None)->None:"""Method to save `checkpoint` with `filename`. Additionally, metadata dictionary is provided. Metadata contains: - `basename`: file prefix (if provided) with checkpoint name, e.g. `epoch_checkpoint`. - `score_name`: score name if provided, e.g `val_acc`. - `priority`: checkpoint priority value (higher is better), e.g. `12` or `0.6554435` Args: checkpoint (Mapping): checkpoint dictionary to save. filename (str): filename associated with checkpoint. metadata (Mapping, optional): metadata on checkpoint to save. """pass
[docs]@abstractmethoddefremove(self,filename:str)->None:"""Method to remove saved checkpoint. Args: filename (str): filename associated with checkpoint. """pass
[docs]classCheckpoint:"""Checkpoint handler can be used to periodically save and load objects which have attribute `state_dict`/`load_state_dict`. This class can use specific save handlers to store on the disk or a cloud storage, etc. The Checkpoint handler (if used with :class:`~ignite.handlers.DiskSaver`) also handles automatically moving data on TPU to CPU before writing the checkpoint. Args: to_save (Mapping): Dictionary with the objects to save. Objects should have implemented `state_dict` and ` load_state_dict` methods. If contains objects of type torch `DistributedDataParallel`_ or `DataParallel`_, their internal wrapped model is automatically saved (to avoid additional key ``module.`` in the state dictionary). save_handler (callable or :class:`~ignite.handlers.checkpoint.BaseSaveHandler`): Method or callable class to use to save engine and other provided objects. Function receives two objects: checkpoint as a dictionary and filename. If `save_handler` is callable class, it can inherit of :class:`~ignite.handlers.checkpoint.BaseSaveHandler` and optionally implement `remove` method to keep a fixed number of saved checkpoints. In case if user needs to save engine's checkpoint on a disk, `save_handler` can be defined with :class:`~ignite.handlers.DiskSaver`. filename_prefix (str, optional): Prefix for the filename to which objects will be saved. See Note for details. score_function (callable, optional): If not None, it should be a function taking a single argument, :class:`~ignite.engine.Engine` object, and returning a score (`float`). Objects with highest scores will be retained. score_name (str, optional): If `score_function` not None, it is possible to store its value using `score_name`. See Notes for more details. n_saved (int, optional): Number of objects that should be kept on disk. Older files will be removed. If set to `None`, all objects are kept. global_step_transform (callable, optional): global step transform function to output a desired global step. Input of the function is `(engine, event_name)`. Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use :meth:`~ignite.handlers.global_step_from_engine`. archived (bool, optional): Deprecated argument as models saved by `torch.save` are already compressed. .. _DistributedDataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel .. _DataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel Note: This class stores a single file as a dictionary of provided objects to save. The filename has the following structure: `{filename_prefix}_{name}_{suffix}.{ext}` where - `filename_prefix` is the argument passed to the constructor, - `name` is the key in `to_save` if a single object is to store, otherwise `name` is "checkpoint". - `suffix` is composed as following `{global_step}_{score_name}={score}`. Above `global_step` defined by the output of `global_step_transform` and `score` defined by the output of `score_function`. By default, none of `score_function`, `score_name`, `global_step_transform` is defined, then suffix is setup by attached engine's current iteration. The filename will be `{filename_prefix}_{name}_{engine.state.iteration}.{ext}`. If defined a `score_function`, but without `score_name`, then suffix is defined by provided score. The filename will be `{filename_prefix}_{name}_{global_step}_{score}.pt`. If defined `score_function` and `score_name`, then the filename will be `{filename_prefix}_{name}_{score_name}={score}.{ext}`. If `global_step_transform` is provided, then the filename will be `{filename_prefix}_{name}_{global_step}_{score_name}={score}.{ext}` For example, `score_name="neg_val_loss"` and `score_function` that returns `-loss` (as objects with highest scores will be retained), then saved filename will be `{filename_prefix}_{name}_neg_val_loss=-0.1234.pt`. To get the last stored filename, handler exposes attribute `last_checkpoint`: .. code-block:: python handler = Checkpoint(...) ... print(handler.last_checkpoint) > checkpoint_12345.pt Note: This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 only process. This class supports automatically distributed configuration and if used with :class:`~ignite.handlers.DiskSaver`, checkpoint is stored by rank 0 process. .. warning:: When running on TPUs, it should be run in all processes, otherwise application can get stuck on saving the checkpoint. .. code-block:: python # Wrong: # if idist.get_rank() == 0: # handler = Checkpoint(...) # trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler) # Correct: handler = Checkpoint(...) trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler) Examples: Attach the handler to make checkpoints during training: .. code-block:: python from ignite.engine import Engine, Events from ignite.handlers import Checkpoint, DiskSaver trainer = ... model = ... optimizer = ... lr_scheduler = ... to_save = {'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'trainer': trainer} handler = Checkpoint(to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2) trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler) trainer.run(data_loader, max_epochs=6) > ["checkpoint_7000.pt", "checkpoint_8000.pt", ] Attach the handler to an evaluator to save best model during the training according to computed validation metric: .. code-block:: python from ignite.engine import Engine, Events from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine trainer = ... evaluator = ... # Setup Accuracy metric computation on evaluator # Run evaluation on epoch completed event # ... def score_function(engine): return engine.state.metrics['accuracy'] to_save = {'model': model} handler = Checkpoint(to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2, filename_prefix='best', score_function=score_function, score_name="val_acc", global_step_transform=global_step_from_engine(trainer)) evaluator.add_event_handler(Events.COMPLETED, handler) trainer.run(data_loader, max_epochs=10) > ["best_model_9_val_acc=0.77.pt", "best_model_10_val_acc=0.78.pt", ] """Item=namedtuple("Item",["priority","filename"])def__init__(self,to_save:Mapping,save_handler:Union[Callable,BaseSaveHandler],filename_prefix:str="",score_function:Optional[Callable]=None,score_name:Optional[str]=None,n_saved:Optional[int]=1,global_step_transform:Callable=None,archived:bool=False,):ifto_saveisnotNone:# for compatibility with ModelCheckpointifnotisinstance(to_save,collections.Mapping):raiseTypeError("Argument `to_save` should be a dictionary, but given {}".format(type(to_save)))iflen(to_save)<1:raiseValueError("No objects to checkpoint.")self._check_objects(to_save,"state_dict")ifnot(callable(save_handler)orisinstance(save_handler,BaseSaveHandler)):raiseTypeError("Argument `save_handler` should be callable or inherit from BaseSaveHandler")ifscore_functionisNoneandscore_nameisnotNone:raiseValueError("If `score_name` is provided, then `score_function` ""should be also provided.")ifglobal_step_transformisnotNoneandnotcallable(global_step_transform):raiseTypeError("global_step_transform should be a function, got {} instead.".format(type(global_step_transform)))ifarchived:warnings.warn("Argument archived is deprecated and will be removed in 0.5.0")self.to_save=to_saveself._fname_prefix=filename_prefix+"_"iflen(filename_prefix)>0elsefilename_prefixself.save_handler=save_handlerself._score_function=score_functionself._score_name=score_nameself._n_saved=n_savedself._saved=[]self._ext=".pt"self.global_step_transform=global_step_transform@propertydeflast_checkpoint(self)->Optional[str]:iflen(self._saved)<1:returnNonereturnself._saved[-1].filenamedef_check_lt_n_saved(self,or_equal=False):ifself._n_savedisNone:returnTruereturnlen(self._saved)<self._n_saved+int(or_equal)def__call__(self,engine:Engine)->None:suffix=""ifself.global_step_transformisnotNone:global_step=self.global_step_transform(engine,engine.last_event_name)suffix="{}".format(global_step)ifself._score_functionisnotNone:priority=self._score_function(engine)ifnotisinstance(priority,numbers.Number):raiseValueError("Output of score_function should be a number")else:priority=engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED)ifself._check_lt_n_saved()orself._saved[0].priority<priority:priority_str=("{}".format(priority)ifisinstance(priority,numbers.Integral)else"{:.4f}".format(priority))ifself._score_nameisnotNone:iflen(suffix)>0:suffix+="_"suffix="{}{}={}".format(suffix,self._score_name,priority_str)elifself._score_functionisnotNone:iflen(suffix)>0:suffix+="_"suffix="{}{}".format(suffix,priority_str)eliflen(suffix)==0:suffix="{}".format(priority_str)checkpoint=self._setup_checkpoint()name="checkpoint"iflen(checkpoint)==1:forkincheckpoint:name=kcheckpoint=checkpoint[name]filename="{}{}_{}{}".format(self._fname_prefix,name,suffix,self._ext)ifany(item.filename==filenameforiteminself._saved):returnmetadata={"basename":"{}{}".format(self._fname_prefix,name),"score_name":self._score_name,"priority":priority,}try:self.save_handler(checkpoint,filename,metadata)exceptTypeError:self.save_handler(checkpoint,filename)self._saved.append(Checkpoint.Item(priority,filename))self._saved.sort(key=lambdaitem:item[0])ifnotself._check_lt_n_saved(or_equal=True):item=self._saved.pop(0)ifisinstance(self.save_handler,BaseSaveHandler):self.save_handler.remove(item.filename)def_setup_checkpoint(self)->dict:checkpoint={}fork,objinself.to_save.items():ifisinstance(obj,(nn.DataParallel,nn.parallel.DistributedDataParallel)):obj=obj.modulecheckpoint[k]=obj.state_dict()returncheckpoint@staticmethoddef_check_objects(objs:Mapping,attr:str)->None:fork,objinobjs.items():ifnothasattr(obj,attr):raiseTypeError("Object {} should have `{}` method".format(type(obj),attr))
[docs]@staticmethoddefload_objects(to_load:Mapping,checkpoint:Mapping,**kwargs)->None:"""Helper method to apply `load_state_dict` on the objects from `to_load` using states from `checkpoint`. Exemples: .. code-block:: python import torch from ignite.engine import Engine, Events from ignite.handlers import ModelCheckpoint, Checkpoint trainer = Engine(lambda engine, batch: None) handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=None, create_dir=True) model = torch.nn.Linear(3, 3) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) to_save = {"weights": model, "optimizer": optimizer} trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save) trainer.run(torch.randn(10, 1), 5) to_load = to_save checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth" checkpoint = torch.load(checkpoint_fp) Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint) Args: to_load (Mapping): a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}` checkpoint (Mapping): a dictionary with state_dicts to load, e.g. `{"model": model_state_dict, "optimizer": opt_state_dict}`. If `to_load` contains a single key, then checkpoint can contain directly corresponding state_dict. **kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables the user to load part of the pretrained model (useful for example, in Transfer Learning) """Checkpoint._check_objects(to_load,"load_state_dict")ifnotisinstance(checkpoint,collections.Mapping):raiseTypeError("Argument checkpoint should be a dictionary, but given {}".format(type(checkpoint)))iflen(kwargs)>1orany(kforkinkwargs.keys()ifknotin["strict"]):warnings.warn("kwargs contains keys other than strict and these will be ignored")is_state_dict_strict=kwargs.get("strict",True)iflen(to_load)==1:# single object and checkpoint is directly a state_dictkey,obj=list(to_load.items())[0]ifkeynotincheckpoint:obj.load_state_dict(checkpoint,strict=is_state_dict_strict)return# multiple objects to loadfork,objinto_load.items():ifknotincheckpoint:raiseValueError("Object labeled by '{}' from `to_load` is not found in the checkpoint".format(k))ifisinstance(obj,torch.nn.Module):obj.load_state_dict(checkpoint[k],strict=is_state_dict_strict)else:obj.load_state_dict(checkpoint[k])
[docs]classDiskSaver(BaseSaveHandler):"""Handler that saves input checkpoint on a disk. Args: dirname (str): Directory path where the checkpoint will be saved atomic (bool, optional): if True, checkpoint is serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occurs during saving). create_dir (bool, optional): if True, will create directory 'dirname' if it doesnt exist. require_empty (bool, optional): If True, will raise exception if there are any files in the directory 'dirname'. """def__init__(self,dirname:str,atomic:bool=True,create_dir:bool=True,require_empty:bool=True,):self.dirname=os.path.expanduser(dirname)self._atomic=atomicself._check_and_setup(dirname,create_dir,require_empty)@staticmethod@idist.one_rank_only()def_check_and_setup(dirname,create_dir,require_empty):ifcreate_dir:ifnotos.path.exists(dirname):os.makedirs(dirname)# Ensure that dirname existsifnotos.path.exists(dirname):raiseValueError("Directory path '{}' is not found".format(dirname))ifrequire_empty:matched=[fnameforfnameinos.listdir(dirname)iffname.endswith(".pt")]iflen(matched)>0:raiseValueError("Files {} with extension '.pt' are already present ""in the directory {}. If you want to use this ""directory anyway, pass `require_empty=False`.""".format(matched,dirname))def__call__(self,checkpoint:Mapping,filename:str,metadata:Optional[Mapping]=None)->None:path=os.path.join(self.dirname,filename)ifidist.has_xla_support:self._save_xla(checkpoint,path)else:self._save_native(checkpoint,path)@idist.one_rank_only()def_save_native(self,checkpoint:Mapping,path:str):self._save_func(checkpoint,path,torch.save)def_save_xla(self,checkpoint:Mapping,path:str):importtorch_xla.core.xla_modelasxm# all tpu procs should enter here as internally performs sync across deviceself._save_func(checkpoint,path,xm.save,rank=idist.get_rank())def_save_func(self,checkpoint:Mapping,path:str,func:Callable,rank:int=0):ifnotself._atomic:func(checkpoint,path)else:tmp_file=Nonetmp_name=Nonetmp=Noneifrank==0:tmp=tempfile.NamedTemporaryFile(delete=False,dir=self.dirname)tmp_file=tmp.filetmp_name=tmp.nametry:func(checkpoint,tmp_file)exceptBaseException:iftmpisnotNone:tmp.close()os.remove(tmp_name)raiseelse:iftmpisnotNone:tmp.close()os.rename(tmp.name,path)@idist.one_rank_only()defremove(self,filename:str)->None:path=os.path.join(self.dirname,filename)os.remove(path)
[docs]classModelCheckpoint(Checkpoint):"""ModelCheckpoint handler can be used to periodically save objects to disk only. If needed to store checkpoints to another storage type, please consider :class:`~ignite.handlers.checkpoint.Checkpoint`. This handler expects two arguments: - an :class:`~ignite.engine.Engine` object - a `dict` mapping names (`str`) to objects that should be saved to disk. See Examples for further details. .. warning:: Behaviour of this class has been changed since v0.3.0. Argument `save_as_state_dict` is deprecated and should not be used. It is considered as True. Argument `save_interval` is deprecated and should not be used. Please, use events filtering instead, e.g. :attr:`~ignite.engine.Events.ITERATION_STARTED(every=1000)` There is no more internal counter that has been used to indicate the number of save actions. User could see its value `step_number` in the filename, e.g. `{filename_prefix}_{name}_{step_number}.pt`. Actually, `step_number` is replaced by current engine's epoch if `score_function` is specified and current iteration otherwise. A single `pt` file is created instead of multiple files. Args: dirname (str): Directory path where objects will be saved. filename_prefix (str): Prefix for the filenames to which objects will be saved. See Notes of :class:`~ignite.handlers.Checkpoint` for more details. score_function (callable, optional): if not None, it should be a function taking a single argument, an :class:`~ignite.engine.Engine` object, and return a score (`float`). Objects with highest scores will be retained. score_name (str, optional): if `score_function` not None, it is possible to store its value using `score_name`. See Notes for more details. n_saved (int, optional): Number of objects that should be kept on disk. Older files will be removed. If set to `None`, all objects are kept. atomic (bool, optional): If True, objects are serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occurs during saving). require_empty (bool, optional): If True, will raise exception if there are any files starting with `filename_prefix` in the directory 'dirname'. create_dir (bool, optional): If True, will create directory 'dirname' if it doesnt exist. global_step_transform (callable, optional): global step transform function to output a desired global step. Input of the function is `(engine, event_name)`. Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use :meth:`~ignite.handlers.global_step_from_engine`. archived (bool, optional): Deprecated argument as models saved by `torch.save` are already compressed. Examples: >>> import os >>> from ignite.engine import Engine, Events >>> from ignite.handlers import ModelCheckpoint >>> from torch import nn >>> trainer = Engine(lambda batch: None) >>> handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=2, create_dir=True) >>> model = nn.Linear(3, 3) >>> trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {'mymodel': model}) >>> trainer.run([0], max_epochs=6) >>> os.listdir('/tmp/models') ['myprefix_mymodel_4.pt', 'myprefix_mymodel_6.pt'] >>> handler.last_checkpoint ['/tmp/models/myprefix_mymodel_6.pt'] """def__init__(self,dirname:str,filename_prefix:str,save_interval:Optional[Callable]=None,score_function:Optional[Callable]=None,score_name:Optional[str]=None,n_saved:Union[int,None]=1,atomic:bool=True,require_empty:bool=True,create_dir:bool=True,save_as_state_dict:bool=True,global_step_transform:Optional[Callable]=None,archived:bool=False,):ifnotsave_as_state_dict:raiseValueError("Argument save_as_state_dict is deprecated and should be True.""This argument will be removed in 0.5.0.")ifsave_intervalisnotNone:msg=("Argument save_interval is deprecated and should be None. This argument will be removed in 0.5.0.""Please, use events filtering instead, e.g. Events.ITERATION_STARTED(every=1000)")ifsave_interval==1:# Do not break for old version who used `save_interval=1`warnings.warn(msg)else:# No choiceraiseValueError(msg)disk_saver=DiskSaver(dirname,atomic=atomic,create_dir=create_dir,require_empty=require_empty,)super(ModelCheckpoint,self).__init__(to_save=None,save_handler=disk_saver,filename_prefix=filename_prefix,score_function=score_function,score_name=score_name,n_saved=n_saved,global_step_transform=global_step_transform,archived=archived,)@propertydeflast_checkpoint(self)->Union[str,None]:iflen(self._saved)<1:returnNonereturnos.path.join(self.save_handler.dirname,self._saved[-1].filename)def__call__(self,engine:Engine,to_save:Mapping)->None:iflen(to_save)==0:raiseRuntimeError("No objects to checkpoint found.")self._check_objects(to_save,"state_dict")self.to_save=to_savesuper(ModelCheckpoint,self).__call__(engine)