Shortcuts

Source code for torch.distributed.checkpoint.format_utils

# mypy: allow-untyped-defs
import argparse
import os
from enum import Enum
from typing import cast, Dict, List, Optional, Union

import torch
import torch.distributed as dist
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
from torch.distributed.checkpoint.default_planner import (
    _EmptyStateDictLoadPlanner,
    DefaultLoadPlanner,
)
from torch.distributed.checkpoint.metadata import (
    Metadata,
    STATE_DICT_TYPE,
    STORAGE_TYPES,
    TensorProperties,
    TensorStorageMetadata,
)
from torch.distributed.checkpoint.planner import LoadItemType, LoadPlan, LoadPlanner
from torch.distributed.checkpoint.planner_helpers import _create_chunk_list
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
from torch.distributed.checkpoint.state_dict_saver import _save_state_dict
from torch.distributed.checkpoint.storage import StorageReader
from torch.futures import Future


__all__ = [
    "dcp_to_torch_save",
    "torch_save_to_dcp",
    "BroadcastingTorchSaveReader",
    "DynamicMetaLoadPlanner",
]


[docs]class BroadcastingTorchSaveReader(StorageReader): """ StorageReader for reading a Torch Save file. This reader will read the entire checkpoint on the coordinator rank, and then broadcast and shard each tensor to all ranks. . N.B. Intended to be used with DynamicMetaLoadPlanner .. warning:: Current implementation only supports loading Tensors. >>> # xdoctest: +SKIP("undefined vars") >>> sd = {"mode": model} >>> dcp.load( >>> sd, >>> storage_reader=BroadcastingTorchSaveReader(), >>> planner=DynamicMetaLoadPlanner(), >>> checkpoint_id="path_to_model.pt" >>> ) """ def __init__( self, checkpoint_id: Optional[Union[str, os.PathLike]] = None, coordinator_rank: int = 0, ) -> None: self.checkpoint_id = checkpoint_id self.coordinator_rank = coordinator_rank
[docs] def read_metadata(self) -> Metadata: """Extends the default StorageReader to support building the metadata file""" # Metadata is built in planner.set_up_planner, since we are not actually reading metadata from # the disk return Metadata(state_dict_metadata={})
[docs] def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: """ Reads torch save data on the coordinator rank, and broadcast afterwards this incurrs a communication cost, but avoids having to load the entire checkpoint on each rank, hopefully preventing OOM issues """ planner = cast(DefaultLoadPlanner, planner) # data is read in on the coordinator rank, and broadcast afterwards # this incurrs a communication cost, but it avoids having to load # the entire checkpoint on each rank, hopefully preventing OOM issues # TODO: read on each host, instead of only the coordinator if self.is_coordinator: assert self.checkpoint_id is not None torch_state_dict = torch.load( self.checkpoint_id, map_location="cpu", weights_only=False ) if planner.flatten_state_dict: torch_state_dict, _ = flatten_state_dict(torch_state_dict) else: torch_state_dict = None for req in plan.items: if req.type == LoadItemType.BYTE_IO: raise RuntimeError( f"Non-tensor value identified at {req.storage_index.fqn}. " f"At this time {type(self).__name__} only supports loading Tensors." ) # Broadcast the tensor from the coordinator rank if self.is_coordinator: tensor = torch_state_dict[req.storage_index.fqn].cuda() else: tensor = torch.empty_like(planner.state_dict[req.storage_index.fqn]) dist.broadcast(tensor, src=self.coordinator_rank, async_op=False) tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths) target_tensor = planner.resolve_tensor(req).detach() assert target_tensor.size() == tensor.size(), ( f"req {req.storage_index} mismatch sizes, " f"{target_tensor.size()} vs {tensor.size()}" ) target_tensor.copy_(tensor) planner.commit_tensor(req, target_tensor) fut: Future = Future() fut.set_result(None) return fut
[docs] def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: """Implementation of the StorageReader method""" self.is_coordinator = is_coordinator if self.is_coordinator: assert dist.get_rank() == self.coordinator_rank assert self.checkpoint_id is not None
[docs] def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: """Implementation of the StorageReader method""" return plan
[docs] def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]: """Implementation of the StorageReader method""" return global_plan
[docs] def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: """Implementation of the StorageReader method""" self.checkpoint_id = checkpoint_id
[docs] @classmethod def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: """Implementation of the StorageReader method""" return os.path.isfile(checkpoint_id)
[docs]class DynamicMetaLoadPlanner(DefaultLoadPlanner): """ Extension of DefaultLoadPlanner, which creates a new Metadata object based on the passed in state dict, avoiding the need to read metadata from disk. This is useful when reading formats which don't have a metadata file, like Torch Save files. . N.B. Intended to be used with BroadcastingTorchSaveReader .. warning:: Current implementation only supports loading Tensors. >>> # xdoctest: +SKIP("undefined vars") >>> sd = {"mode": model} >>> dcp.load( >>> sd, >>> storage_reader=BroadcastingTorchSaveReader(), >>> planner=DynamicMetaLoadPlanner(), >>> checkpoint_id="path_to_model.pt" >>> ) """
[docs] def set_up_planner( self, state_dict: STATE_DICT_TYPE, metadata: Optional[Metadata] = None, is_coordinator: bool = False, ) -> None: """Setups of the planner, extnding default behavior by creating the Metadata object from the state dict""" super().set_up_planner(state_dict, metadata, is_coordinator) state_dict_metadata: Dict[str, STORAGE_TYPES] = {} for key, tensor in self.state_dict.items(): if not torch.is_tensor(tensor): raise RuntimeError( f"Non-tensor value identified at {key}. " f"At this time {type(self).__name__} only supports loading Tensors." ) state_dict_metadata[key] = TensorStorageMetadata( TensorProperties(dtype=tensor.dtype), tensor.size(), _create_chunk_list(tensor), ) self.metadata = Metadata(state_dict_metadata=state_dict_metadata)
[docs]def dcp_to_torch_save( dcp_checkpoint_dir: Union[str, os.PathLike], torch_save_path: Union[str, os.PathLike], ): """ Given a directory containing a DCP checkpoint, this function will convert it into a Torch save file. Args: dcp_checkpoint_dir: Directory containing the DCP checkpoint. torch_save_path: Filename to store the converted Torch save file. .. warning:: To avoid OOM, it's recommended to only run this function on a single rank. """ sd: STATE_DICT_TYPE = {} _load_state_dict( sd, storage_reader=FileSystemReader(dcp_checkpoint_dir), planner=_EmptyStateDictLoadPlanner(), no_dist=True, ) torch.save(sd, torch_save_path)
[docs]def torch_save_to_dcp( torch_save_path: Union[str, os.PathLike], dcp_checkpoint_dir: Union[str, os.PathLike], ): """ Given the location of a torch save file, converts it into a DCP checkpoint. Args: torch_save_path: Filename of the Torch save file. dcp_checkpoint_dir: Directory to store the DCP checkpoint. .. warning:: To avoid OOM, it's recommended to only run this function on a single rank. """ state_dict = torch.load(torch_save_path, weights_only=False) # we don't need stateful behavior here because the expectation is anything loaded by # torch.load would not contain stateful objects. _save_state_dict( state_dict, storage_writer=FileSystemWriter(dcp_checkpoint_dir), no_dist=True )
if __name__ == "__main__": class FormatMode(Enum): TORCH_TO_DCP = "torch_to_dcp" DCP_TO_TORCH = "dcp_to_torch" # Parse command-line arguments parser = argparse.ArgumentParser() parser.add_argument( "mode", type=str, help="Conversion mode", choices=[m.value for m in FormatMode], default=FormatMode.TORCH_TO_DCP, ) parser.add_argument("src", type=str, help="Path to the source model") parser.add_argument("dst", type=str, help="Path to the destination model") args = parser.parse_args() print( f"Converting checkpoint from {args.src} to {args.dst} using method: '{args.mode}'" ) checkpoint_missing_warning = ( f"No checkpoint found at {args.src}. Skipping conversion." ) if args.mode == FormatMode.TORCH_TO_DCP.value: if os.path.isfile(args.src): torch_save_to_dcp(args.src, args.dst) else: print(checkpoint_missing_warning) elif args.mode == FormatMode.DCP_TO_TORCH.value: if os.path.isdir(args.src): dcp_to_torch_save(args.src, args.dst) else: print(checkpoint_missing_warning) else: raise ValueError(f"Unknown conversion mode: {args.mode}")

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources