Source code for torch.distributed.checkpoint.default_planner
# Copyright (c) Meta Platforms, Inc. and affiliatesimportdataclassesimportioimportloggingimportoperatorfromcollectionsimportChainMapfromfunctoolsimportreducefromtypingimportList,Tuple,Dict,Any,Union,castimporttorchfromtorch.distributed._shard._utilsimportnarrow_tensor_by_indexfromtorch.distributed._tensorimportDTensorfromtorch.distributed.checkpoint.plannerimport(SavePlanner,LoadPlanner,SavePlan,LoadPlan,ReadItem,WriteItem,WriteItemType,)fromtorch.distributed.checkpoint.metadataimport(BytesStorageMetadata,ChunkStorageMetadata,TensorStorageMetadata,MetadataIndex,Metadata,STATE_DICT_TYPE,STORAGE_TYPES,)fromtorch.distributed.checkpoint.planner_helpersimport(_create_read_items,_create_write_items,_create_default_metadata_only_plan,)fromtorch.distributed.checkpoint._nested_dictimport(FLATTEN_MAPPING,flatten_state_dict,)fromtorch.distributed.checkpoint._sharded_tensor_utilsimport(_flatten_sharded_tensors,)fromtorch.distributed.checkpoint._dedup_tensorsimportdedup_tensorsfromtorch.distributed.checkpoint.utilsimportfind_state_dict_objectfromtorch.distributed.checkpoint._traverseimportset_elementlogger:logging.Logger=logging.getLogger(__file__)__all__=["DefaultSavePlanner","DefaultLoadPlanner","create_default_local_load_plan","create_default_global_load_plan","create_default_local_save_plan","create_default_global_save_plan",]# TODO: Update docstrings for default_planner.py
[docs]classDefaultSavePlanner(SavePlanner):mappings:FLATTEN_MAPPINGdef__init__(self,flatten_state_dict:bool=True,flatten_sharded_tensors:bool=True,dedup_replicated_tensors:bool=True,)->None:self.flatten_state_dict=flatten_state_dictself.flatten_sharded_tensors=flatten_sharded_tensorsself.dedup_replicated_tensors=dedup_replicated_tensorsself.mappings={}defset_up_planner(self,state_dict:STATE_DICT_TYPE,is_coordinator:bool)->None:ifself.flatten_state_dict:state_dict,self.mappings=flatten_state_dict(state_dict)ifself.flatten_sharded_tensors:state_dict=_flatten_sharded_tensors(state_dict)self.state_dict=state_dictself.is_coordinator=is_coordinatordefcreate_local_plan(self)->SavePlan:plan=create_default_local_save_plan(self.state_dict,self.is_coordinator)ifself.flatten_state_dict:plan=dataclasses.replace(plan,planner_data=self.mappings)self.plan=planreturnself.plandefcreate_global_plan(self,all_plans:List[SavePlan])->Tuple[List[SavePlan],Metadata]:ifself.dedup_replicated_tensors:all_plans=dedup_tensors(all_plans)global_plan,metadata=create_default_global_save_plan(all_plans)ifself.flatten_state_dict:# | does not work for Python 3.8 or older version.# merged_mappings = reduce(# lambda x, y: x | y, (p.planner_data for p in global_plan)# )planner_data_dict=[p.planner_dataforpinglobal_plan]merged_mappings=dict(ChainMap(*planner_data_dict))metadata=dataclasses.replace(metadata,planner_data=merged_mappings)ifnot_validate_global_plan(global_plan,metadata):raiseValueError("Failed to validate global plan")self.global_plan=global_planself.metadata=metadatareturnself.global_plan,self.metadatadeffinish_plan(self,new_plan:SavePlan)->SavePlan:self.plan=new_planreturnnew_plandefresolve_data(self,write_item:WriteItem)->Union[torch.Tensor,io.BytesIO]:object=self.lookup_object(write_item.index)returnself.transform_object(write_item,object)
[docs]deflookup_object(self,index:MetadataIndex)->Any:""" This is an extension from the planner interface to make it easy to extend the default planner """returnfind_state_dict_object(self.state_dict,index)
[docs]deftransform_object(self,write_item:WriteItem,object:Any):""" This is an extension from the planner interface to make it easy to extend the default planner """ifwrite_item.type==WriteItemType.BYTE_IO:bytes=io.BytesIO()torch.save(object,bytes)object=bytesreturnobject
[docs]classDefaultLoadPlanner(LoadPlanner):""" DefaultLoadPlanner that adds multiple features on top of LoadPlanner. In particular it adds the following: flatten_state_dict: Handle state_dict with nested dicts flatten_sharded_tensors: For FSDP in 2D parallel mode """original_state_dict:STATE_DICT_TYPEmappings:FLATTEN_MAPPINGdef__init__(self,flatten_state_dict:bool=True,flatten_sharded_tensors:bool=True,)->None:self.flatten_state_dict=flatten_state_dictself.flatten_sharded_tensors=flatten_sharded_tensorsself.original_state_dict={}self.mappings={}defset_up_planner(self,state_dict:STATE_DICT_TYPE,metadata:Metadata,is_coordinator:bool,)->None:self.original_state_dict=state_dictifself.flatten_sharded_tensors:state_dict=_flatten_sharded_tensors(state_dict)ifself.flatten_state_dict:state_dict,self.mappings=flatten_state_dict(state_dict)self.state_dict=state_dictself.metadata=metadataself.is_coordinator=is_coordinatordefcreate_local_plan(self)->LoadPlan:returncreate_default_local_load_plan(self.state_dict,self.metadata)defcreate_global_plan(self,global_plan:List[LoadPlan])->List[LoadPlan]:returncreate_default_global_load_plan(global_plan)deffinish_plan(self,new_plan:LoadPlan)->LoadPlan:returnnew_plandefload_bytes(self,read_item:ReadItem,value:io.BytesIO)->None:ifself.flatten_state_dict:set_element(self.original_state_dict,self.mappings[read_item.dest_index.fqn],torch.load(value),)else:self.state_dict[read_item.dest_index.fqn]=torch.load(value)defresolve_tensor(self,read_item:ReadItem):tensor=self.lookup_tensor(read_item.dest_index)returnself.transform_tensor(read_item,tensor)defcommit_tensor(self,read_item:ReadItem,tensor:torch.Tensor)->None:pass
[docs]deflookup_tensor(self,index:MetadataIndex)->torch.Tensor:""" This is an extension from the planner interface to make it easy to extend the default planner """returnfind_state_dict_object(self.state_dict,index)
[docs]deftransform_tensor(self,read_item:ReadItem,tensor:torch.Tensor):""" This is an extension from the planner interface to make it easy to extend the default planner """returnnarrow_tensor_by_index(tensor,read_item.dest_offsets,read_item.lengths)
defcreate_default_local_load_plan(state_dict:Dict[str,Any],metadata:Metadata,)->LoadPlan:requests=[]""" Create the ``LoadPlan`` used by DefaultLoadPlanner. It produces one read item per value in ``state_dict`` using the metadata in ``metadata``. The default behavior is to match key exactly between state_dict and metadata. It handles resharding by issuing multiple read requests against storage in order to match load requirements. """forfqn,objinstate_dict.items():md=metadata.state_dict_metadata[fqn]# Since DTensor supports submesh, adding extra check to ensure _create_read_items()# gets called only when the current rank is part of the mesh for the corresponding DTensor.ifisinstance(obj,DTensor):ifobj.device_mesh.get_coordinate()isnotNone:requests+=_create_read_items(fqn,md,obj)else:requests+=_create_read_items(fqn,md,obj)returnLoadPlan(requests)defcreate_default_global_load_plan(all_plans:List[LoadPlan],)->List[LoadPlan]:""" Create global load plan used by DefaultLoadPlanner. The default load behavior involved no global coordination and this function currently doesn't change the local plans. """returnall_plansdefcreate_default_local_save_plan(state_dict:Dict[str,Any],is_coordinator:bool)->SavePlan:""" Create the ``SavePlan`` used by DefaultSavePlanner. On non-coordinator ranks, this function ignores tensors and non-tensor objects, only producing writes for ShardedTensor objects. On the coordinator rank, produce writes for all values. """requests=[]forfqn,objinstate_dict.items():# Since DTensor supports submesh, adding extra check to ensure _create_write_items()# gets called only when the current rank is part of the mesh for the corresponding DTensor.ifisinstance(obj,DTensor):ifobj.device_mesh.get_coordinate()isnotNone:requests+=_create_write_items(fqn,obj)elifisinstance(obj,(torch.Tensor))oris_coordinator:requests+=_create_write_items(fqn,obj)returnSavePlan(requests)defcreate_default_global_save_plan(all_plans:List[SavePlan],rewrite_index_hints:bool=True,)->Tuple[List[SavePlan],Metadata]:""" Create the global plan and metadata used by DefaultSavePlanner. Metadata is produced by concatenating the metadata of all ``WriteItem`` from the supplied plans. The only global planning change is to update index hints in all ``MetadataIndex`` objects if ``rewrite_index_hints`` is True. """md:Dict[str,STORAGE_TYPES]={}new_plans=[]forplaninall_plans:new_items=[]foriteminplan.items:ifnotitem.type==WriteItemType.SHARD:assertitem.index.fqnnotinmdifitem.type==WriteItemType.BYTE_IO:md[item.index.fqn]=BytesStorageMetadata()new_items.append(item)else:assertitem.tensor_dataisnotNonetensor_md=cast(TensorStorageMetadata,md.setdefault(item.index.fqn,TensorStorageMetadata(properties=item.tensor_data.properties,size=item.tensor_data.size,chunks=[],),),)new_item=itemifrewrite_index_hints:new_index=dataclasses.replace(item.index,index=len(tensor_md.chunks))new_item=dataclasses.replace(item,index=new_index)new_items.append(new_item)assert(item.tensor_data.chunkisnotNone),f""" Cannot create MD for tensor without bounds. FQN: {item.index.fqn} """tensor_md.chunks.append(item.tensor_data.chunk)new_plans.append(dataclasses.replace(plan,items=new_items))return(new_plans,Metadata(md))def_create_default_local_metadata(state_dict:STATE_DICT_TYPE)->Metadata:""" Return the ``Metadata`` if DefaultSavePlanner was used to checkpoint ``state_dict``. """plan=_create_default_metadata_only_plan(state_dict)_,md=create_default_global_save_plan([plan])returnmddef_check_box_overlap(box0:ChunkStorageMetadata,box1:ChunkStorageMetadata)->bool:""" Checks if two boxes overlap. Tuples are (offset, lengths) """# For each dim of each shard, check if one shard resides on the other# end of second shard with respect to that dim. As an example for a 2D# shard, we would check if one shard is above or on the left of the# other shard.ndims=len(box0.offsets)foriinrange(ndims):ifbox0.offsets[i]>=box1.offsets[i]+box1.sizes[i]:returnFalseifbox1.offsets[i]>=box0.offsets[i]+box0.sizes[i]:returnFalsereturnTruedef_check_box_bounds(outer_box_size:torch.Size,inner_box:ChunkStorageMetadata)->bool:foriinrange(len(outer_box_size)):ifinner_box.offsets[i]<0:returnFalseifinner_box.sizes[i]<0:returnFalseifinner_box.offsets[i]+inner_box.sizes[i]>outer_box_size[i]:returnFalsereturnTruedef_validate_global_plan(global_plan:List[SavePlan],metadata:Metadata)->bool:all_good=Trueforkey,valueinmetadata.state_dict_metadata.items():ifisinstance(value,BytesStorageMetadata):continueiflen(value.size)==0:continuechunks_volume=0forchunk_idx,chunk0inenumerate(value.chunks):# Compute the volumeifnot_check_box_bounds(value.size,chunk0):logger.warning(""" key:%s has out of bounds chunk: tensor-size:%s chunk: %s """,key,value.size,chunk0)all_good=Falsechunks_volume+=reduce(operator.mul,chunk0.sizes,1)# Check for overlapforchunk1invalue.chunks[chunk_idx+1:]:if_check_box_overlap(chunk0,chunk1):logger.warning("key:%s has overlapping chunks: %s%s",key,chunk0,chunk1)all_good=False# Check whether combined chunk cover the whole tensortensor_volume=reduce(operator.mul,value.size,1)ifchunks_volume!=tensor_volume:logger.warning(""" key:%s invalid fill tensor-volume: %s chunks-volume: %s """,key,tensor_volume,chunks_volume)all_good=Falsereturnall_good
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.