[docs]defsave_state_dict(state_dict:STATE_DICT_TYPE,storage_writer:StorageWriter,process_group:Optional[dist.ProcessGroup]=None,coordinator_rank:int=0,no_dist:bool=False,planner:Optional[SavePlanner]=None,)->Metadata:""" Saves a distributed model in SPMD style. This function is different from ``torch.save()`` as it handles ``ShardedTensor`` by having each rank only save their local shards. .. warning:: There is no guarantees of Backwards Compatibility across PyTorch versions for saved state_dicts. .. warning:: If using the `process_group` argument, make sure that only its ranks call `save_state_dict` and that all data in state_dict belong to it. .. note:: When saving checkpoint for FSDP's `ShardingStrategy.HYBRID_SHARD`, only one of the shard_group should be calling `save_state_dict` and the corresponding process group needs to be passed in. .. note:: This function can be used to save a state_dict without having a process group initialized by passing ``no_dist=True``. Args: state_dict (Dict[str, Any]): The state_dict to save. storage_writer (StorageWriter): Instance of StorageWrite use to perform writes. process_group (ProcessGroup): ProcessGroup to be used for cross-rank synchronization. coordinator_rank (int): Rank to use to coordinate the checkpoint. rank0 is used by default. no_dist (bool): If ``True``, distributed checkpoint will not save in SPMD style. (Default: ``False``) Returns: Metadata: Metadata object for the saved checkpoint. Example: >>> # xdoctest: +SKIP >>> my_model = MyModule() >>> model_state_dict = my_model.state_dict() >>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1") >>> torch.distributed.checkpoint.save_state_dict( >>> state_dict=model_state_dict, >>> storage_writer=fs_storage_writer, >>> ) .. note:: save_state_dict uses collectives to coordinate writes across ranks. For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by ``torch.cuda.current_device()`` and it is the user's responsibility to ensure that this is set so that each rank has an individual GPU, via ``torch.cuda.set_device()``. """torch._C._log_api_usage_once("torch.distributed.checkpoint.save_state_dict")distW=_DistWrapper(process_group,notno_dist,coordinator_rank)ifplannerisNone:planner=DefaultSavePlanner()assertplannerisnotNoneglobal_metatadata=Nonedeflocal_step():assertplannerisnotNoneplanner.set_up_planner(state_dict,distW.is_coordinator)storage_writer.set_up_storage_writer(distW.is_coordinator)local_plan=planner.create_local_plan()local_plan=storage_writer.prepare_local_plan(local_plan)returnlocal_plandefglobal_step(all_local_plans):nonlocalglobal_metatadataassertplannerisnotNoneall_local_plans,global_metatadata=planner.create_global_plan(all_local_plans)all_local_plans=storage_writer.prepare_global_plan(all_local_plans)returnall_local_planscentral_plan=distW.reduce_scatter("plan",local_step,global_step)defwrite_data():assertplannerisnotNonefinal_local_plan=planner.finish_plan(central_plan)all_writes=storage_writer.write_data(final_local_plan,planner)all_writes.wait()returnall_writes.value()deffinish_checkpoint(all_results):assertglobal_metatadataisnotNonestorage_writer.finish(metadata=global_metatadata,results=all_results)returnglobal_metatadatareturndistW.all_reduce("write",write_data,finish_checkpoint)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.