Distributed Checkpoint - torch.distributed.checkpoint¶
Distributed Checkpoint (DCP) support loading and saving models from multiple ranks in parallel. It handles load-time resharding which enables saving in one cluster topology and loading into another.
DCP is different than torch.save and torch.load in a few significant ways:
It produces multiple files per checkpoint, with at least one per rank.
It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead.
The entrypoints to load and save a checkpoint are the following:
Additional resources:¶
- torch.distributed.checkpoint.state_dict_saver.save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None)[source]¶
Save a distributed model in SPMD style.
This function is different from
torch.save()
as it handlesShardedTensor
, andDTensor
by having each rank only save their local shards.For each
Stateful
object (having both astate_dict
and aload_state_dict
), save will callstate_dict
before serialization.Warning
There is no guarantees of Backwards Compatibility across PyTorch versions for saved state_dicts.
Warning
If using the process_group argument, make sure that only its ranks call save_state_dict and that all data in state_dict belong to it.
Note
When saving checkpoint for FSDP’s ShardingStrategy.HYBRID_SHARD, only one of the shard_group should be calling save_state_dict and the corresponding process group needs to be passed in.
Note
- If no process group is available, this function assumes the intention is to save the
state_dict in the local process.
- Parameters
state_dict (Dict[str, Any]) – The state_dict to save.
checkpoint_id (Union[str, os.PathLike, None]) – The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. It can also be a key if the storage is a key-value store. (Default:
None
)storage_writer (Optional[StorageWriter]) – Instance of StorageWriter used to perform writes. If this is not specified, DCP will automatically infer the writer based on the checkpoint_id. If checkpoint_id is also None, an exception will be raised. (Default:
None
)planner (Optional[SavePlanner]) – Instance of SavePlanner. If this is not specificed, the default planner will be used. (Default:
None
)process_group (Optional[ProcessGroup]) – ProcessGroup to be used for cross-rank synchronization. (Default:
None
)
- Returns
Metadata object for the saved checkpoint.
- Return type
Metadata
Example
>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1") >>> torch.distributed.checkpoint.save( >>> state_dict=state_dict, >>> storage_writer=fs_storage_writer, >>> )
Note
save_state_dict uses collectives to coordinate writes across ranks. For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by
torch.cuda.current_device()
and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, viatorch.cuda.set_device()
.
- torch.distributed.checkpoint.state_dict_saver.async_save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None)[source]¶
Asynchronous version of
save
. This code first de-stages the state_dict on to the staging storage (defaults to CPU memory), and then calls the save in a separate thread.Warning
This feature is experimental and subject to change.
- Parameters
state_dict (Dict[str, Any]) – The state_dict to save.
checkpoint_id (Union[str, os.PathLike, None]) – The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. It can also be a key if the storage is a key-value store. (Default:
None
)storage_writer (Optional[StorageWriter]) – Instance of StorageWriter used to perform ‘stage’ and ‘save’. If this is not specified, DCP will automatically infer the writer based on the checkpoint_id. If checkpoint_id is also None, an exception will be raised. (Default:
None
)planner (Optional[SavePlanner]) – Instance of SavePlanner. If this is not specificed, the default planner will be used. (Default:
None
)process_group (Optional[ProcessGroup]) – ProcessGroup to be used for cross-rank synchronization. (Default:
None
)
- Returns
A future holding the resultant Metadata object from save.
- Return type
Example
>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1") >>> checkpoint_future = torch.distributed.checkpoint.async_save( >>> state_dict=state_dict, >>> storage_writer=fs_storage_writer, >>> ) >>> >>> # ... do some work ... >>> >>> checkpoint_future.result()
- torch.distributed.checkpoint.state_dict_saver.save_state_dict(state_dict, storage_writer, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source]¶
This method is deprecated. Please switch to ‘save’.
- Return type
Metadata
- torch.distributed.checkpoint.state_dict_loader.load(state_dict, *, checkpoint_id=None, storage_reader=None, planner=None, process_group=None)[source]¶
Load a distributed
state_dict
in SPMD style.Each rank will try to read the least amount of data necessary to fullfill the requested state_dict. When loading
ShardedTensor
orDTensor
instances, each rank only reads data for their local shards.For each
Stateful
object (having both astate_dict
and aload_state_dict
), load will first callstate_dict
before attempting deserialization, followed byload_state_dict
once the deserialization is complete. For each non-Stateful
object, load will deserailize the object, and then replace it in thestate_dict
with the deserialized object.Warning
All tensors in
state_dict
must be allocated on their destination device prior to calling this function.All non-tensor data is loaded using torch.load() and modified in place on state_dict.
Warning
Users must call load_state_dict on the root module to ensure load pos-processing and non-tensor data properly propagates.
- Parameters
state_dict (Dict[str, Any]) – The state_dict to save.
checkpoint_id (Union[str, os.PathLike, None]) – The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. It can also be a key if the storage is a key-value store. (Default:
None
)storage_reader (Optional[StorageReader]) – Instance of StorageWriter used to perform reads. If this is not specified, DCP will automatically infer the reader based on the checkpoint_id. If checkpoint_id is also None, an exception will be raised. (Default:
None
)planner (Optional[LoadPlanner]) – Instance of LoadPlanner. If this is not specificed, the default planner will be used. (Default:
None
)process_group (Optional[ProcessGroup]) – ProcessGroup to be used for cross-rank synchronization. (Default:
None
)
- Returns
None.
- Return type
None
- Examples
>>> my_model = MyModule() >>> optimizer = Adagrad(my_model.parameters()) >>> model_state_dict = my_model.state_dict() >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1")
>>> torch.distributed.checkpoint.load_state_dict( >>> state_dict=model_state_dict, >>> storage_reader=fs_storage_reader, >>> )
>>> # module.load_state_dict() function might have customized steps >>> # to flush the state_dict, must call it to >>> # ensure correct behavior. >>> my_model.load_state_dict(model_state_dict)
Note
load_state_dict uses collectives to coordinate reads across ranks. For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by
torch.cuda.current_device()
and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, viatorch.cuda.set_device()
.
- torch.distributed.checkpoint.state_dict_loader.load_state_dict(state_dict, storage_reader, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source]¶
This method is deprecated. Please switch to ‘load’.
The following module is also useful for additional customization of the staging mechanisms used for asynchronous checkpointing (torch.distributed.checkpoint.async_save):
- class torch.distributed.checkpoint.staging.AsyncStager(*args, **kwargs)[source]¶
This protocol is meant to provide customization and extensibility for dcp.async_save, allowing users to customize how data is staged previous to executing the usual dcp.save path in parallel. The expected order of operations (concretely defined in torch.distributed.state_dict_saver.async_save) is the following:
- AsyncStager.stage_data(state_dict):
This call gives the AsyncStager the opportunity to ‘stage’ the state_dict. The expectation and purpose of staging in this context is to create a “training-safe” representation of the state dict, meaning that any updates to module data after staging is complete should not be reflected in the state dict returned from this method. For example, in the default case a copy of the entire state dict is created on CPU RAM and returned here, allowing users to continue training without risking changes to data which is being serialized.
- dcp.save is called on the state_dict returned from stage in parallel. This call is responsible
for serializing the state_dict and writing it to storage.
- If AsyncStager.should_synchronize_after_execute is True, this method will be called immediately after
the serialization thread starts and before returning from dcp.async_save. If this is set to False, the assumption is the user has defined a custom synchronization point for the the purpose of further optimizing save latency in the training loop (for example, by overlapping staging with the forward/backward pass), and it is the respondsibility of the user to call AsyncStager.synchronize_staging at the appropriate time.
- class torch.distributed.checkpoint.staging.BlockingAsyncStager(cache_staged_state_dict=False, type_check=False)[source]¶
An implementation of AsyncStager which stages the state_dict on CPU RAM and blocks until the copy is complete. This implementation also provides an option to optimize stage latency using pinned memory.
N.B. synchronize_staging is a no-op in this case.
In addition to the above entrypoints, Stateful objects, as described below, provide additional customization during saving/loading .. automodule:: torch.distributed.checkpoint.stateful
- class torch.distributed.checkpoint.stateful.Stateful(*args, **kwargs)[source]¶
Stateful protocol for objects that can be checkpointed and restored.
- state_dict()[source]¶
Objects should return their state_dict representation as a dictionary. The output of this function will be checkpointed, and later restored in load_state_dict().
Warning
Because of the inplace nature of restoring a checkpoint, this function is also called during torch.distributed.checkpoint.load.
- Returns
The objects state dict
- Return type
Dict
This example shows how to use Pytorch Distributed Checkpoint to save a FSDP model.
The following types define the IO interface used during checkpoint:
- class torch.distributed.checkpoint.StorageReader[source]¶
Interface used by
load_state_dict
to read from storage.One StorageReader instance acts as both the coordinator and the follower in a distributed checkpoint. As part of initialization, each instance is told its role.
A subclass should expected the following sequence of calls by
load_state_dict
:(all ranks) set checkpoint_id if users pass a valid checkpoint_id.
(all ranks) read_metadata()
(all ranks) set_up_storage_reader()
(all ranks) prepare_local_plan()
(coordinator) prepare_global_plan()
(all ranks) read_data()
- abstract prepare_global_plan(plans)[source]¶
Perform centralized planning of storage loading.
This method is only called on the coordinator instance.
While this method can produce a completely different plan, the preferred way is to store storage specific data in LoadPlan::storage_data.
- abstract prepare_local_plan(plan)[source]¶
Perform storage-specific local planning.
While this method can produce a completely different plan, the recommended way is to store storage specific data in LoadPlan::storage_data.
- abstract read_data(plan, planner)[source]¶
Read all items from
plan
usingplanner
to resolve the data.A subclass should call
LoadPlanner::load_bytes
to deserialize a BytesIO object into the right place.A subclass should call
LoadPlanner::resolve_tensor
to get access to the tensors that in should load data into.It’s the StorageLayer responsibility to properly schedule any cross device copies required.
- Parameters
plan (LoadPlan) – The local plan to execute on
planner (LoadPlanner) – The planner object to use to resolve items.
- Returns
A future that completes once all reads are finished.
- Return type
Future[None]
- abstract read_metadata()[source]¶
Read the checkpoint metadata.
- Returns
The metadata object associated with the checkpoint being loaded.
- Return type
Metadata
- abstract reset(checkpoint_id=None)[source]¶
Calls to indicates a brand new checkpoint read is going to happen. A checkpoint_id may be present if users set the checkpoint_id for this checkpoint read. The meaning of the checkpiont_id is storage-dependent. It can be a path to a folder/file or a key for a key-value storage.
- Parameters
checkpoint_id (Union[str, os.PathLike, None]) – The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. It can also be a key if the storage is more like a key-value store. (Default:
None
)
- class torch.distributed.checkpoint.StorageWriter[source]¶
Interface used by
save_state_dict
to write to storage.One StorageWriter instance acts as both the coordinator and the follower in a distributed checkpoint. As part of initialization, each instance is told its role.
A subclass should expect the following sequence of calls.
(all ranks) set checkpoint_id if users pass a valid checkpoint_id.
(all ranks) set_up_storage_writer()
(all ranks) prepare_local_plan()
(coordinator) prepare_global_plan()
(all ranks) write_data()
(coordinator) finish()
- abstract finish(metadata, results)[source]¶
Write the metadata and marks the current checkpoint as successful.
The actual format/schema used for serializing metadata is an implementation detail. The only requirement is that it’s recoverable in to the same object graph.
- abstract prepare_global_plan(plans)[source]¶
Perform centralized planning of storage.
This method is only called on the coordinator instance.
While this method can produce a completely different plan, the preferred way is to store storage specific data in SavePlan::storage_data.
- abstract prepare_local_plan(plan)[source]¶
Perform storage-specific local planning.
While this method can produce a completely different plan, the recommended way is to store storage specific data in SavePlan::storage_data.
- abstract reset(checkpoint_id=None)[source]¶
Calls to indicates a brand new checkpoint write is going to happen. A checkpoint_id may be present if users set the checkpoint_id for this checkpoint write. The meaning of the checkpiont_id is storage-dependent. It can be a path to a folder/file or a key for a key-value storage.
- Parameters
checkpoint_id (Union[str, os.PathLike, None]) – The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. It can also be a key if the storage is a key-value store. (Default:
None
)
- abstract set_up_storage_writer(is_coordinator)[source]¶
Initialize this instance.
- Parameters
is_coordinator (bool) – Whether this instance is responsible for coordinating the checkpoint.
- storage_meta()[source]¶
Return the storage-specific metadata. This is used to store additional information in a checkpoint that can be useful for providing request-level observability. StorageMeta is passed to the
SavePlanner
during save calls. Returns None by default.TODO: provide an example
- Return type
Optional[StorageMeta]
- abstract classmethod validate_checkpoint_id(checkpoint_id)[source]¶
Check if the given checkpoint_id is supported by the stroage. This allow us to enable automatic storage selection.
- Return type
- abstract write_data(plan, planner)[source]¶
Write all items from
plan
usingplanner
to resolve the data.A subclass should call
SavePlanner::resolve_data
on each item from the plan to get access to the underlying object to write.Subclasses should lazily call resolve_data as it can allocate memory. In case of tensors, make following assumptions:
They might be on any device, including not matching the one on
WriteItem::tensor_data
They might be views or not contiguous. Only the projection needs to be saved.
- Parameters
plan (SavePlan) – The save plan to execute.
planner (SavePlanner) – Planner object to be used to resolve items to data.
- Returns
A future that completes to a list of WriteResult
- Return type
The following types define the planner interface used during checkpoint:
- class torch.distributed.checkpoint.LoadPlanner[source]¶
Abstract class defining the protocol used by load_state_dict to plan the load process.
LoadPlanner are stateful objects that can be used to customize the whole load process.
LoadPlanner acts as an access proxy to the state_dict, so any transformation done to it will be visible to the whole process.
A planner subclass can expect the following sequence of calls during load_state_dict:
- set_up_planner - called on all ranks.
Signals the start of loading a checkpoint.
- create_local_plan - called on all ranks.
Process the state_dict and produces a LoadPlan that will be sent for global planning.
- create_global_plan - called on the coordinator rank only.
Takes the LoadPlan from all ranks and make any global decision.
- load_bytes - called multiple times on each rank
This is called once per non-tensor value in state_dict.
- resolve_tensor and commit_tensor - called multiple times on each rank
They are called in pair for each Tensor value in state_dict.
Users are recommended to extend DefaultLoadPlanner instead of this interface directly as most changes can be expressed by changes in a single method.
There are two usual patterns of extension:
Rewriting state_dict. This is the simplest way to extend the load process as it doesn’t requite understanding the intrincacies of how LoadPlan works. We need to keep a reference to the original state_dict as load happens in place so we need to be able to perform it in place
>>> class RenamePlanner(DefaultLoadPlanner): >>> def set_up_planner( >>> self, >>> state_dict: STATE_DICT_TYPE, >>> metadata: Metadata, >>> is_coordinator: bool, >>> ) -> None: >>> self.original_state_dict = state_dict >>> state_dict = {"foo_" + k: v for k, v in state_dict.items()} >>> >>> if self.flatten_sharded_tensors: >>> state_dict = _flatten_sharded_tensors(state_dict) >>> >>> if self.flatten_state_dict: >>> state_dict, self.mappings = flatten_state_dict(state_dict) >>> >>> self.state_dict = state_dict >>> self.metadata = metadata >>> self.is_coordinator = is_coordinator >>> >>> def load_bytes(self, read_item, value): >>> # Remove the "foo_" prefix >>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False)
Modifying resolve_tensor and commit_tensor to handle load time transformation.
>>> class MetaModelMaterialize(DefaultSavePlanner): >>> def resolve_tensor(self, read_item): >>> tensor = super().resolve_tensor(read_item) >>> return torch.empty_like(tensor, device="cpu") >>> >>> def commit_tensor(self, read_item, tensor): >>> self.state_dict[read_item.dest_index.fqn] = tensor
- abstract commit_tensor(read_item, tensor)[source]¶
Call once the StorageReader finished loading data into
tensor
.The provided tensor is the same one returned by the call to
resolve_tensor
. This method is only needed if this LoadPlanner needs to post processtensor
prior to copying it back to the one in the state_dict.The contents of tensor will follow its device synchronization model.
- abstract create_global_plan(global_plan)[source]¶
Compute the global load plan and return plans for each rank.
. N.B. This is called on the coordinator rank only
- abstract create_local_plan()[source]¶
Create a LoadPlan based on state_dict and metadata provided by set_up_planner.
. N.B. This is called on every rank.
- Return type
- abstract finish_plan(central_plan)[source]¶
Accept the plan from coordinator and return final LoadPlan.
- Return type
- abstract load_bytes(read_item, value)[source]¶
Load the item described by
read_item``and ``value
.This method is expected to modify in-place the underlying state_dict.
The contents of
value
are defined by the SavePlanner used to produce the checkpoint being loaded.
- resolve_bytes(read_item)[source]¶
Return the BytesIO to be used by the StorageReader to load read_item.
The BytesIO should alias with one on the underlying state_dict as StorageReader will replace its contents.
- Return type
BytesIO
- abstract resolve_tensor(read_item)[source]¶
Return the tensor described by
read_item
to be used by the StorageReader to load read_item.The tensor should alias with one on the underlying state_dict as StorageReader will replace its contents. If, for any reason, that’s not possible, the planner can use the
commit_tensor
method to copy the data back to the one in state_dict.- Return type
- class torch.distributed.checkpoint.LoadPlan(items: List[torch.distributed.checkpoint.planner.ReadItem], storage_data: Any = None, planner_data: Any = None)[source]¶
- class torch.distributed.checkpoint.ReadItem(type: torch.distributed.checkpoint.planner.LoadItemType, dest_index: torch.distributed.checkpoint.metadata.MetadataIndex, dest_offsets: torch.Size, storage_index: torch.distributed.checkpoint.metadata.MetadataIndex, storage_offsets: torch.Size, lengths: torch.Size)[source]¶
- class torch.distributed.checkpoint.SavePlanner[source]¶
Abstract class defining the protocol used by save_state_dict to plan the save process.
SavePlanners are stateful objects that can be used to customize the whole save process.
SavePlanner acts as an access proxy to the state_dict, so any transformation done to it will be visible to the whole process.
A planner subclass can expect the following sequence of calls during save_state_dict:
- set_up_planner - called on all ranks.
Signals the start of a checkpoint save.
- create_local_plan - called on all ranks.
Process the state_dict and produces a SavePlan that will be sent for global planning.
- create_global_plan - called on the coordinator rank only.
Takes the SavePlan from all ranks and make any global decision.
- finish_plan - called on all ranks.
This gives each rank a chance to adjust to global planning decisions.
- resolve_data - called multiple times on each rank
Lookups a value on the state_dict for the storage layer to write.
Users are recommended to extend DefaultSavePlanner instead of this interface directly as most changes can be expressed by changes in a single method.
There are 3 usual patterns of extension:
Rewriting state_dict. This is the simplest way to extend the save process as it doesn’t requite understanding the intrincacies of how SavePlan works:
>>> class RenamePlanner(DefaultSavePlanner): >>> def set_up_planner( >>> self, >>> state_dict: STATE_DICT_TYPE, >>> storage_meta: Optional[StorageMeta], >>> is_coordinator: bool, >>> ) -> None: >>> # prefix all keys with `foo_`` >>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator)
Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted
>>> class FP16Planner(DefaultSavePlanner): >>> def create_local_plan(self): >>> plan = super().create_local_plan() >>> for p in plan: >>> if p.tensor_data is not None: >>> p.tensor_data.properties.dtype = torch.float16 >>> return plan >>> >>> def resolve_data(self, write_item): >>> item = super().resolve_data(write_item) >>> return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)
Using the global planning step to make central decisions that can’t be made individually by each rank
>>> from itertools import zip_longest >>> from dataclasses import replace >>> class DDPLoadBalancingPlanner(DefaultSavePlanner): >>> # This uses the default local plan behavior of having all non-sharded writes in rank 0 >>> # This sample doesn't handle ShardedTensors >>> def create_global_plan(self, all_plans): >>> iters = [iter(all_plans[0].items)] * len(all_plans) >>> items_per_rank = [ >>> [item for item in items if item is not None] >>> for items in zip(*zip_longest(*iters), strict=True) >>> ] >>> all_plans = [ >>> replace(plan, items=items) >>> for plan, items in zip(all_plans, items_per_rank, strict=True) >>> ] >>> return super().create_global_plan(all_plans)
Finally, some planners need to save additional metadata in the checkpoint, this is accomplished by having each rank contribute their data items in the local plan and the global planner aggregate them:
>>> class SaveExtraDataPlanner(DefaultSavePlanner): >>> def create_local_plan(self) -> SavePlan: >>> plan = super().create_local_plan() >>> return replace(plan, planner_data="per-rank-data") >>> >>> def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: >>> global_plan, metadata = super().create_global_plan(all_plans) >>> merged_data = [p.planner_data for p in global_plan] >>> metadata = replace(metadata, planner_data=merged_data) >>> return global_plan, metadata
- abstract create_global_plan(all_plans)[source]¶
Compute the global checkpoint plan and return the local plan of each rank.
This is called on the coordinator rank only.
- abstract create_local_plan()[source]¶
Compute the save plan for the current rank.
This will be aggregated and passed to create_global_plan. Planner specific data can be passed through SavePlan::planner_data.
This is called on all ranks.
- Return type
- abstract finish_plan(new_plan)[source]¶
Merge the plan created by create_local_plan and the result of create_global_plan.
This is called on all ranks.
- Return type
- abstract resolve_data(write_item)[source]¶
Transform and prepare
write_item
fromstate_dict
for storage, ensuring idempotency and thread-safety.Lookup the object associated with
write_item
instate_dict
and apply any transformation (such as serialization) prior to the storage layer consuming it.Called on each rank multiple times, at least once per WriteItem in the final SavePlan.
This method should be idempotent and thread-save. StorageWriter implementations are free to call it as frequently as they need.
Any transformation that allocates memory should be lazily done when his method is called in order to reduce peak memory required by checkpointing.
When returning tensors, they can be on any device or format, they can be views too. It’s the storage layer responsibility to figure out how to save them.
- class torch.distributed.checkpoint.SavePlan(items: List[torch.distributed.checkpoint.planner.WriteItem], storage_data: Any = None, planner_data: Any = None)[source]¶
- class torch.distributed.checkpoint.planner.WriteItem(index, type, tensor_data=None)[source]¶
Dataclass which holds information about what needs to be written to storage.
We provide a filesystem based storage layer:
- class torch.distributed.checkpoint.FileSystemWriter(path, single_file_per_rank=True, sync_files=True, thread_count=1, per_thread_copy_ahead=10000000, cache_staged_state_dict=False, overwrite=True)[source]¶
Basic implementation of StorageWriter using file IO.
This implementation makes the following assumptions and simplifications:
The checkpoint path is an empty or non-existing directory.
File creation is atomic
The checkpoint consist of one file per write request plus a .metadata file with the serialized metadata.
We provide default implementations of LoadPlanner and SavePlanner that can handle all of torch.distributed constructs such as FSDP, DDP, ShardedTensor and DistributedTensor.
- class torch.distributed.checkpoint.DefaultSavePlanner(flatten_state_dict=True, flatten_sharded_tensors=True, dedup_replicated_tensors=None, dedup_save_to_lowest_rank=False)[source]¶
- class torch.distributed.checkpoint.DefaultLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True, allow_partial_load=False)[source]¶
DefaultLoadPlanner that adds multiple features on top of LoadPlanner.
In particular it adds the following:
flatten_state_dict: Handle state_dict with nested dicts flatten_sharded_tensors: For FSDP in 2D parallel mode allow_partial_load: If False, will raise a runtime error if a key is present in state_dict, but not in the checkpoint.
Due to legacy design decisions, the state dictionaries of FSDP and DDP may have different keys or fully qualified names (e.g., layer1.weight) even when the original unparallelized model is identical. Moreover, FSDP offers various types of model state dictionaries, such as full and sharded state dictionaries. Additionally, optimizer state dictionaries employ parameter IDs instead of fully qualified names to identify parameters, potentially causing issues when parallelisms are used (e.g., pipeline parallelism).
To tackle these challenges, we offer a collection of APIs for users to easily manage state_dicts. get_model_state_dict returns a model state dictionary with keys consistent with those returned by the unparallelized model state dictionary. Similarly, get_optimizer_state_dict provides the optimizer state dictionary with keys uniform across all parallelisms applied. To achieve this consistency, get_optimizer_state_dict converts parameter IDs to fully qualified names identical to those found in the unparallelized model state dictionary.
Note that results returned by these APIs can be used directly with the torch.distributed.checkpoint.save() and torch.distributed.checkpoint.load() methods without requiring any additional conversions.
Note that this feature is experimental, and API signatures might change in the future.
- torch.distributed.checkpoint.state_dict.get_state_dict(model, optimizers, *, submodules=None, options=None)[source]¶
Return the model state_dict and optimizers state_dict.
get_state_dict
can process any module that is parallelized by PyTorch FSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and any combination of these parallelisms. The main functions ofget_state_dict
are: 1.) returning a model and optimizer state_dict that can be resharded with a different number of trainers and/or different parallelisms. 2.) hiding the parallelism-specific state_dict APIs. Users don’t have to call these APIs. 3.) sanity checking the result state_dict.The keys of the result state dictionary are the canonical FQNs (Fully Qualified Names). A canonical FQN refers to the FQN based on a parameter’s position in an nn.Module hierarchy. More specifically, a canonical FQN to a parameter is the FQN returned by
module.named_parameters()
ormodule.named_buffers()
when the module is not distributed by any parallelisms. Since the optimizer internally uses parameter IDs to represent a parameter, there will be a conversion from the parameter IDs to the canonical FQNs when calling this API.get_state_dict
can also process a module that is not parallelized. In such a case,get_state_dict
only performs one function – converting the optimizer parameter IDs to the canonical FQNs.Example
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> from torch.distributed.checkpoint.state_dict import get_state_dict
>>> fsdp_model = FSDP(copy.deepcopy(model)) >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) >>> ddp_model = DDP(copy.deepcopy(model)) >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim) >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim)
>>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(), >>> # the asserts will fail. >>> assert ddp_state_dict == fsdp_state_dict >>> assert ddp_optim_state == fsdp_optim_state_dict
- Parameters
model (nn.Module) – the nn.Module to the model.
optimizers (Union[None, Optimizer, Iterable[Optimizer]]) – The optimizers that are used to optimize
model
.submodules (deprecated) – Optional[Set[nn.Module]]: only return the model parameters that belong to the submodules.
options (StateDictOptions) – the options to control how model state_dict and optimizer state_dict should be returned. See StateDictOptions for the details.
- Returns
Tuple
that contain model state_dict and optimizer state_dict.- Return type
- torch.distributed.checkpoint.state_dict.get_model_state_dict(model, *, submodules=None, options=None)[source]¶
Return the model state_dict of
model
.See
get_state_dict
for the detail usage.- Parameters
model (nn.Module) – the nn.Module to the model.
submodules (deprecated) – Optional[Set[nn.Module]]: only return the model parameters that belong to the submodules.
options (StateDictOptions) – the options to control how model state_dict and optimizer state_dict should be returned. See StateDictOptions for the details.
- Returns
The state_dict for
model
.- Return type
- torch.distributed.checkpoint.state_dict.get_optimizer_state_dict(model, optimizers, *, submodules=None, options=None)[source]¶
Return the combined state_dict for optimizers.
See
get_state_dict
for the detail usage.- Parameters
model (nn.Module) – the nn.Module to the model.
optimizers (Union[None, Optimizer, Iterable[Optimizer]]) – The optimizers that are used to optimize
model
.submodules (deprecated) – Optional[Set[nn.Module]]: only return the model parameters that belong to the submodules.
options (StateDictOptions) – the options to control how model state_dict and optimizer state_dict should be returned. See StateDictOptions for the details.
- Returns
The state_dict for
optimizers
.- Return type
OptimizerStateType
- torch.distributed.checkpoint.state_dict.set_state_dict(model, optimizers, *, model_state_dict, optim_state_dict, options=None)[source]¶
Load the model state_dict and optimizers state_dict.
The counterpart of
get_state_dict
to set the state_dict to the model and optimizers. The givenmodel_state_dict
andoptim_state_dict
do not have to be returned byget_state_dict
but must meet the following requirements: 1) all FQNs are canonical FQNs as defined inget_state_dict
, 2) if a tensor is sharded, it must be either a ShardedTensor or DTensor, 3) optimizer state_dict cannot contain the parameter IDs; the keys should be the canonical FQNs.- Parameters
model (nn.Module) – the nn.Module to the model.
optimizers (Union[Optimizer, Iterable[Optimizer]]) – The optimizers that are used to optimize
model
.model_state_dict (Dict[str, ValueType]) – (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): the model state_dict to load. If the key of the
model_state_dict
is nn.Module, the key is a submodule ofmodel
and the value should be the state_dict of the submodule. When loading the state_dict, the prefix of the submodule will be append to the state_dict.optim_state_dict (OptimizerStateType) – OptimizerStateType: the optimizer state_dict to load.
options (StateDictOptions) – the options to control how model state_dict and optimizer state_dict should be loaded. See StateDictOptions for the details.
- Returns
missing_keys is a list of str containing the missing keys of the model state_dict.
unexpected_keys is a list of str containing the unexpected keys of the model state_dict.
- Return type
NamedTuple
withmissing_keys
andunexpected_keys
fields
- torch.distributed.checkpoint.state_dict.set_model_state_dict(model, model_state_dict, *, options=None)[source]¶
Load the model state_dict.
The counterpart of
get_model_state_dict
to set the state_dict to the model. Seeset_state_dict
for the detail usage.- Parameters
model (nn.Module) – the nn.Module to the model.
model_state_dict (Dict[str, ValueType]) – (Dict[str, ValueType]): the model state_dict to load. If the key of the
model_state_dict
is nn.Module, the key is a submodule ofmodel
and the value should be the state_dict of the submodule. When loading the state_dict, the prefix of the submodule will be append to the state_dict.options (StateDictOptions) – the options to control how model state_dict and optimizer state_dict should be loaded. See StateDictOptions for the details.
- Returns
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type
NamedTuple
withmissing_keys
andunexpected_keys
fields
- torch.distributed.checkpoint.state_dict.set_optimizer_state_dict(model, optimizers, optim_state_dict, *, options=None)[source]¶
Load the optimizers state_dict.
The counterpart of
get_optimizer_state_dict
to set the state_dict to the optimizers. Seeset_state_dict
for the detail usage.- Parameters
model (nn.Module) – the nn.Module to the model.
optimizers (Union[Optimizer, Iterable[Optimizer]]) – The optimizers that are used to optimize
model
.optim_state_dict (OptimizerStateType) – OptimizerStateType: the optimizer state_dict to load.
options (StateDictOptions) – the options to control how model state_dict and optimizer state_dict should be loaded. See StateDictOptions for the details.
- Returns
None
- Return type
None
- class torch.distributed.checkpoint.state_dict.StateDictOptions(full_state_dict=False, cpu_offload=False, ignore_frozen_params=False, keep_submodule_prefixes=True, strict=True, broadcast_from_rank0=False, flatten_optimizer_state_dict=False)[source]¶
This dataclass specifies how get_state_dict/set_state_dict will work.
full_state_dict
: if this is set to True, all the tensors in the returned state_dict will be gathered. No ShardedTensor and DTensor will be in the returned state_dict.cpu_offload
: offload all the tensors to cpu. To prevent CPU OOM, iffull_state_dict
is also true, then only the rank0 will get the state_dict and all other ranks will get empty state_dict.ignore_frozen_params
: if the value is True, the returned state_dict won’t contain any frozen parameters – therequires_grad
is False. The default value is False.keep_submodule_prefixes
(deprecated): whensubmodules
is not None, this option indicates whether to keep the submodule prefixes from the state_dict keys. or example, if the submodule ismodule.pretrain
and the full FQN of the parameter ispretrain.layer1.weight
of the param. When this option is True, the parameter’s key in the returned state_dict will bepretrain.layer1.weight
. If the options is False, the key will belayer1.weight
. Note that ifkeep_submodule_prefixes
is False, there may be conflicted FQNs, hence there should be only one submodule insubmodules
.strict
: thestrict
option whenset_state_dict
calls model.load_state_dict().broadcast_from_rank0
: when the option is True, rank0 should receive afull state_dict and will broadcast the tensors in the state_dict/ optim_state_dict one by one to other ranks. Other ranks will receive the tensors and shard according to the local shards in the model and optimizer.
full_state_dict
must be set to True when using this option. This option currently only supports DTensor, not the legacy ShardedTensor.
For users which are used to using and sharing models in the torch.save format, the following methods are provided which provide offline utilities for converting betweeing formats.
- torch.distributed.checkpoint.format_utils.dcp_to_torch_save(dcp_checkpoint_dir, torch_save_path)[source]¶
Given a directory containing a DCP checkpoint, this function will convert it into a Torch save file.
- Parameters
Warning
To avoid OOM, it’s recommended to only run this function on a single rank.
- torch.distributed.checkpoint.format_utils.torch_save_to_dcp(torch_save_path, dcp_checkpoint_dir)[source]¶
Given the location of a torch save file, converts it into a DCP checkpoint.
- Parameters
Warning
To avoid OOM, it’s recommended to only run this function on a single rank.
The following classes can also be utilized for online loading and resharding of models from the torch.save format.
- class torch.distributed.checkpoint.format_utils.BroadcastingTorchSaveReader(checkpoint_id=None, coordinator_rank=0)[source]¶
StorageReader for reading a Torch Save file. This reader will read the entire checkpoint on the coordinator rank, and then broadcast and shard each tensor to all ranks.
. N.B. Intended to be used with DynamicMetaLoadPlanner
Warning
Current implementation only supports loading Tensors.
>>> sd = {"mode": model} >>> dcp.load( >>> sd, >>> storage_reader=BroadcastingTorchSaveReader(), >>> planner=DynamicMetaLoadPlanner(), >>> checkpoint_id="path_to_model.pt" >>> )
- read_data(plan, planner)[source]¶
Reads torch save data on the coordinator rank, and broadcast afterwards this incurrs a communication cost, but avoids having to load the entire checkpoint on each rank, hopefully preventing OOM issues
- Return type
Future[None]
- class torch.distributed.checkpoint.format_utils.DynamicMetaLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True, allow_partial_load=False)[source]¶
Extension of DefaultLoadPlanner, which creates a new Metadata object based on the passed in state dict, avoiding the need to read metadata from disk. This is useful when reading formats which don’t have a metadata file, like Torch Save files.
. N.B. Intended to be used with BroadcastingTorchSaveReader
Warning
Current implementation only supports loading Tensors.
>>> sd = {"mode": model} >>> dcp.load( >>> sd, >>> storage_reader=BroadcastingTorchSaveReader(), >>> planner=DynamicMetaLoadPlanner(), >>> checkpoint_id="path_to_model.pt" >>> )
The following experimental interfaces are provided for improved observability in production environments: