[docs]@deprecated("`save_state_dict` is deprecated and will be removed in future versions.""Please use `save` instead.",category=FutureWarning,)defsave_state_dict(state_dict:STATE_DICT_TYPE,storage_writer:StorageWriter,process_group:Optional[dist.ProcessGroup]=None,coordinator_rank:int=0,no_dist:bool=False,planner:Optional[SavePlanner]=None,)->Metadata:"""This method is deprecated. Please switch to 'save'."""storage_writer.reset()# TODO: test returning `save` here instead.with_profile():return_save_state_dict(state_dict,storage_writer,process_group,coordinator_rank,no_dist,planner,)
[docs]@_dcp_method_logger(log_exceptions=True)# type: ignore[arg-type]@_api_bc_checkdefsave(state_dict:STATE_DICT_TYPE,*,checkpoint_id:Union[str,os.PathLike,None]=None,storage_writer:Optional[StorageWriter]=None,planner:Optional[SavePlanner]=None,process_group:Optional[dist.ProcessGroup]=None,)->Metadata:""" Save a distributed model in SPMD style. This function is different from ``torch.save()`` as it handles ``ShardedTensor`` , and ``DTensor`` by having each rank only save their local shards. For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``), save will call ``state_dict`` before serialization. .. warning:: There is no guarantees of Backwards Compatibility across PyTorch versions for saved state_dicts. .. warning:: If using the `process_group` argument, make sure that only its ranks call `save_state_dict` and that all data in state_dict belong to it. .. note:: When saving checkpoint for FSDP's `ShardingStrategy.HYBRID_SHARD`, only one of the shard_group should be calling `save_state_dict` and the corresponding process group needs to be passed in. .. note:: If no process group is available, this function assumes the intention is to save the state_dict in the local process. .. note: Rank 0 is assumed to be the coordinator rank. Args: state_dict (Dict[str, Any]): The state_dict to save. checkpoint_id (Union[str, os.PathLike, None]): The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. It can also be a key if the storage is a key-value store. (Default: ``None``) storage_writer (Optional[StorageWriter]): Instance of StorageWriter used to perform writes. If this is not specified, DCP will automatically infer the writer based on the checkpoint_id. If checkpoint_id is also None, an exception will be raised. (Default: ``None``) planner (Optional[SavePlanner]): Instance of SavePlanner. If this is not specificed, the default planner will be used. (Default: ``None``) process_group (Optional[ProcessGroup]): ProcessGroup to be used for cross-rank synchronization. (Default: ``None``) Returns: Metadata: Metadata object for the saved checkpoint. Example: >>> # xdoctest: +SKIP >>> my_model = MyModule() >>> state_dict = {"model": my_model} >>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1") >>> torch.distributed.checkpoint.save( >>> state_dict=state_dict, >>> storage_writer=fs_storage_writer, >>> ) .. note:: save_state_dict uses collectives to coordinate writes across ranks. For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by ``torch.cuda.current_device()`` and it is the user's responsibility to ensure that this is set so that each rank has an individual GPU, via ``torch.cuda.set_device()``. """torch._C._log_api_usage_once("torch.distributed.checkpoint.save")no_dist=not(dist.is_available()anddist.is_initialized())ifno_dist:warnings.warn("torch.distributed is unavailable or uninitialized, assuming the intent is to save in a single process.")with_profile():storage_writer=cast(StorageWriter,_storage_setup(storage_writer,checkpoint_id,reader=False))return_save_state_dict(state_dict=_stateful_to_state_dict(state_dict),storage_writer=storage_writer,process_group=process_group,no_dist=no_dist,planner=planner,)
[docs]@_dcp_method_logger(log_exceptions=True)defasync_save(state_dict:STATE_DICT_TYPE,*,checkpoint_id:Union[str,os.PathLike,None]=None,storage_writer:Optional[StorageWriter]=None,planner:Optional[SavePlanner]=None,process_group:Optional[dist.ProcessGroup]=None,)->Future:"""Asynchronous version of ``save``. This code first de-stages the state_dict on to the staging storage (defaults to CPU memory), and then calls the `save` in a separate thread. .. warning:: This feature is experimental and subject to change. Args: state_dict (Dict[str, Any]): The state_dict to save. checkpoint_id (Union[str, os.PathLike, None]): The ID of this checkpoint instance. The meaning of the checkpoint_id depends on the storage. It can be a path to a folder or to a file. It can also be a key if the storage is a key-value store. (Default: ``None``) storage_writer (Optional[StorageWriter]): Instance of StorageWriter used to perform 'stage' and 'save'. If this is not specified, DCP will automatically infer the writer based on the checkpoint_id. If checkpoint_id is also None, an exception will be raised. (Default: ``None``) planner (Optional[SavePlanner]): Instance of SavePlanner. If this is not specificed, the default planner will be used. (Default: ``None``) process_group (Optional[ProcessGroup]): ProcessGroup to be used for cross-rank synchronization. (Default: ``None``) Returns: Future: A future holding the resultant Metadata object from `save`. Example: >>> # xdoctest: +SKIP >>> my_model = MyModule() >>> state_dict = {"model": my_model} >>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1") >>> checkpoint_future = torch.distributed.checkpoint.async_save( >>> state_dict=state_dict, >>> storage_writer=fs_storage_writer, >>> ) >>> >>> # ... do some work ... >>> >>> checkpoint_future.result() """torch._C._log_api_usage_once("torch.distributed.checkpoint.async_save")ifdist.is_available()anddist.is_initialized():pg=process_groupor_get_default_group()assert(torch.device("cpu")inpg._device_types# type: ignore[attr-defined]),"A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"storage_writer=cast(StorageWriter,_storage_setup(storage_writer,checkpoint_id,reader=False))state_dict=_stateful_to_state_dict(state_dict)ifisinstance(storage_writer,AsyncStager):staged_state_dict=storage_writer.stage(state_dict)else:# provides bwc for storage_writers not implementing AsyncStagerstaged_state_dict=_offload_state_dict_to_cpu(state_dict,type_check=False)executor=ThreadPoolExecutor(max_workers=1)f:Future=executor.submit(save,staged_state_dict,checkpoint_id=checkpoint_id,storage_writer=storage_writer,planner=planner,process_group=process_group,)f.add_done_callback(lambdaf:executor.shutdown(wait=False))if(isinstance(storage_writer,AsyncStager)andstorage_writer.should_synchronize_after_execute):storage_writer.synchronize_staging()returnf
def_stateful_to_state_dict(state_dict:STATE_DICT_TYPE)->STATE_DICT_TYPE:"""Creates a shallow copy of `state_dict` where `state_dict` is called for each Stateful object."""stateful_state_dict={}forkey,eleminstate_dict.items():stateful_state_dict[key]=(elem.state_dict()ifisinstance(elem,Stateful)elseelem)returnstateful_state_dictdef_save_state_dict(state_dict:STATE_DICT_TYPE,storage_writer:StorageWriter,process_group:Optional[dist.ProcessGroup]=None,coordinator_rank:int=0,no_dist:bool=False,planner:Optional[SavePlanner]=None,)->Metadata:torch._C._log_api_usage_once("torch.distributed.checkpoint.save_state_dict")distW=_DistWrapper(process_group,notno_dist,coordinator_rank)ifplannerisNone:planner=DefaultSavePlanner()assertplannerisnotNoneglobal_metadata=Noneckpt_kwargs={}if(ckpt_id:=getattr(storage_writer,"checkpoint_id",None))isnotNone:ckpt_kwargs["checkpoint_id"]=ckpt_idckpt_kwargs["process_group"]=distW.group@_dcp_method_logger(**ckpt_kwargs)deflocal_step():assertplannerisnotNonestorage_meta=storage_writer.storage_meta()if"storage_meta"notininspect.signature(planner.set_up_planner).parameters:warnings.warn("The function definition for SavePlanner.set_up_planner has been updated"" to include the storage_meta argument. Please update your implementation"" to include this parameter.")planner.set_up_planner(state_dict,distW.is_coordinator)# type: ignore[call-arg, arg-type]else:planner.set_up_planner(state_dict=state_dict,storage_meta=storage_meta,is_coordinator=distW.is_coordinator,)storage_writer.set_up_storage_writer(distW.is_coordinator)local_plan=planner.create_local_plan()local_plan=storage_writer.prepare_local_plan(local_plan)returnlocal_plan@_dcp_method_logger(**ckpt_kwargs)defglobal_step(all_local_plans):nonlocalglobal_metadataassertplannerisnotNoneall_local_plans,global_metadata=planner.create_global_plan(all_local_plans)all_local_plans=storage_writer.prepare_global_plan(all_local_plans)returnall_local_planscentral_plan:SavePlan=distW.reduce_scatter("plan",local_step,global_step)@_dcp_method_logger(**ckpt_kwargs)defwrite_data():assertplannerisnotNonefinal_local_plan=planner.finish_plan(central_plan)all_writes=storage_writer.write_data(final_local_plan,planner)all_writes.wait()returnall_writes.value()@_dcp_method_logger(**ckpt_kwargs)deffinish_checkpoint(all_results):assertglobal_metadataisnotNonestorage_writer.finish(metadata=global_metadata,results=all_results)returnglobal_metadatareturndistW.all_reduce("write",write_data,finish_checkpoint)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.