Shortcuts

Source code for torch.distributed.checkpoint.storage

import abc
import os
from dataclasses import dataclass
from typing import Any, List, Optional, Union

from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta
from torch.distributed.checkpoint.planner import (
    LoadPlan,
    LoadPlanner,
    SavePlan,
    SavePlanner,
)

from torch.futures import Future

__all__ = ["WriteResult", "StorageWriter", "StorageReader"]


@dataclass(frozen=True)
class WriteResult:
    index: MetadataIndex

    size_in_bytes: int
    storage_data: Any


[docs]class StorageWriter(abc.ABC): """ Interface used by ``save_state_dict`` to write to storage. One StorageWriter instance acts as both the coordinator and the follower in a distributed checkpoint. As part of initialization, each instance is told its role. A subclass should expect the following sequence of calls. 0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id. 1) (all ranks) set_up_storage_writer() 2) (all ranks) prepare_local_plan() 3) (coordinator) prepare_global_plan() 4) (all ranks) write_data() 5) (coordinator) finish() """
[docs] @abc.abstractmethod def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: """ Calls to indicates a brand new checkpoint write is going to happen. A checkpoint_id may be present if users set the checkpoint_id for this checkpoint write. The meaning of the checkpiont_id is storage-dependent. It can be a path to a folder/file or a key for a key-value storage. Args: checkpoint_id (Union[str, os.PathLike, None]): The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. It can also be a key if the storage is a key-value store. (Default: ``None``) """ ...
[docs] @abc.abstractmethod def set_up_storage_writer(self, is_coordinator: bool) -> None: """ Initialize this instance. Args: is_coordinator (bool): Whether this instance is responsible for coordinating the checkpoint. """ pass
[docs] @abc.abstractmethod def prepare_local_plan(self, plan: SavePlan) -> SavePlan: """ Perform storage-specific local planning. While this method can produce a completely different plan, the recommended way is to store storage specific data in SavePlan::storage_data. Args: plan (SavePlan): The local plan from the ``SavePlanner`` in use. Returns: A transformed ``SavePlan`` after storage local planning """ pass
[docs] @abc.abstractmethod def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]: """ Perform centralized planning of storage. This method is only called on the coordinator instance. While this method can produce a completely different plan, the preferred way is to store storage specific data in SavePlan::storage_data. Args: plans: A list of ``SavePlan`` instances, one for each rank. Returns: A list of transformed ``SavePlan`` after storage global planning """ pass
[docs] @abc.abstractmethod def write_data( self, plan: SavePlan, planner: SavePlanner ) -> Future[List[WriteResult]]: """ Write all items from ``plan`` using ``planner`` to resolve the data. A subclass should call ``SavePlanner::resolve_data`` on each item from the plan to get access to the underlying object to write. Subclasses should lazily call `resolve_data` as it can allocate memory. In case of tensors, make following assumptions: - They might be on any device, including not matching the one on ``WriteItem::tensor_data`` - They might be views or not contiguous. Only the projection needs to be saved. Args: plan (SavePlan): The save plan to execute. planner (SavePlanner): Planner object to be used to resolve items to data. Returns: A future that completes to a list of WriteResult """ pass
[docs] @abc.abstractmethod def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: """ Write the metadata and marks the current checkpoint as successful. The actual format/schema used for serializing `metadata` is an implementation detail. The only requirement is that it's recoverable in to the same object graph. Args: metadata (Metadata): metadata for the new checkpoint results: A list of WriteResults from all ranks. Returns: None """ pass
[docs] @classmethod @abc.abstractmethod def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: """ Check if the given checkpoint_id is supported by the stroage. This allow us to enable automatic storage selection. """ ...
[docs] def storage_meta(self) -> Optional[StorageMeta]: """ Return the storage-specific metadata. This is used to store additional information in a checkpoint that can be useful for providing request-level observability. StorageMeta is passed to the ``SavePlanner`` during save calls. Returns None by default. TODO: provide an example """ return None
[docs]class StorageReader(abc.ABC): """ Interface used by ``load_state_dict`` to read from storage. One StorageReader instance acts as both the coordinator and the follower in a distributed checkpoint. As part of initialization, each instance is told its role. A subclass should expected the following sequence of calls by ``load_state_dict``: 0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id. 1) (all ranks) read_metadata() 2) (all ranks) set_up_storage_reader() 3) (all ranks) prepare_local_plan() 4) (coordinator) prepare_global_plan() 5) (all ranks) read_data() """
[docs] @abc.abstractmethod def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: """ Calls to indicates a brand new checkpoint read is going to happen. A checkpoint_id may be present if users set the checkpoint_id for this checkpoint read. The meaning of the checkpiont_id is storage-dependent. It can be a path to a folder/file or a key for a key-value storage. Args: checkpoint_id (Union[str, os.PathLike, None]): The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. It can also be a key if the storage is more like a key-value store. (Default: ``None``) """ ...
[docs] @abc.abstractmethod def read_metadata(self) -> Metadata: """ Read the checkpoint metadata. Returns: The metadata object associated with the checkpoint being loaded. """ pass
[docs] @abc.abstractmethod def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: """ Initialize this instance. Args: metadata (Metadata): The metadata schema to use. is_coordinator (bool): Whether this instance is responsible for coordinating the checkpoint. """ pass
[docs] @abc.abstractmethod def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: """ Perform storage-specific local planning. While this method can produce a completely different plan, the recommended way is to store storage specific data in LoadPlan::storage_data. Args: plan (LoadPlan): The local plan from the ``LoadPlan`` in use. Returns: A transformed ``LoadPlan`` after storage local planning """ pass
[docs] @abc.abstractmethod def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]: """ Perform centralized planning of storage loading. This method is only called on the coordinator instance. While this method can produce a completely different plan, the preferred way is to store storage specific data in LoadPlan::storage_data. Args: plans: A list of ``LoadPlan`` instances, one for each rank. Returns: A list of transformed ``LoadPlan`` after storage global planning """ pass
[docs] @abc.abstractmethod def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: """ Read all items from ``plan`` using ``planner`` to resolve the data. A subclass should call ``LoadPlanner::load_bytes`` to deserialize a BytesIO object into the right place. A subclass should call ``LoadPlanner::resolve_tensor`` to get access to the tensors that in should load data into. It's the StorageLayer responsibility to properly schedule any cross device copies required. Args: plan (LoadPlan): The local plan to execute on planner (LoadPlanner): The planner object to use to resolve items. Returns: A future that completes once all reads are finished. """ pass
[docs] @classmethod @abc.abstractmethod def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: """ Check if the given checkpoint_id is supported by the stroage. This allow us to enable automatic storage selection. """ ...

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources