TorchSnapshotSaver¶
-
class
torchtnt.framework.callbacks.
TorchSnapshotSaver
(dirpath: str, *, save_every_n_train_steps: Optional[int] = None, save_every_n_epochs: Optional[int] = None, replicated: Optional[List[str]] = None, storage_options: Optional[Dict[str, Any]] = None)¶ A callback which periodically saves the application state during training using TorchSnapshot.
This callback supplements the application state provided by
torchtnt.unit.AppStateMixin
with the train progress, train dataloader (if applicable), and random number generator state.If used with
torchtnt.framework.fit()
, this class will also save the evaluation progress state.Checkpoints will be saved under
dirpath/epoch_{epoch}_step_{step}
where step is the total number of training steps completed across all epochs.Parameters: - dirpath – Parent directory to save snapshots to.
- save_every_n_train_steps – Frequency of steps with which to save snapshots during the train epoch. If None, no intra-epoch snapshots are generated.
- save_every_n_epochs – Frequency of epochs with which to save snapshots during training. If None, no end-of-epoch snapshots are generated.
- replicated – A glob-pattern of replicated key names that indicate which application state entries have the same state across all processes. For more information, see https://pytorch.org/torchsnapshot/main/api_reference.html#torchsnapshot.Snapshot.take .
- storage_options – storage_options: Additional keyword options for the storage plugin to use, to be passed to torchsnapshot.Snapshot. See each storage plugin’s documentation for customizations.
Note: If torch.distributed is available and default process group is initialized, the constructor will call a collective operation for rank 0 to broadcast the dirpath to all other ranks
Note
If checkpointing FSDP model, you can set state_dict type calling set_state_dict_type prior to starting training.
-
property
dirpath
: str¶ Returns parent directory to save to.
-
on_exception
(state: State, unit: Union[TrainUnit[TTrainData], EvalUnit[TEvalData], PredictUnit[TPredictData]], exc: BaseException) None ¶ Hook called when an exception occurs.
-
on_train_epoch_end
(state: State, unit: TrainUnit[TTrainData]) None ¶ Hook called after a train epoch ends.
-
on_train_start
(state: State, unit: TrainUnit[TTrainData]) None ¶ Validate there’s no key collision for the app state.
-
on_train_step_end
(state: State, unit: TrainUnit[TTrainData]) None ¶ Hook called after a train step ends.
-
static
restore
(path: str, unit: AppStateMixin, *, train_dataloader: Optional[Stateful] = None, restore_train_progress: bool = True, restore_eval_progress: bool = True, storage_options: Optional[Dict[str, Any]] = None) None ¶ Utility method to restore snapshot state from a path.
There are additional flags offered should the user want to skip loading the train and eval progress. By default, the train and eval progress are restored, if applicable.
Parameters: - path – Path of the snapshot to restore.
- unit – An instance of
TrainUnit
,EvalUnit
, orPredictUnit
containing states to restore. - train_dataloader – An optional train dataloader to restore.
- restore_train_progress – Whether to restore the training progress state.
- restore_eval_progress – Whether to restore the evaluation progress state.
- storage_options –
Additional keyword options for the storage plugin to use, to be passed to torchsnapshot.Snapshot. See each storage plugin’s documentation for customizations.
-
static
restore_from_latest
(dirpath: str, unit: AppStateMixin, *, train_dataloader: Optional[Stateful] = None, restore_train_progress: bool = True, restore_eval_progress: bool = True, storage_options: Optional[Dict[str, Any]] = None) bool ¶ Given a parent directory where checkpoints are saved, restore the snapshot state from the latest checkpoint in the directory.
There are additional flags offered should the user want to skip loading the train and eval progress. By default, the train and eval progress are restored, if applicable.
Parameters: - dirpath – Parent directory from which to get the latest snapshot.
- unit – An instance of
TrainUnit
,EvalUnit
, orPredictUnit
containing states to restore. - train_dataloader – An optional train dataloader to restore.
- restore_train_progress – Whether to restore the training progress state.
- restore_eval_progress – Whether to restore the evaluation progress state.
- storage_options –
Additional keyword options for the storage plugin to use, to be passed to torchsnapshot.Snapshot. See each storage plugin’s documentation for customizations.
Returns: True if the latest snapshot directory was found and successfully restored, otherwise False.