[docs]classBroadcastingTorchSaveReader(StorageReader):""" StorageReader for reading a Torch Save file. This reader will read the entire checkpoint on the coordinator rank, and then broadcast and shard each tensor to all ranks. . N.B. Intended to be used with DynamicMetaLoadPlanner .. warning:: Current implementation only supports loading Tensors. >>> # xdoctest: +SKIP("undefined vars") >>> sd = {"mode": model} >>> dcp.load( >>> sd, >>> storage_reader=BroadcastingTorchSaveReader(), >>> planner=DynamicMetaLoadPlanner(), >>> checkpoint_id="path_to_model.pt" >>> ) """def__init__(self,checkpoint_id:Optional[Union[str,os.PathLike]]=None,coordinator_rank:int=0,)->None:self.checkpoint_id=checkpoint_idself.coordinator_rank=coordinator_rank
[docs]defread_metadata(self)->Metadata:"""Extends the default StorageReader to support building the metadata file"""# Metadata is built in planner.set_up_planner, since we are not actually reading metadata from# the diskreturnMetadata(state_dict_metadata={})
[docs]defread_data(self,plan:LoadPlan,planner:LoadPlanner)->Future[None]:""" Reads torch save data on the coordinator rank, and broadcast afterwards this incurrs a communication cost, but avoids having to load the entire checkpoint on each rank, hopefully preventing OOM issues """planner=cast(DefaultLoadPlanner,planner)# data is read in on the coordinator rank, and broadcast afterwards# this incurrs a communication cost, but it avoids having to load# the entire checkpoint on each rank, hopefully preventing OOM issues# TODO: read on each host, instead of only the coordinatorifself.is_coordinator:assertself.checkpoint_idisnotNonetorch_state_dict=torch.load(self.checkpoint_id,map_location="cpu",weights_only=False)ifplanner.flatten_state_dict:torch_state_dict,_=flatten_state_dict(torch_state_dict)else:torch_state_dict=Noneforreqinplan.items:ifreq.type==LoadItemType.BYTE_IO:raiseRuntimeError(f"Non-tensor value identified at {req.storage_index.fqn}. "f"At this time {type(self).__name__} only supports loading Tensors.")# Broadcast the tensor from the coordinator rankifself.is_coordinator:pg_device=dist.distributed_c10d._get_pg_default_device()tensor=torch_state_dict[req.storage_index.fqn].to(pg_device)else:tensor=torch.empty_like(planner.state_dict[req.storage_index.fqn])dist.broadcast(tensor,src=self.coordinator_rank,async_op=False)tensor=narrow_tensor_by_index(tensor,req.storage_offsets,req.lengths)target_tensor=planner.resolve_tensor(req).detach()asserttarget_tensor.size()==tensor.size(),(f"req {req.storage_index} mismatch sizes, "f"{target_tensor.size()} vs {tensor.size()}")target_tensor.copy_(tensor)planner.commit_tensor(req,target_tensor)fut:Future=Future()fut.set_result(None)returnfut
[docs]defset_up_storage_reader(self,metadata:Metadata,is_coordinator:bool)->None:"""Implementation of the StorageReader method"""self.is_coordinator=is_coordinatorifself.is_coordinator:assertdist.get_rank()==self.coordinator_rankassertself.checkpoint_idisnotNone
[docs]defprepare_local_plan(self,plan:LoadPlan)->LoadPlan:"""Implementation of the StorageReader method"""returnplan
[docs]defprepare_global_plan(self,global_plan:List[LoadPlan])->List[LoadPlan]:"""Implementation of the StorageReader method"""returnglobal_plan
[docs]defreset(self,checkpoint_id:Union[str,os.PathLike,None]=None)->None:"""Implementation of the StorageReader method"""self.checkpoint_id=checkpoint_id
[docs]@classmethoddefvalidate_checkpoint_id(cls,checkpoint_id:Union[str,os.PathLike])->bool:"""Implementation of the StorageReader method"""returnos.path.isfile(checkpoint_id)
[docs]classDynamicMetaLoadPlanner(DefaultLoadPlanner):""" Extension of DefaultLoadPlanner, which creates a new Metadata object based on the passed in state dict, avoiding the need to read metadata from disk. This is useful when reading formats which don't have a metadata file, like Torch Save files. . N.B. Intended to be used with BroadcastingTorchSaveReader .. warning:: Current implementation only supports loading Tensors. >>> # xdoctest: +SKIP("undefined vars") >>> sd = {"mode": model} >>> dcp.load( >>> sd, >>> storage_reader=BroadcastingTorchSaveReader(), >>> planner=DynamicMetaLoadPlanner(), >>> checkpoint_id="path_to_model.pt" >>> ) """
[docs]defset_up_planner(self,state_dict:STATE_DICT_TYPE,metadata:Optional[Metadata]=None,is_coordinator:bool=False,)->None:"""Setups of the planner, extnding default behavior by creating the Metadata object from the state dict"""super().set_up_planner(state_dict,metadata,is_coordinator)state_dict_metadata:Dict[str,STORAGE_TYPES]={}forkey,tensorinself.state_dict.items():ifnottorch.is_tensor(tensor):raiseRuntimeError(f"Non-tensor value identified at {key}. "f"At this time {type(self).__name__} only supports loading Tensors.")state_dict_metadata[key]=TensorStorageMetadata(TensorProperties(dtype=tensor.dtype),tensor.size(),_create_chunk_list(tensor),)self.metadata=Metadata(state_dict_metadata=state_dict_metadata)
[docs]defdcp_to_torch_save(dcp_checkpoint_dir:Union[str,os.PathLike],torch_save_path:Union[str,os.PathLike],):""" Given a directory containing a DCP checkpoint, this function will convert it into a Torch save file. Args: dcp_checkpoint_dir: Directory containing the DCP checkpoint. torch_save_path: Filename to store the converted Torch save file. .. warning:: To avoid OOM, it's recommended to only run this function on a single rank. """sd:STATE_DICT_TYPE={}_load_state_dict(sd,storage_reader=FileSystemReader(dcp_checkpoint_dir),planner=_EmptyStateDictLoadPlanner(),no_dist=True,)torch.save(sd,torch_save_path)
[docs]deftorch_save_to_dcp(torch_save_path:Union[str,os.PathLike],dcp_checkpoint_dir:Union[str,os.PathLike],):""" Given the location of a torch save file, converts it into a DCP checkpoint. Args: torch_save_path: Filename of the Torch save file. dcp_checkpoint_dir: Directory to store the DCP checkpoint. .. warning:: To avoid OOM, it's recommended to only run this function on a single rank. """state_dict=torch.load(torch_save_path,weights_only=False)# we don't need stateful behavior here because the expectation is anything loaded by# torch.load would not contain stateful objects._save_state_dict(state_dict,storage_writer=FileSystemWriter(dcp_checkpoint_dir),no_dist=True)
if__name__=="__main__":classFormatMode(Enum):TORCH_TO_DCP="torch_to_dcp"DCP_TO_TORCH="dcp_to_torch"# Parse command-line argumentsparser=argparse.ArgumentParser()parser.add_argument("mode",type=str,help="Conversion mode",choices=[m.valueforminFormatMode],default=FormatMode.TORCH_TO_DCP,)parser.add_argument("src",type=str,help="Path to the source model")parser.add_argument("dst",type=str,help="Path to the destination model")args=parser.parse_args()print(f"Converting checkpoint from {args.src} to {args.dst} using method: '{args.mode}'")checkpoint_missing_warning=(f"No checkpoint found at {args.src}. Skipping conversion.")ifargs.mode==FormatMode.TORCH_TO_DCP.value:ifos.path.isfile(args.src):torch_save_to_dcp(args.src,args.dst)else:print(checkpoint_missing_warning)elifargs.mode==FormatMode.DCP_TO_TORCH.value:ifos.path.isdir(args.src):dcp_to_torch_save(args.src,args.dst)else:print(checkpoint_missing_warning)else:raiseValueError(f"Unknown conversion mode: {args.mode}")
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.