[docs]defload_state_dict(state_dict:Dict[str,Any],storage_reader:StorageReader,process_group:Optional[dist.ProcessGroup]=None,coordinator_rank:int=0,no_dist:bool=False,planner:Optional[LoadPlanner]=None,)->None:""" Loads a distributed ``state_dict`` in SPMD style. Each rank will try to read the least amount of data necessary to fullfill the requested `state_dict`. When loading :class:`ShardedTensor` instances, each rank only reads data for their local shards. .. warning:: All tensors in ``state_dict`` must be allocated on their destination device *prior to* calling this function. All non-tensor data is loaded using `torch.load()` and modified in place on state_dict. .. warning:: Users must call `load_state_dict` on the root module to ensure load pos-processing and non-tensor data properly propagates. .. note: This function can be used for local inference and load a checkpoint produced by ``save_state_dict`` without having a process group initialized by passing ``no_dist=True`` and by using Tensors instead of ShardedTensors. Args: state_dict (Dict[str, Any]) : The state_dict to load. Note that this state dict will updated in place. storage_reader (StorageReader): StorageReader used to load data from. process_group (ProcessGroup): ProcessGroup to be used for cross-rank synchronization. coordinator_rank (int): Rank to use to coordinate the checkpoint. rank0 is used by default. no_dist (bool): If ``True``, distributed checkpoint will not save in SPMD style. (Default: ``False``) Returns: None. Examples >>> # xdoctest: +SKIP >>> my_model = MyModule() >>> optimizer = Adagrad(my_model.parameters()) >>> model_state_dict = my_model.state_dict() >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1") >>> torch.distributed.checkpoint.load_state_dict( >>> state_dict=model_state_dict, >>> storage_reader=fs_storage_reader, >>> ) >>> # module.load_state_dict() function might have customized steps >>> # to flush the state_dict, must call it to >>> # ensure correct behavior. >>> my_model.load_state_dict(model_state_dict) .. note:: load_state_dict uses collectives to coordinate reads across ranks. For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by ``torch.cuda.current_device()`` and it is the user's responsibility to ensure that this is set so that each rank has an individual GPU, via ``torch.cuda.set_device()``. """torch._C._log_api_usage_once("torch.distributed.checkpoint.load_state_dict")distW=_DistWrapper(process_group,notno_dist,coordinator_rank)ifplannerisNone:planner=DefaultLoadPlanner()deflocal_step():assertplannerisnotNonemetadata=storage_reader.read_metadata()planner.set_up_planner(state_dict,metadata,distW.is_coordinator)storage_reader.set_up_storage_reader(metadata,distW.is_coordinator)local_plan=planner.create_local_plan()local_plan=storage_reader.prepare_local_plan(local_plan)returnlocal_plandefglobal_step(all_local_plans):assertplannerisnotNoneall_local_plans=planner.create_global_plan(all_local_plans)all_local_plans=storage_reader.prepare_global_plan(all_local_plans)returnall_local_planscentral_plan=distW.reduce_scatter("plan",local_step,global_step)defread_data():assertplannerisnotNonefinal_local_plan=planner.finish_plan(central_plan)all_reads=storage_reader.read_data(final_local_plan,planner)all_reads.wait()returnNone_=distW.all_gather("read",read_data)
