Source code for torch.distributed.checkpoint.filesystem
# mypy: allow-untyped-defsimportcollectionsimportdataclassesimportioimportoperatorimportosimportpickleimportqueueimportthreadingimportuuidimportwarningsfromabcimportABC,abstractmethodfromcontextlibimportcontextmanagerfromdataclassesimportdataclassfrompathlibimportPathfromtypingimport(Any,Callable,cast,Dict,Generator,IO,Iterable,Iterator,List,Optional,Tuple,Union,)importtorchfromtorchimportTensorfromtorch._utilsimport_get_available_device_type,_get_device_modulefromtorch.distributed._shard._utilsimportnarrow_tensor_by_indexfromtorch.distributed.checkpoint.metadataimport(Metadata,MetadataIndex,STATE_DICT_TYPE,StorageMeta,)fromtorch.distributed.checkpoint.plannerimport(LoadItemType,LoadPlan,LoadPlanner,ReadItem,SavePlan,SavePlanner,WriteItem,WriteItemType,)fromtorch.distributed.checkpoint.stagingimportBlockingAsyncStagerfromtorch.distributed.checkpoint.storageimport(StorageReader,StorageWriter,WriteResult,)fromtorch.distributed.checkpoint.utilsimport_create_file_viewfromtorch.futuresimportFuture__all__=["FileSystemWriter","FileSystemReader","FileSystem","FileSystemBase"]_metadata_fn:str=".metadata"@dataclassclass_StorageInfo:"""This is the per entry storage info."""relative_path:stroffset:intlength:int@dataclassclass_StoragePrefix:prefix:strDEFAULT_SUFFIX=".distcp"def_generate_uuid()->str:returnstr(uuid.uuid4())class_TensorLoader(ABC):@abstractmethoddefadd(self,size:int,obj:object)->None:pass@abstractmethoddefstart_loading(self)->None:pass@abstractmethoddefvalues(self)->Iterator[Tuple[torch.Tensor,object]]:passclass_SerialCpuLoader(_TensorLoader):def__init__(self,resolve_fun:Callable)->None:self.resolve_fun=resolve_funself.items:List[Tuple[int,object]]=[]defadd(self,size:int,obj:object)->None:self.items.append((size,obj))defstart_loading(self)->None:passdefvalues(self)->Iterator[Tuple[torch.Tensor,object]]:for_,objinself.items:tensor=self.resolve_fun(obj).detach()tensor=tensor.cpu()iftensor.storage().size()!=tensor.numel():tensor=tensor.clone()yield(tensor,obj,)class_OverlappingCpuLoader(_TensorLoader):def__init__(self,resolve_fun:Callable,stream:Optional[torch.Stream]=None,inflight_threshhold:int=1_000_000,)->None:self.resolve_fun=resolve_funself.items:List[Tuple[int,object]]=[]self.inflight_threshhold=inflight_threshholdself.in_flight_data=0self.current_items:collections.deque=collections.deque()self.idx=0self.started=Falseself.device_type=(stream.device_typeifstreamelse_get_available_device_type())self.device_module=_get_device_module(self.device_type)self.stream=cast(torch.cuda.Stream,streamorself.device_module.current_stream())ifself.stream!=self.device_module.current_stream():self.stream.wait_stream(self.device_module.current_stream())@propertydef_done(self)->bool:returnself.idx>=len(self.items)def_drain(self)->List[Tuple[torch.Tensor,object]]:drained=[]ifself.in_flight_data>=self.inflight_threshhold:self.stream.synchronize()whileself.in_flight_data>=self.inflight_threshhold:val=self.current_items.popleft()self.in_flight_data-=val[0].numel()*val[0].element_size()drained.append(val)returndraineddef_refill(self)->None:withself.device_module.stream(self.stream):whilenotself._doneandself.in_flight_data<self.inflight_threshhold:_,obj=self.items[self.idx]self.idx+=1tensor=self.resolve_fun(obj).detach()iftensor.device.type==self.device_type:tensor=tensor.to(device="cpu",non_blocking=True)eliftensor.device==torch.device("cpu"):if(tensor.untyped_storage().size()!=tensor.numel()*tensor.itemsize):# this forces the tensor to be both contiguous and with minimal storagetensor=tensor.clone()self.current_items.append((tensor,obj,))self.in_flight_data+=tensor.numel()*tensor.element_size()def_finish(self)->Iterable[Tuple[torch.Tensor,object]]:assertself._doneiflen(self.current_items)>0:self.stream.synchronize()returnself.current_itemsdefadd(self,size:int,obj:object)->None:ifself.started:raiseRuntimeError("cannot add items after loading started")self.items.append((size,obj))defstart_loading(self)->None:ifself.started:returnself.started=Trueself.items.sort(key=operator.itemgetter(0))self._refill()defvalues(self)->Iterator[Tuple[torch.Tensor,object]]:self.start_loading()whilenotself._done:drained=self._drain()self._refill()yield fromdrainedyield fromself._finish()def_item_size(item:WriteItem)->int:size=1assertitem.tensor_dataisnotNone# can't use math.prod as PT needs to support older pythonforsinitem.tensor_data.size:size*=sdtype=item.tensor_data.properties.dtypereturnsize*torch._utils._element_size(dtype)def_split_by_size_and_type(bins:int,items:List[WriteItem])->List[List[WriteItem]]:ifbins==1:return[items]bytes_w=[wiforwiinitemsifwi.type==WriteItemType.BYTE_IO]tensor_w=[wiforwiinitemsifwi.type!=WriteItemType.BYTE_IO]buckets:List[List[WriteItem]]=[[]for_inrange(bins)]bucket_sizes=[0for_inrange(bins)]tensor_w.sort(key=_item_size,reverse=True)fori,wiinenumerate(bytes_w):buckets[i%bins].append(wi)forwiintensor_w:# TODO replace with headqidx=min(enumerate(bucket_sizes),key=operator.itemgetter(1))[0]buckets[idx].append(wi)bucket_sizes[idx]+=_item_size(wi)returnbucketsdef_write_item(stream:io.IOBase,data:Union[io.BytesIO,torch.Tensor],write_item:WriteItem,storage_key:str,)->WriteResult:offset=stream.tell()ifwrite_item.type==WriteItemType.BYTE_IO:assertisinstance(data,io.BytesIO)stream.write(data.getbuffer())else:assertisinstance(data,torch.Tensor)assertdata.device==torch.device("cpu")torch.save(data,cast(IO[bytes],stream))length=stream.tell()-offsetreturnWriteResult(index=write_item.index,size_in_bytes=length,storage_data=_StorageInfo(storage_key,offset,length),)def_write_files_from_queue(create_stream:Callable,file_queue:queue.Queue,result_queue:queue.Queue,planner:SavePlanner,inflight_threshhold:int,use_fsync:bool,thread_count:int,)->None:try:whileTrue:file_name,storage_key,write_items=file_queue.get_nowait()loader:_TensorLoadercustom_backend_name=torch._C._get_privateuse1_backend_name()custom_device_mod=getattr(torch,custom_backend_name,None)# TODO: Using the OverlappingCpuLoader with multiple threads creates significant# performance degredation, observed as being related to cuda stream syncs. We# should try to fix this and use _OverlappingCpuLoader for all threaded casesif(thread_count==1and(torch.cuda.is_available()or(custom_device_modandcustom_device_mod.is_available()))andinflight_threshhold>0):loader=_OverlappingCpuLoader(planner.resolve_data,inflight_threshhold=inflight_threshhold,)else:loader=_SerialCpuLoader(planner.resolve_data,)tensor_w=[wiforwiinwrite_itemsifwi.type!=WriteItemType.BYTE_IO]forwrite_itemintensor_w:loader.add(_item_size(write_item),write_item)loader.start_loading()bytes_w=[wiforwiinwrite_itemsifwi.type==WriteItemType.BYTE_IO]write_results=[]withcreate_stream(file_name,"wb")asstream:forwrite_iteminbytes_w:data=planner.resolve_data(write_item)write_results.append(_write_item(stream,data,write_item,storage_key))fortensor,write_iteminloader.values():asserttensor.is_cpuwrite_results.append(_write_item(stream,tensor,write_item,storage_key))ifuse_fsync:try:os.fsync(stream.fileno())exceptAttributeError:os.sync()result_queue.put(write_results)exceptqueue.Empty:passclassFileSystemBase(ABC):@contextmanager@abstractmethoddefcreate_stream(self,path:Union[str,os.PathLike],mode:str)->Generator[io.IOBase,None,None]:...@abstractmethoddefconcat_path(self,path:Union[str,os.PathLike],suffix:str)->Union[str,os.PathLike]:...@abstractmethoddefrename(self,path:Union[str,os.PathLike],new_path:Union[str,os.PathLike])->None:...@abstractmethoddefinit_path(self,path:Union[str,os.PathLike])->Union[str,os.PathLike]:...@abstractmethoddefmkdir(self,path:Union[str,os.PathLike])->None:...@classmethod@abstractmethoddefvalidate_checkpoint_id(cls,checkpoint_id:Union[str,os.PathLike])->bool:...@abstractmethoddefexists(self,path:Union[str,os.PathLike])->bool:...@abstractmethoddefrm_file(self,path:Union[str,os.PathLike])->None:...classFileSystem(FileSystemBase):@contextmanagerdefcreate_stream(self,path:Union[str,os.PathLike],mode:str)->Generator[io.IOBase,None,None]:withcast(Path,path).open(mode)asstream:yieldcast(io.IOBase,stream)defconcat_path(self,path:Union[str,os.PathLike],suffix:str)->Union[str,os.PathLike]:returncast(Path,path)/suffixdefinit_path(self,path:Union[str,os.PathLike])->Union[str,os.PathLike]:ifnotisinstance(path,Path):path=Path(path)returnpathdefrename(self,path:Union[str,os.PathLike],new_path:Union[str,os.PathLike])->None:cast(Path,path).rename(cast(Path,new_path))defmkdir(self,path:Union[str,os.PathLike])->None:cast(Path,path).mkdir(parents=True,exist_ok=True)@classmethoddefvalidate_checkpoint_id(cls,checkpoint_id:Union[str,os.PathLike])->bool:ifisinstance(checkpoint_id,Path):returnTrueif"://"instr(checkpoint_id):returnFalseforpinPath(checkpoint_id).parents:ifp.exists()andos.access(str(p),os.W_OK):returnTruereturnFalsedefexists(self,path:Union[str,os.PathLike])->bool:returncast(Path,path).exists()defrm_file(self,path:Union[str,os.PathLike])->None:cast(Path,path).unlink()class_FileSystemWriter(StorageWriter):""" Basic implementation of StorageWriter using file IO. This implementation makes the following assumptions and simplifications: * The checkpoint path is an empty or non-existing directory. * File creation is atomic The checkpoint consist of one file per write request plus a `.metadata` file with the serialized metadata. """def__init__(self,path:Union[str,os.PathLike],single_file_per_rank:bool=True,sync_files:bool=True,thread_count:int=1,per_thread_copy_ahead:int=10_000_000,overwrite:bool=True,*args:Any,**kwargs:Any,)->None:""" Initialize the writer pointing to `path`. Args: path: directory where the checkpoint will be written to. single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. sync_files : force files to be synced to permanent storage. Default to True. thread_count: Number of IO threads to use to write. Default to 1. per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. overwrite: Whether to allow overwriting existing checkpoints. Defaults to True. N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. """super().__init__()self.fs=FileSystem()self.path=self.fs.init_path(path)self.single_file_per_rank=single_file_per_rankself.sync_files=sync_filesself.thread_count=thread_countself.per_thread_copy_ahead=per_thread_copy_aheadself.save_id=_generate_uuid()self.overwrite=overwritedefreset(self,checkpoint_id:Union[str,os.PathLike,None]=None)->None:ifcheckpoint_id:self.path=self.fs.init_path(checkpoint_id)self.save_id=_generate_uuid()defset_up_storage_writer(self,is_coordinator:bool)->None:passdefprepare_local_plan(self,plan:SavePlan)->SavePlan:self.fs.mkdir(self.path)ifself.fs.exists(self.metadata_path):ifself.overwrite:warnings.warn(f"Detected an existing checkpoint in {self.metadata_path}, overwriting since {self.overwrite=}."" Past version 2.5 of PyTorch, `overwrite` will default to False. Set this variable to True to"" maintain this functionality or False to raise when an existing checkpoint is found.")else:raiseRuntimeError(f"Checkpoint already exists and {self.overwrite=}.")returnplandefprepare_global_plan(self,plans:List[SavePlan])->List[SavePlan]:new_plans=[dataclasses.replace(plan,storage_data=_StoragePrefix(f"__{i}_"))fori,planinenumerate(plans)]returnnew_plansdefwrite_data(self,plan:SavePlan,planner:SavePlanner,)->Future[List[WriteResult]]:storage_plan:_StoragePrefix=plan.storage_datafile_count=0defgen_file():nonlocalfile_countfile_name=f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}"file_count+=1returnfile_namefile_queue:queue.Queue=queue.Queue()ifself.single_file_per_rank:forbucketin_split_by_size_and_type(self.thread_count,plan.items):file_name=gen_file()path=self.fs.concat_path(self.path,file_name)file_queue.put((path,file_name,bucket))else:foriteminplan.items:file_name=gen_file()path=self.fs.concat_path(self.path,file_name)file_queue.put((path,file_name,[item]))result_queue:queue.Queue=queue.Queue()threads=[]for_inrange(1,self.thread_count):t=threading.Thread(target=_write_files_from_queue,args=(self.fs.create_stream,file_queue,result_queue,planner,self.per_thread_copy_ahead,self.sync_files,self.thread_count,),)t.start()threads.append(t)_write_files_from_queue(create_stream=self.fs.create_stream,file_queue=file_queue,result_queue=result_queue,planner=planner,inflight_threshhold=self.per_thread_copy_ahead,use_fsync=self.sync_files,thread_count=self.thread_count,)fortinthreads:t.join()res=[]try:whileTrue:res+=result_queue.get_nowait()exceptqueue.Empty:fut:Future[List[WriteResult]]=Future()fut.set_result(res)returnfutdeffinish(self,metadata:Metadata,results:List[List[WriteResult]])->None:storage_md={}forwr_listinresults:storage_md.update({wr.index:wr.storage_dataforwrinwr_list})metadata.storage_data=storage_mdmetadata.storage_meta=self.storage_meta()tmp_path=cast(Path,self.fs.concat_path(self.path,f"{_metadata_fn}.tmp"))withself.fs.create_stream(tmp_path,"wb")asmetadata_file:pickle.dump(metadata,metadata_file)ifself.sync_files:try:os.fsync(metadata_file.fileno())exceptAttributeError:os.sync()# delete in-case other checkpoints were present.ifself.fs.exists(self.metadata_path):self.fs.rm_file(self.metadata_path)self.fs.rename(tmp_path,self.metadata_path)defstorage_meta(self)->Optional[StorageMeta]:returnStorageMeta(checkpoint_id=self.checkpoint_id,save_id=self.save_id)@propertydefmetadata_path(self)->Union[str,os.PathLike]:returncast(Path,self.fs.concat_path(self.path,_metadata_fn))@propertydefcheckpoint_id(self)->Union[str,os.PathLike]:""" return the checkpoint_id that will be used to save the checkpoint. """returnself.path@classmethoddefvalidate_checkpoint_id(cls,checkpoint_id:Union[str,os.PathLike])->bool:returnFileSystem.validate_checkpoint_id(checkpoint_id)
[docs]classFileSystemReader(StorageReader):def__init__(self,path:Union[str,os.PathLike])->None:super().__init__()self.fs=FileSystem()self.path=self.fs.init_path(path)self.storage_data:Dict[MetadataIndex,_StorageInfo]={}self.load_id=_generate_uuid()def_slice_file(self,file,sinfo:_StorageInfo)->io.IOBase:return_create_file_view(file,sinfo.offset,sinfo.length)defreset(self,checkpoint_id:Union[str,os.PathLike,None]=None)->None:self.storage_data={}ifcheckpoint_id:self.path=self.fs.init_path(checkpoint_id)self.load_id=_generate_uuid()defread_data(self,plan:LoadPlan,planner:LoadPlanner)->Future[None]:# group requests by fileper_file:Dict[str,List[ReadItem]]={}forread_iteminplan.items:item_md=self.storage_data[read_item.storage_index]path=item_md.relative_pathper_file.setdefault(path,[]).append(read_item)forrelative_path,reqsinper_file.items():new_path=self.fs.concat_path(self.path,relative_path)withself.fs.create_stream(new_path,"rb")asstream:# TODO sort by offset and cache the readingforreqinreqs:item_md=self.storage_data[req.storage_index]file_slice=self._slice_file(stream,item_md)ifreq.type==LoadItemType.BYTE_IO:read_bytes=io.BytesIO(file_slice.read(item_md.length))read_bytes.seek(0)planner.load_bytes(req,read_bytes)else:tensor=cast(Tensor,torch.load(cast(IO[bytes],file_slice),map_location="cpu",weights_only=True,),)tensor=narrow_tensor_by_index(tensor,req.storage_offsets,req.lengths)target_tensor=planner.resolve_tensor(req).detach()assert(target_tensor.size()==tensor.size()),f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"target_tensor.copy_(tensor)planner.commit_tensor(req,target_tensor)fut:Future=Future()fut.set_result(None)returnfut# Implementing the abstract function in StorageReaderdefread_metadata(self)->Metadata:path=self.fs.concat_path(self.path,".metadata")withself.fs.create_stream(path,"rb")asmetadata_file:metadata=pickle.load(metadata_file)ifgetattr(metadata,"storage_meta",None)isNone:metadata.storage_meta=StorageMeta()metadata.storage_meta.load_id=self.load_idreturnmetadatadefset_up_storage_reader(self,metadata:Metadata,is_coordinator:bool)->None:self.storage_data=metadata.storage_dataassertself.storage_dataisnotNonedefprepare_local_plan(self,plan:LoadPlan)->LoadPlan:returnplandefprepare_global_plan(self,plans:List[LoadPlan])->List[LoadPlan]:returnplans@propertydefcheckpoint_id(self)->Union[str,os.PathLike]:""" return the checkpoint_id that will be used to load the checkpoint. """returnself.path@classmethoddefvalidate_checkpoint_id(cls,checkpoint_id:Union[str,os.PathLike])->bool:returnFileSystem.validate_checkpoint_id(checkpoint_id)
[docs]classFileSystemWriter(_FileSystemWriter,BlockingAsyncStager):""" Basic implementation of StorageWriter using file IO. This implementation makes the following assumptions and simplifications: * The checkpoint path is an empty or non-existing directory. * File creation is atomic The checkpoint consist of one file per write request plus a `.metadata` file with the serialized metadata. """def__init__(self,path:Union[str,os.PathLike],single_file_per_rank:bool=True,sync_files:bool=True,thread_count:int=1,per_thread_copy_ahead:int=10_000_000,cache_staged_state_dict:bool=False,overwrite:bool=True,)->None:""" Initialize the writer pointing to `path`. Args: path: directory where the checkpoint will be written to. single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. sync_files : force files to be synced to permanent storage. Default to True. thread_count: Number of IO threads to use to write. Default to 1. per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation that the stager is maintained and re-used for multiple dcp.async_save calls. Default to False. overwrite: Whether to allow overwriting existing checkpoints. Defaults to True. N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. """_FileSystemWriter.__init__(self,path=path,single_file_per_rank=single_file_per_rank,sync_files=sync_files,thread_count=thread_count,per_thread_copy_ahead=per_thread_copy_ahead,overwrite=overwrite,)BlockingAsyncStager.__init__(self,cache_staged_state_dict=cache_staged_state_dict,)
[docs]defstage(self,state_dict:STATE_DICT_TYPE)->STATE_DICT_TYPE:"""Override of AsyncStager.stage"""# in the async case, the state dict is already on CPU, so maintaining this# buffer makes no senseself.per_thread_copy_ahead=0returnsuper().stage(state_dict)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.