API Reference¶
- class Snapshot(path: str, pg: Optional[ProcessGroup] = None, storage_options: Optional[Dict[str, Any]] = None)¶
Create a reference to an existing snapshot.
- Parameters:
path (str) – The path to the snapshot. This should be the same as the
path
argument used forSnapshot.take()
when the snapshot was taken.pg (ProcessGroup, optional) – The process group for the participants of
Snapshot.restore()
. If none, the default process group will be used.storage_options (Dict[str, Any], optional) – Additional keyword options for the storage plugin to use. See each storage plugin’s documentation for customizations.
- classmethod take(path: str, app_state: Dict[str, T], pg: Optional[ProcessGroup] = None, replicated: Optional[List[str]] = None, storage_options: Optional[Dict[str, Any]] = None, _custom_tensor_prepare_func: Optional[Callable[[str, Tensor, bool], Tensor]] = None) Snapshot ¶
Takes a snapshot of the application state.
- Parameters:
app_state (Dict[str, Stateful]) – The application state to persist. It takes the form of a dictionary, with the keys being user-defined strings and the values being stateful objects. Stateful objects are objects that exposes
.state_dict()
and.load_state_dict()
methods. Common PyTorch objects such astorch.nn.Module
,torch.optim.Optimizer
, and LR schedulers all qualify as stateful objects.path (str) –
The location to save the snapshot.
path
can have a URI prefix (e.g.s3://
) that specifies a storage backend. If no URI prefix is supplied,path
is assumed to be a file system location. For distributed snapshot, ifpath
is inconsistent across participating ranks, the value specified by rank 0 will be used. For multi-host snapshot,path
needs to be a location accessible by all hosts.Note
path
must not point to an existing snapshot.pg (ProcessGroup, optional) – The process group for the participants of
Snapshot.take()
. If none, the default process group will be used.replicated (List[str], optional) –
Glob patterns for marking checkpoint content as replicated. Matching objects will be deduped and load-balanced across ranks.
Note
The replication property is automatically inferred for
DistributedDataParallel
. Only specify this argument if your model has fully replicated states but does not useDistributedDataParallel
.storage_options (Dict[str, Any], optional) – Additional keyword options for the storage plugin to use. See each storage plugin’s documentation for customizations.
- Returns:
The newly taken snapshot.
- classmethod async_take(path: str, app_state: Dict[str, T], pg: Optional[ProcessGroup] = None, replicated: Optional[List[str]] = None, storage_options: Optional[Dict[str, Any]] = None, _custom_tensor_prepare_func: Optional[Callable[[str, Tensor, bool], Tensor]] = None) PendingSnapshot ¶
Asynchronously takes a snapshot from the application state.
This function is identical to
Snapshot.take()
, except that it returns early and performs as much I/O operations in the background as possible, allowing training to resume early.- Parameters:
app_state (Dict[str, Stateful]) – Same as the
app_state
argument ofSnapshot.take()
.path (str) – Same as the
path
argument ofSnapshot.take()
.pg (ProcessGroup, optional) – Same as the
pg
argument ofSnapshot.take()
.replicated (List[str], optional) – Same as the
replicated
argument ofSnapshot.take()
.storage_options (Dict[str, Any], optional) – Same as the
storage_options
argument ofSnapshot.take()
.
- Returns:
A handle to the pending snapshot. The handle has exposes a
.done()
method for querying the progress and a.wait()
method for waiting for the snapshot’s completion.
- restore(app_state: Dict[str, T]) None ¶
Restores the application state from the snapshot.
- Parameters:
app_state (Dict[str, Stateful]) – The application state to restore.
app_state
needs to be either identical to or a subset of theapp_state
used forSnapshot.take()
when the snapshot was taken.
- read_object(path: str, obj_out: Optional[T] = None, memory_budget_bytes: Optional[int] = None) T ¶
Reads an object from the snapshot’s content.
- Parameters:
path (str) – The path to the target object within the snapshot.
path
is equivalent to the target object’s key in the snapshot manifest and can be obtained viaSnapshot.get_manifest()
.obj_out (Any, optional) –
When specified, load the object in-place into
obj_out
if in-place load is supported for the object’s type. Otherwise,obj_out
is ignored.Note
When the target object is a
ShardedTensor
,obj_out
must be specified.memory_budget_bytes (int, optional) – When specified, the read operation will keep the temporary memory buffer size below this threshold.
- Returns:
The object read from the snapshot’s content.
- get_manifest() Dict[str, Entry] ¶
Returns the snapshot manifest.
Each entry in the dictionary corresponds to an object in the snapshot, with the keys being the logical paths to the objects and the values being the metadata describing the object. For distributed snapshots, the manifest contain entries for objects saved by all ranks.
- Returns:
The snapshot manifest.
- class StateDict(dict=None, /, **kwargs)¶
A dictionary that exposes
.state_dict()
and.load_state_dict()
methods.It can be used to capture objects that do not expose
.state_dict()
and.load_state_dict()
methods (e.g. Tensors, Python primitive types) as part of the application state.
- class RNGState¶
A special stateful object for saving and restoring global RNG state.
When captured in the application state, it is guaranteed that the global RNG state is set to the same values after restoring from the snapshot as it was after taking the snapshot.
Example:
>>> Snapshot.take( >>> path="foo/bar", >>> app_state={"rng_state": RNGState()}, >>> ) >>> after_take = torch.rand(1) >>> # In the same process or in another process >>> snapshot = Snapshot(path="foo/bar") >>> snapshot.restore(app_state) >>> after_restore = torch.rand(1) >>> torch.testing.assert_close(after_take, after_restore)