Shortcuts

API Reference

class Snapshot(path: str, pg: Optional[ProcessGroup] = None, storage_options: Optional[Dict[str, Any]] = None)

Create a reference to an existing snapshot.

Parameters:
  • path (str) – The path to the snapshot. This should be the same as the path argument used for Snapshot.take() when the snapshot was taken.

  • pg (ProcessGroup, optional) – The process group for the participants of Snapshot.restore(). If none, the default process group will be used.

  • storage_options (Dict[str, Any], optional) – Additional keyword options for the storage plugin to use. See each storage plugin’s documentation for customizations.

classmethod take(path: str, app_state: Dict[str, T], pg: Optional[ProcessGroup] = None, replicated: Optional[List[str]] = None, storage_options: Optional[Dict[str, Any]] = None, _custom_tensor_prepare_func: Optional[Callable[[str, Tensor, bool], Tensor]] = None) Snapshot

Takes a snapshot of the application state.

Parameters:
  • app_state (Dict[str, Stateful]) – The application state to persist. It takes the form of a dictionary, with the keys being user-defined strings and the values being stateful objects. Stateful objects are objects that exposes .state_dict() and .load_state_dict() methods. Common PyTorch objects such as torch.nn.Module, torch.optim.Optimizer, and LR schedulers all qualify as stateful objects.

  • path (str) –

    The location to save the snapshot. path can have a URI prefix (e.g. s3://) that specifies a storage backend. If no URI prefix is supplied, path is assumed to be a file system location. For distributed snapshot, if path is inconsistent across participating ranks, the value specified by rank 0 will be used. For multi-host snapshot, path needs to be a location accessible by all hosts.

    Note

    path must not point to an existing snapshot.

  • pg (ProcessGroup, optional) – The process group for the participants of Snapshot.take(). If none, the default process group will be used.

  • replicated (List[str], optional) –

    Glob patterns for marking checkpoint content as replicated. Matching objects will be deduped and load-balanced across ranks.

    Note

    The replication property is automatically inferred for DistributedDataParallel. Only specify this argument if your model has fully replicated states but does not use DistributedDataParallel.

  • storage_options (Dict[str, Any], optional) – Additional keyword options for the storage plugin to use. See each storage plugin’s documentation for customizations.

Returns:

The newly taken snapshot.

classmethod async_take(path: str, app_state: Dict[str, T], pg: Optional[ProcessGroup] = None, replicated: Optional[List[str]] = None, storage_options: Optional[Dict[str, Any]] = None, _custom_tensor_prepare_func: Optional[Callable[[str, Tensor, bool], Tensor]] = None) PendingSnapshot

Asynchronously takes a snapshot from the application state.

This function is identical to Snapshot.take(), except that it returns early and performs as much I/O operations in the background as possible, allowing training to resume early.

Parameters:
  • app_state (Dict[str, Stateful]) – Same as the app_state argument of Snapshot.take().

  • path (str) – Same as the path argument of Snapshot.take().

  • pg (ProcessGroup, optional) – Same as the pg argument of Snapshot.take().

  • replicated (List[str], optional) – Same as the replicated argument of Snapshot.take().

  • storage_options (Dict[str, Any], optional) – Same as the storage_options argument of Snapshot.take().

Returns:

A handle to the pending snapshot. The handle has exposes a .done() method for querying the progress and a .wait() method for waiting for the snapshot’s completion.

restore(app_state: Dict[str, T]) None

Restores the application state from the snapshot.

Parameters:

app_state (Dict[str, Stateful]) – The application state to restore. app_state needs to be either identical to or a subset of the app_state used for Snapshot.take() when the snapshot was taken.

read_object(path: str, obj_out: Optional[T] = None, memory_budget_bytes: Optional[int] = None) T

Reads an object from the snapshot’s content.

Parameters:
  • path (str) – The path to the target object within the snapshot. path is equivalent to the target object’s key in the snapshot manifest and can be obtained via Snapshot.get_manifest().

  • obj_out (Any, optional) –

    When specified, load the object in-place into obj_out if in-place load is supported for the object’s type. Otherwise, obj_out is ignored.

    Note

    When the target object is a ShardedTensor, obj_out must be specified.

  • memory_budget_bytes (int, optional) – When specified, the read operation will keep the temporary memory buffer size below this threshold.

Returns:

The object read from the snapshot’s content.

get_manifest() Dict[str, Entry]

Returns the snapshot manifest.

Each entry in the dictionary corresponds to an object in the snapshot, with the keys being the logical paths to the objects and the values being the metadata describing the object. For distributed snapshots, the manifest contain entries for objects saved by all ranks.

Returns:

The snapshot manifest.

class StateDict(dict=None, /, **kwargs)

A dictionary that exposes .state_dict() and .load_state_dict() methods.

It can be used to capture objects that do not expose .state_dict() and .load_state_dict() methods (e.g. Tensors, Python primitive types) as part of the application state.

class RNGState

A special stateful object for saving and restoring global RNG state.

When captured in the application state, it is guaranteed that the global RNG state is set to the same values after restoring from the snapshot as it was after taking the snapshot.

Example:

>>> Snapshot.take(
>>>     path="foo/bar",
>>>     app_state={"rng_state": RNGState()},
>>> )
>>> after_take = torch.rand(1)

>>> # In the same process or in another process
>>> snapshot = Snapshot(path="foo/bar")
>>> snapshot.restore(app_state)
>>> after_restore = torch.rand(1)

>>> torch.testing.assert_close(after_take, after_restore)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources