Shortcuts

TorchSnapshotSaver

class torchtnt.framework.callbacks.TorchSnapshotSaver(dirpath: str, *, save_every_n_train_steps: Optional[int] = None, save_every_n_epochs: Optional[int] = None, replicated: Optional[List[str]] = None, storage_options: Optional[Dict[str, Any]] = None)

A callback which periodically saves the application state during training using TorchSnapshot.

This callback supplements the application state provided by torchtnt.unit.AppStateMixin with the train progress, train dataloader (if applicable), and random number generator state.

If used with torchtnt.framework.fit(), this class will also save the evaluation progress state.

Checkpoints will be saved under dirpath/epoch_{epoch}_step_{step} where step is the total number of training steps completed across all epochs.

Parameters:
  • dirpath – Parent directory to save snapshots to.
  • save_every_n_train_steps – Frequency of steps with which to save snapshots during the train epoch. If None, no intra-epoch snapshots are generated.
  • save_every_n_epochs – Frequency of epochs with which to save snapshots during training. If None, no end-of-epoch snapshots are generated.
  • replicated – A glob-pattern of replicated key names that indicate which application state entries have the same state across all processes. For more information, see https://pytorch.org/torchsnapshot/main/api_reference.html#torchsnapshot.Snapshot.take .
  • storage_options – storage_options: Additional keyword options for the storage plugin to use, to be passed to torchsnapshot.Snapshot. See each storage plugin’s documentation for customizations.

Note: If torch.distributed is available and default process group is initialized, the constructor will call a collective operation for rank 0 to broadcast the dirpath to all other ranks

Note

If checkpointing FSDP model, you can set state_dict type calling set_state_dict_type prior to starting training.

property dirpath: str

Returns parent directory to save to.

on_exception(state: State, unit: Union[TrainUnit[TTrainData], EvalUnit[TEvalData], PredictUnit[TPredictData]], exc: BaseException) None

Hook called when an exception occurs.

on_train_end(state: State, unit: TrainUnit[TTrainData]) None

Hook called after training ends.

on_train_epoch_end(state: State, unit: TrainUnit[TTrainData]) None

Hook called after a train epoch ends.

on_train_start(state: State, unit: TrainUnit[TTrainData]) None

Validate there’s no key collision for the app state.

on_train_step_end(state: State, unit: TrainUnit[TTrainData]) None

Hook called after a train step ends.

static restore(path: str, unit: AppStateMixin, *, train_dataloader: Optional[Stateful] = None, restore_train_progress: bool = True, restore_eval_progress: bool = True, storage_options: Optional[Dict[str, Any]] = None) None

Utility method to restore snapshot state from a path.

There are additional flags offered should the user want to skip loading the train and eval progress. By default, the train and eval progress are restored, if applicable.

Parameters:
  • path – Path of the snapshot to restore.
  • unit – An instance of TrainUnit, EvalUnit, or PredictUnit containing states to restore.
  • train_dataloader – An optional train dataloader to restore.
  • restore_train_progress – Whether to restore the training progress state.
  • restore_eval_progress – Whether to restore the evaluation progress state.
  • storage_options

    Additional keyword options for the storage plugin to use, to be passed to torchsnapshot.Snapshot. See each storage plugin’s documentation for customizations.

static restore_from_latest(dirpath: str, unit: AppStateMixin, *, train_dataloader: Optional[Stateful] = None, restore_train_progress: bool = True, restore_eval_progress: bool = True, storage_options: Optional[Dict[str, Any]] = None) bool

Given a parent directory where checkpoints are saved, restore the snapshot state from the latest checkpoint in the directory.

There are additional flags offered should the user want to skip loading the train and eval progress. By default, the train and eval progress are restored, if applicable.

Parameters:
  • dirpath – Parent directory from which to get the latest snapshot.
  • unit – An instance of TrainUnit, EvalUnit, or PredictUnit containing states to restore.
  • train_dataloader – An optional train dataloader to restore.
  • restore_train_progress – Whether to restore the training progress state.
  • restore_eval_progress – Whether to restore the evaluation progress state.
  • storage_options

    Additional keyword options for the storage plugin to use, to be passed to torchsnapshot.Snapshot. See each storage plugin’s documentation for customizations.

Returns:

True if the latest snapshot directory was found and successfully restored, otherwise False.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources