[docs]defload_state_dict(state_dict:Dict[str,Any],storage_reader:StorageReader,process_group:Optional[dist.ProcessGroup]=None,coordinator_rank:int=0,no_dist:bool=False,planner:Optional[LoadPlanner]=None,)->None:""" Loads a distributed ``state_dict`` in SPMD style. Each rank will try to read the least amount of data necessary to fullfill the requested `state_dict`. When loading :class:`ShardedTensor` instances, each rank only reads data for their local shards. .. warning:: All tensors in ``state_dict`` must be allocated on their destination device *prior to* calling this function. All non-tensor data is loaded using `torch.load()` and modified in place on state_dict. .. warning:: Users must call `load_state_dict` on the root module to ensure load pos-processing and non-tensor data properly propagates. .. note: This function can be used for local inference and load a checkpoint produced by ``save_state_dict`` without having a process group initialized by passing ``no_dist=True`` and by using Tensors instead of ShardedTensors. Args: state_dict (Dict[str, Any]) : The state_dict to load. Note that this state dict will updated in place. storage_reader (StorageReader): StorageReader used to load data from. process_group (ProcessGroup): ProcessGroup to be used for cross-rank synchronization. coordinator_rank (int): Rank to use to coordinate the checkpoint. rank0 is used by default. no_dist (bool): If ``True``, distributed checkpoint will not save in SPMD style. (Default: ``False``) Returns: None. Examples >>> # xdoctest: +SKIP >>> my_model = MyModule() >>> optimizer = Adagrad(my_model.parameters()) >>> model_state_dict = my_model.state_dict() >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1") >>> torch.distributed.checkpoint.load_state_dict( >>> state_dict=model_state_dict, >>> storage_reader=fs_storage_reader, >>> ) >>> # module.load_state_dict() function might have customized steps >>> # to flush the state_dict, must call it to >>> # ensure correct behavior. >>> my_model.load_state_dict(model_state_dict) .. note:: load_state_dict uses collectives to coordinate reads across ranks. For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by ``torch.cuda.current_device()`` and it is the user's responsibility to ensure that this is set so that each rank has an individual GPU, via ``torch.cuda.set_device()``. """torch._C._log_api_usage_once("torch.distributed.checkpoint.load_state_dict")distW=_DistWrapper(process_group,notno_dist,coordinator_rank)ifplannerisNone:planner=DefaultLoadPlanner()deflocal_step():assertplannerisnotNonemetadata=storage_reader.read_metadata()planner.set_up_planner(state_dict,metadata,distW.is_coordinator)storage_reader.set_up_storage_reader(metadata,distW.is_coordinator)local_plan=planner.create_local_plan()local_plan=storage_reader.prepare_local_plan(local_plan)returnlocal_plandefglobal_step(all_local_plans):assertplannerisnotNoneall_local_plans=planner.create_global_plan(all_local_plans)all_local_plans=storage_reader.prepare_global_plan(all_local_plans)returnall_local_planscentral_plan=distW.reduce_scatter("plan",local_step,global_step)defread_data():assertplannerisnotNonefinal_local_plan=planner.finish_plan(central_plan)all_reads=storage_reader.read_data(final_local_plan,planner)all_reads.wait()returnNone_=distW.all_gather("read",read_data)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.