Source code for torch.distributed.checkpoint.state_dict
importcontextlibimportfunctoolsimportgcfromdataclassesimportasdict,dataclass,fieldfromitertoolsimportchainfromtypingimport(Any,Callable,cast,Dict,Generator,Iterable,List,no_type_check,Optional,Set,Tuple,Union,)importtorchimporttorch.distributedasdistimporttorch.nnasnnfromtorch.distributed._shard.sharded_tensorimportShardedTensorfromtorch.distributed._state_dict_utilsimport(_gather_state_dict,_offload_state_dict_to_cpu,)fromtorch.distributed._tensorimportDTensorfromtorch.distributed.algorithms._checkpoint.checkpoint_wrapperimport(_CHECKPOINT_PREFIX,)fromtorch.distributed.fsdpimport(FullOptimStateDictConfig,FullStateDictConfig,FullyShardedDataParallelasFSDP,OptimStateDictConfig,ShardedOptimStateDictConfig,ShardedStateDictConfig,StateDictConfig,StateDictType,)fromtorch.distributed.fsdp._common_utilsimport(_get_module_fsdp_state_if_fully_sharded_module,FSDP_WRAPPED_MODULE,)fromtorch.nn.modules.moduleimport_IncompatibleKeysfromtorch.nn.parallelimportDistributedDataParallelasDDPFLAT_PARAM="_flat_param"PG="param_groups"PG_PREFIX=f"{PG}."STATE="state"STATE_PREFIX=f"{STATE}."PARAMS="params"FQNS_T=Set[str]_patched_state_dict:Set[Callable]=set()PrimitiveType=Union[DTensor,ShardedTensor,torch.Tensor,int,float,str]ValueType=Union[PrimitiveType,List[PrimitiveType],Tuple[PrimitiveType],Dict[str,"ValueType"]]DictValueType=Dict[str,ValueType]ListDictValueType=List[DictValueType]OptimizerStateType=Dict[str,Union[DictValueType,ListDictValueType]]@contextlib.contextmanagerdefgc_context():is_enabled=gc.isenabled()gc.disable()try:yieldfinally:# TODO: add logging for the gc details/timegc.collect()ifis_enabled:gc.enable()
[docs]@dataclassclassStateDictOptions:""" This dataclass specifies how get_state_dict/set_state_dict will work. - ``full_state_dict``: if this is set to True, all the tensors in the returned state_dict will be gathered. No ShardedTensor and DTensor will be in the returned state_dict. - ``cpu_offload``: offload all the tensors to cpu. To prevent CPU OOM, if ``full_state_dict`` is also true, then only the rank0 will get the state_dict and all other ranks will get empty state_dict. - ``ignore_frozen_params``: if the value is True, the returned state_dict won't contain any frozen parameters -- the ``requires_grad`` is False. The default value is False. - ``keep_submodule_prefixes``: when ``submodules`` is not None, this option indicates whether to keep the submodule prefixes from the state_dict keys. or example, if the submodule is ``module.pretrain`` and the full FQN of the parameter is ``pretrain.layer1.weight`` of the param. When this option is True, the parameter's key in the returned state_dict will be ``pretrain.layer1.weight``. If the options is False, the key will be ``layer1.weight``. Note that if ``keep_submodule_prefixes`` is False, there may be conflicted FQNs, hence there should be only one submodule in ``submodules``. - ``strict``: the ``strict`` option when ``set_state_dict`` calls model.load_state_dict(). The default value is False. """full_state_dict:bool=Falsecpu_offload:bool=Falseignore_frozen_params:bool=Falsekeep_submodule_prefixes:bool=Truestrict:bool=True
@dataclassclass_StateDictInfo(StateDictOptions):fqn_param_mapping:Dict[Union[str,torch.Tensor],Union[FQNS_T,torch.Tensor]]=field(default_factory=dict)all_fqns:Set[str]=field(default_factory=set)submodule_prefixes:Set[str]=field(default_factory=set)handle_model:bool=Truehandle_optim:bool=Truefsdp_context:Callable=contextlib.nullcontextfsdp_modules:List[nn.Module]=field(default_factory=list)def_get_fqns(model:nn.Module,name:str,skip_ddp_prefix:bool=True,skip_compiler_prefix:bool=True,)->FQNS_T:""" This API is used to convert the name of a parameter to the FQNs. For FSDP without `use_orig_params`, the name of FlatParameter can be mapped to multiple original parameters. As a result, the return type of this function is `Set[str]`. Args: module (nn.Module): the root model. name (str): the name skip_ddp_prefix (bool): whether to skip DDP's `module` prefix Returns: The canonical FQNs based on the model traversal. """# Remove the checkpoint prefix, if it exists.name=name.replace(_CHECKPOINT_PREFIX,"")if"."notinname:return{name}obj_names=name.split(".")fqn_obj_names=[]curr_obj=modelfori,curr_obj_nameinenumerate(obj_names):ifisinstance(curr_obj,DDP):assertcurr_obj_name=="module"curr_obj=curr_obj.moduleifnotskip_ddp_prefix:fqn_obj_names.append(curr_obj_name)elifisinstance(curr_obj,FSDP):ifi<len(obj_names)-1andobj_names[i+1]==FLAT_PARAM:prefix=".".join(fqn_obj_names)flat_param=getattr(curr_obj,FLAT_PARAM)ifprefix:prefix=f"{prefix}."return{f"{prefix}{fqn}"forfqninflat_param._fqns}curr_obj=getattr(curr_obj,FSDP_WRAPPED_MODULE)ifcurr_obj_name!=FSDP_WRAPPED_MODULE:fqn_obj_names.append(curr_obj_name)curr_obj=getattr(curr_obj,curr_obj_name)elifisinstance(curr_obj,torch._dynamo.eval_frame.OptimizedModule):assertcurr_obj_name=="_orig_mod"curr_obj=curr_obj._orig_modifnotskip_compiler_prefix:fqn_obj_names.append(curr_obj_name)else:fqn_obj_names.append(curr_obj_name)ifcurr_obj_name==nn.modules.module._EXTRA_STATE_KEY_SUFFIX:ifi!=len(obj_names)-1:raiseRuntimeError("Expect `_extra_state` to be the last obj name")else:curr_obj=getattr(curr_obj,curr_obj_name)return{".".join(fqn_obj_names).replace(_CHECKPOINT_PREFIX,"")}class_EXTRA_STATE:passdef_iterate_valid_model_state(model):visited_modules:Set[nn.Module]=set()defrecurse(module:nn.Module,curr_fqn:str)->Generator:visited_modules.add(module)curr_fqn=f"{curr_fqn}."ifcurr_fqnelse""forname,submoduleinmodule.named_children():ifsubmoduleinvisited_modules:continuenew_fqn=f"{curr_fqn}{name}"yield fromrecurse(submodule,new_fqn)forname,objinchain(module.named_buffers(recurse=False),module.named_parameters(recurse=False)):ifnameinmodule._non_persistent_buffers_set:continuenew_fqn=f"{curr_fqn}{name}"yieldnew_fqn,objif(getattr(module.__class__,"get_extra_state",nn.Module.get_extra_state)!=nn.Module.get_extra_state):new_fqn=f"{curr_fqn}{nn.modules.module._EXTRA_STATE_KEY_SUFFIX}"yieldnew_fqn,_EXTRA_STATE()yield fromrecurse(model,"")def_verify_options(model:nn.Module,optims:Tuple[torch.optim.Optimizer,...],optim_only:bool,*,submodules:Optional[Set[nn.Module]]=None,options:Optional[StateDictOptions]=None,)->_StateDictInfo:""" Verify the model and options passed by the user and generates _StateDictInfo. """ifoptim_onlyandnotoptims:raiseRuntimeError("Optimizers are not passed in but optim_only is set to True.")options=optionsorStateDictOptions()fqn_param_mapping:Dict[Union[str,torch.Tensor],Union[Set[str],torch.Tensor]]={}all_fqns=set()forname,paramin_iterate_valid_model_state(model):fqns=_get_fqns(model,name)ifnotisinstance(param,_EXTRA_STATE):fqn_param_mapping[param]=fqnsforfqninfqns:ifnotisinstance(param,_EXTRA_STATE):fqn_param_mapping[fqn]=paramall_fqns.add(fqn)submodule_prefixes=set()ifsubmodules:submodules=set(submodules)forname,moduleinmodel.named_modules():ifmodulenotinsubmodules:continuefqns=_get_fqns(model,name)assertlen(fqns)==1,"Submodule FQN should only have 1 instance"forfqninfqns:submodule_prefixes.add(f"{fqn}.")fsdp_modules=FSDP.fsdp_modules(model)state_dict_config:StateDictConfigoptim_state_dict_config:OptimStateDictConfigfsdp_context:Callableiffsdp_modules:# FSDP API only work if at least one FSDP instance exists.ifoptions.full_state_dict:state_dict_config=FullStateDictConfig(offload_to_cpu=options.cpu_offload,rank0_only=options.cpu_offload)optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=options.cpu_offload,rank0_only=options.cpu_offload)state_dict_type=StateDictType.FULL_STATE_DICTelse:state_dict_config=ShardedStateDictConfig(offload_to_cpu=options.cpu_offload,)optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=options.cpu_offload,)state_dict_type=StateDictType.SHARDED_STATE_DICTfsdp_context=functools.partial(FSDP.state_dict_type,module=model,state_dict_type=state_dict_type,state_dict_config=state_dict_config,optim_state_dict_config=optim_state_dict_config,)else:fsdp_context=contextlib.nullcontextreturn_StateDictInfo(**asdict(options),fqn_param_mapping=fqn_param_mapping,all_fqns=all_fqns,submodule_prefixes=submodule_prefixes,fsdp_context=fsdp_context,fsdp_modules=cast(List[nn.Module],fsdp_modules),handle_model=notoptim_only,handle_optim=(len(optims)>0),)def_verify_state_dict(model_state_dict:Dict[str,ValueType],optim_state_dict:OptimizerStateType,info:_StateDictInfo,)->None:formoduleininfo.fsdp_modules:fsdp_state=_get_module_fsdp_state_if_fully_sharded_module(module)assertfsdp_stateisnotNone,"Expected a fsdp_state with a fsdp module."# Verify if the model_state_dict and optim_state_dict are valid. This API# should give the users an explicit error message to debug or report.if(info.handle_modelandnotmodel_state_dictandnotinfo.submodule_prefixesandnotinfo.ignore_frozen_paramsandnot(info.cpu_offloadandinfo.full_state_dict)andinfo.strict):raiseRuntimeError("The option indicates that model state_dict is required to save ""or load, but model state_dict is empty."f"rank = {dist.get_rank()=}.")ifinfo.handle_optim:ifnot(optim_state_dictandoptim_state_dict[STATE])andnot(info.cpu_offloadandinfo.full_state_dict):raiseRuntimeError("The option indicates that model state_dict is required to save, "f"or load but optim state_dict is empty. {optim_state_dict}")forkeyinmodel_state_dict.keys():ifFLAT_PARAMinkey:raiseRuntimeError(f"{key} contains {FLAT_PARAM}. This can happen if the model ""is not the root module.")def_state_dict_fn(obj:Union[nn.Module,torch.optim.Optimizer],api:str)->Callable:call=getattr(obj,api)ifcallin_patched_state_dict:call=functools.partial(getattr(obj.__class__,api),self=obj)returncalldef_get_model_state_dict(model:nn.Module,info:_StateDictInfo)->Dict[str,ValueType]:ifnotinfo.handle_model:return{}withinfo.fsdp_context():state_dict=_state_dict_fn(model,"state_dict")()forkeyinlist(state_dict.keys()):fqns=_get_fqns(model,key)assertlen(fqns)==1fqn=next(iter(fqns))iffqn!=key:# As we only support FSDP, DDP, and TP, the only cases are# wrapper-based DDP and compiler. Verify if the assumption# is correct.defverify(key,fqn)->bool:iflen(fqn)>=len(key):returnFalsefqn_split=fqn.split(".")key_split=key.split(".")fqn_idx=0forkey_idx,key_nameinenumerate(key_split):ifkey_name==fqn_split[fqn_idx]:fqn_idx+=1iffqn_idx==len(fqn_split):returnkey_idx==len(key_split)-1elifkey_namein("module","_orig_mod"):continueelse:returnFalsereturnTrueifnotverify(key,fqn):raiseRuntimeError(f"An unexpected key, {key}, exists. FQN is {fqn}")state_dict[fqn]=state_dict.pop(key)ifinfo.submodule_prefixes:new_state_dict:Dict[str,ValueType]={}# TODO: make this faster.forfqninstate_dict.keys():forprefixininfo.submodule_prefixes:ifnotfqn.startswith(prefix):continueifinfo.keep_submodule_prefixes:new_state_dict[fqn]=state_dict[fqn]else:new_fqn=fqn[len(prefix):]new_state_dict[new_fqn]=state_dict[fqn]state_dict=new_state_dictifinfo.ignore_frozen_params:forkey,paraminmodel.named_parameters():ifparam.requires_grad:continuefqns=_get_fqns(model,key)forfqninfqns:state_dict.pop(fqn)forkey,pinlist(state_dict.items()):iftorch.is_tensor(p)andp.is_meta:state_dict.pop(key)ifinfo.full_state_dict:ranks_only=tuple()ifnotinfo.cpu_offloadelse(0,)return_gather_state_dict(state_dict,cpu_offload=info.cpu_offload,ranks_only=ranks_only)elifinfo.cpu_offload:return_offload_state_dict_to_cpu(state_dict)else:returnstate_dictdef_load_model_state_dict(model:nn.Module,state_dict:Dict[str,ValueType],info:_StateDictInfo,)->_IncompatibleKeys:ifnotinfo.handle_modelornotstate_dict:return_IncompatibleKeys({},{})forkey,_in_iterate_valid_model_state(model):fqns=_get_fqns(model,key)fqns_with_prefix=_get_fqns(model,key,skip_ddp_prefix=False,skip_compiler_prefix=False)forfqn,fqn_with_prefixinzip(fqns,fqns_with_prefix):iffqn!=fqn_with_prefix:state_dict[fqn_with_prefix]=state_dict.pop(fqn)withinfo.fsdp_context():returncast(_IncompatibleKeys,_state_dict_fn(model,"load_state_dict")(state_dict=state_dict,strict=info.strict),)def_init_optim_state(optim:torch.optim.Optimizer)->None:""" Initialize optim states by calling the step() with zero grads. """ifoptim.state:# The optimizer state is initialized.returnforparam_groupinoptim.param_groups:forparaminparam_group[PARAMS]:ifparam.gradisnotNone:raiseRuntimeError("state_dict can only be used if the optimizer ""states are initialized (usually after one step() with ""gradients) or gradients are None. For the later case, ""state_dict will fake the gradients as zero ""to initialize the optimizer states. However, the ""gradients are not None.")ifparam.requires_grad:param.grad=torch.zeros_like(param)optim.step(closure=None)optim.zero_grad(set_to_none=True)def_get_optim_state_dict(model:nn.Module,optimizers:Tuple[torch.optim.Optimizer,...],info:_StateDictInfo,)->OptimizerStateType:ifnotinfo.handle_optim:return{}optim_state_dict:OptimizerStateType={STATE:{},PG:[]}foroptiminoptimizers:_init_optim_state(optim)osd=_state_dict_fn(optim,"state_dict")()ifinfo.fsdp_modules:withinfo.fsdp_context():osd=FSDP.optim_state_dict(model,optim,osd)# We need to specially handle FlatParameter FSDP as# FlatParameter FSDP converts the FQNs.# There are no easy ways to do this conversion systematically.# We can only use a string replacment without correctness check.ifnotosd:continueforkinlist(osd[STATE].keys()):if"_orig_mod"ink:osd[STATE][k.replace("_orig_mod.","")]=osd[STATE].pop(k)forginosd[PG]:params=[k.replace("_orig_mod.","")forking[PARAMS]]g[PARAMS]=paramselse:params=list(chain.from_iterable(g[PARAMS]forginoptim.param_groups))param_pid_mapping=dict(zip(params,range(len(params))))fqn_pid_mapping={}forkey,paraminmodel.named_parameters():fqns=_get_fqns(model,key)assertlen(fqns)==1fqn=next(iter(fqns))ifparamnotinparam_pid_mapping:continuepid=param_pid_mapping[param]fqn_pid_mapping[fqn]=pidfqn_pid_mapping[pid]=fqnforkeyinlist(osd[STATE].keys()):fqn=fqn_pid_mapping[key]osd[STATE][fqn]=osd[STATE].pop(key)forgroupinosd[PG]:group[PARAMS]=[fqn_pid_mapping[pid]forpidingroup[PARAMS]]ifnotosd:continuecast(DictValueType,optim_state_dict[STATE]).update(osd[STATE])cast(ListDictValueType,optim_state_dict[PG]).extend(osd[PG])ifinfo.full_state_dict:ranks_only=tuple()ifnotinfo.cpu_offloadelse(0,)return_gather_state_dict(optim_state_dict,cpu_offload=info.cpu_offload,ranks_only=ranks_only)elifinfo.cpu_offload:return_offload_state_dict_to_cpu(optim_state_dict)else:returnoptim_state_dictdef_split_optim_state_dict(model:nn.Module,optim:torch.optim.Optimizer,optim_state_dict:OptimizerStateType,info:_StateDictInfo,)->OptimizerStateType:""" Extract the corresponding optim state_dict from ``optim_state_dict`` for ``optim`` and return the result optim state_dict. Args: model (nn.Module): the root model. optim (torch.optim.Optimizer): the optimizer. optim_state_dict (Dict[str, ValueType]): the superset optim state_dict that contains the optim state_dict of ``optim``. info (_StateDictInfo): state dict information. Returns: The optim state_dict of ``optim``. """state:DictValueType={}pg_state:ListDictValueType=[]return_osd:OptimizerStateType={STATE:state,PG:pg_state}pg_mapping:Dict[int,int]={}forparam_groupinoptim.param_groups:pg_state.append({PARAMS:[]})forparaminparam_group[PARAMS]:forfqnininfo.fqn_param_mapping[param]:params=pg_state[-1][PARAMS]assertisinstance(params,list)params.append(fqn)ifparam.requires_grad:state[fqn]=cast(DictValueType,optim_state_dict[STATE])[fqn]forloaded_param_groupincast(ListDictValueType,optim_state_dict[PG]):params=loaded_param_group[PARAMS]assertisinstance(params,list)iffqninparams:pg_mapping[id(loaded_param_group)]=len(return_osd[PG])-1forparam_groupincast(ListDictValueType,optim_state_dict[PG]):idx=pg_mapping.get(id(param_group),-1)ifidx==-1:continueforkey,valueinparam_group.items():ifkey==PARAMS:continue# TODO: check if value is the same if exists.pg_state[idx][key]=valuereturnreturn_osddef_load_optim_state_dict(model:nn.Module,optimizers:Tuple[torch.optim.Optimizer,...],state_dict:OptimizerStateType,info:_StateDictInfo,)->None:ifnotinfo.handle_optim:returnforoptiminoptimizers:optim_state_dict=_split_optim_state_dict(model,optim,state_dict,info)ifinfo.fsdp_modules:# We need to specially handle FlatParameter FSDP as# FlatParameter FSDP converts the FQNs.fororiginal_fqn,_inmodel.named_parameters():fqns=_get_fqns(model,original_fqn)fqns_with_compiler=_get_fqns(model,original_fqn,skip_compiler_prefix=False)iffqns==fqns_with_compiler:continueassertlen(fqns)==1fqn=fqns.pop()fqn_with_compiler=fqns_with_compiler.pop()forginoptim_state_dict[PG]:val=cast(Dict[str,Any],g)params=[key.replace(fqn,fqn_with_compiler)forkeyinval[PARAMS]]val[PARAMS]=paramsosd_state=cast(DictValueType,optim_state_dict[STATE])forkinlist(osd_state.keys()):iffqnink:osd_state[k.replace(fqn,fqn_with_compiler)]=osd_state.pop(k)withinfo.fsdp_context():optim_state_dict=FSDP.optim_state_dict_to_load(model,optim,optim_state_dict)# Note that we do not have to convert the FQN back to param id here if# order in optim.param_groups[idx][PARAMS] is the same as the one in# optim_state_dict[PG][idx][PARAMS]._init_optim_state(optim)_state_dict_fn(optim,"load_state_dict")(state_dict=optim_state_dict)
[docs]defget_model_state_dict(model:nn.Module,*,submodules:Optional[Set[nn.Module]]=None,options:Optional[StateDictOptions]=None,)->Dict[str,ValueType]:""" Return the model state_dict of ``model``. See ``get_state_dict`` for the detail usage. Args: model (nn.Module): the nn.Module to the model. submodules: Optional[Set[nn.Module]]: only return the model parameters that belong to the submodules. options (StateDictOptions): the options to control how model state_dict and optimizer state_dict should be returned. See `StateDictOptions` for the details. Returns: The state_dict for ``model``. :rtype: typing.Dict[str, ValueType] """withgc_context():info=_verify_options(model,tuple(),optim_only=False,submodules=submodules,options=options,)model_state_dict=_get_model_state_dict(model,info)_verify_state_dict(model_state_dict,{},info)returnmodel_state_dict
[docs]defget_optimizer_state_dict(model:nn.Module,optimizers:Union[torch.optim.Optimizer,Iterable[torch.optim.Optimizer]],*,submodules:Optional[Set[nn.Module]]=None,options:Optional[StateDictOptions]=None,)->OptimizerStateType:""" Return the combined state_dict for optimizers. See ``get_state_dict`` for the detail usage. Args: model (nn.Module): the nn.Module to the model. optimizers (Union[None, Optimizer, Iterable[Optimizer]]): The optimizers that are used to optimize ``model``. submodules: Optional[Set[nn.Module]]: only return the model parameters that belong to the submodules. options (StateDictOptions): the options to control how model state_dict and optimizer state_dict should be returned. See `StateDictOptions` for the details. Returns: The state_dict for ``optimizers``. :rtype: OptimizerStateType """withgc_context():optimizers=((optimizers,)ifisinstance(optimizers,torch.optim.Optimizer)elsetuple(optimizers))info=_verify_options(model,optimizers,optim_only=True,submodules=submodules,options=options,)optim_state_dict=_get_optim_state_dict(model,optimizers,info)_verify_state_dict({},optim_state_dict,info)returnoptim_state_dict
[docs]defget_state_dict(model:nn.Module,optimizers:Union[torch.optim.Optimizer,Iterable[torch.optim.Optimizer]],*,submodules:Optional[Set[nn.Module]]=None,options:Optional[StateDictOptions]=None,)->Tuple[Dict[str,ValueType],OptimizerStateType]:""" Return the model state_dict and optimizers state_dict. ``get_state_dict`` can process any module that is parallelized by PyTorch FSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and any combination of these parallelisms. The main functions of ``get_state_dict`` are: 1.) returning a model and optimizer state_dict that can be resharded with a different number of trainers and/or different parallelisms. 2.) hiding the parallelism-specific state_dict APIs. Users don't have to call these APIs. 3.) sanity checking the result state_dict. The keys of the result state dictionary are the canonical FQNs (Fully Qualified Names). A canonical FQN refers to the FQN based on a parameter's position in an nn.Module hierarchy. More specifically, a canonical FQN to a parameter is the FQN returned by ``module.named_parameters()`` or ``module.named_buffers()`` when the module is not distributed by any parallelisms. Since the optimizer internally uses parameter IDs to represent a parameter, there will be a conversion from the parameter IDs to the canonical FQNs when calling this API. ``get_state_dict`` can also process a module that is not parallelized. In such a case, ``get_state_dict`` only performs one function -- converting the optimizer parameter IDs to the canonical FQNs. Example: >>> # xdoctest: +SKIP >>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> from torch.distributed.checkpoint.state_dict import get_state_dict >>> fsdp_model = FSDP(copy.deepcopy(model)) >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) >>> ddp_model = DDP(copy.deepcopy(model)) >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) >>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim) >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim) >>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(), >>> # the asserts will fail. >>> assert ddp_state_dict == fsdp_state_dict >>> assert ddp_optim_state == fsdp_optim_state_dict Args: model (nn.Module): the nn.Module to the model. optimizers (Union[None, Optimizer, Iterable[Optimizer]]): The optimizers that are used to optimize ``model``. submodules: Optional[Set[nn.Module]]: only return the model parameters that belong to the submodules. options (StateDictOptions): the options to control how model state_dict and optimizer state_dict should be returned. See `StateDictOptions` for the details. Returns: ``Tuple`` that contain model state_dict and optimizer state_dict. :rtype: typing.Tuple[typing.Dict[str, ValueType], OptimizerStateType] """withgc_context():optimizers=((optimizers,)ifisinstance(optimizers,torch.optim.Optimizer)elsetuple(optimizers))info=_verify_options(model,optimizers,optim_only=False,submodules=submodules,options=options,)model_state_dict=_get_model_state_dict(model,info)optim_state_dict=_get_optim_state_dict(model,optimizers,info)_verify_state_dict(model_state_dict,optim_state_dict,info)returnmodel_state_dict,optim_state_dict
def_unflatten_model_state_dict(model:nn.Module,state_dict:Union[Dict[nn.Module,Dict[str,ValueType]],Dict[str,ValueType]],)->Dict[str,ValueType]:ifnotstate_dict:return{}ifisinstance(next(iter(state_dict.keys())),nn.Module):cast_state_dict=cast(Dict[nn.Module,Dict[str,ValueType]],state_dict)new_state_dict:Dict[str,ValueType]={}forsubmodule,sub_state_dictincast_state_dict.items():forname,minmodel.named_modules():ifm!=submodule:continuefqns=_get_fqns(model,name)assertlen(fqns)==1,"FQNs for a submodule should only have 1 element"prefix=f"{next(iter(fqns))}."new_state_dict.update({prefix+subfqn:valueforsubfqn,valueinsub_state_dict.items()})returnnew_state_dictelse:returncast(Dict[str,ValueType],state_dict)
[docs]defset_model_state_dict(model:nn.Module,model_state_dict:Dict[str,ValueType],*,options:Optional[StateDictOptions]=None,)->_IncompatibleKeys:"""Load the model state_dict. The counterpart of ``get_model_state_dict`` to set the state_dict to the model. See ``set_state_dict`` for the detail usage. Args: model (nn.Module): the nn.Module to the model. model_state_dict: (Dict[str, ValueType]): the model state_dict to load. If the key of the ``model_state_dict`` is nn.Module, the key is a submodule of ``model`` and the value should be the state_dict of the submodule. When loading the state_dict, the prefix of the submodule will be append to the state_dict. options (StateDictOptions): the options to control how model state_dict and optimizer state_dict should be loaded. See `StateDictOptions` for the details. Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys :type model_state_dict: typing.Dict[str, ValueType] """model_state_dict:Dict[str,ValueType]=_unflatten_model_state_dict(model,model_state_dict)withgc_context():info=_verify_options(model,tuple(),optim_only=False,options=options)_verify_state_dict(model_state_dict,{},info)return_load_model_state_dict(model,model_state_dict,info)
[docs]defset_optimizer_state_dict(model:nn.Module,optimizers:Union[torch.optim.Optimizer,Iterable[torch.optim.Optimizer]],*,optim_state_dict:OptimizerStateType,options:Optional[StateDictOptions]=None,)->None:"""Load the optimizers state_dict. The counterpart of ``get_optimizer_state_dict`` to set the state_dict to the optimizers. See ``set_state_dict`` for the detail usage. Args: model (nn.Module): the nn.Module to the model. optimizers (Union[Optimizer, Iterable[Optimizer]]): The optimizers that are used to optimize ``model``. optim_state_dict: OptimizerStateType: the optimizer state_dict to load. options (StateDictOptions): the options to control how model state_dict and optimizer state_dict should be loaded. See `StateDictOptions` for the details. Returns: None :type optim_state_dict: typing.OptimizerStateType """withgc_context():optimizers=((optimizers,)ifisinstance(optimizers,torch.optim.Optimizer)elsetuple(optimizers))info=_verify_options(model,optimizers,optim_only=True,options=options)_verify_state_dict({},optim_state_dict,info)_load_optim_state_dict(model,optimizers,optim_state_dict,info)
[docs]defset_state_dict(model:nn.Module,optimizers:Union[torch.optim.Optimizer,Iterable[torch.optim.Optimizer]],*,model_state_dict:Dict[str,ValueType],optim_state_dict:OptimizerStateType,options:Optional[StateDictOptions]=None,)->_IncompatibleKeys:"""Load the model state_dict and optimizers state_dict. The counterpart of ``get_state_dict`` to set the state_dict to the model and optimizers. The given ``model_state_dict`` and ``optim_state_dict`` do not have to be returned by ``get_state_dict`` but must meet the following requirements: 1) all FQNs are canonical FQNs as defined in ``get_state_dict``, 2) if a tensor is sharded, it must be either a ShardedTensor or DTensor, 3) optimizer state_dict cannot contain the parameter IDs; the keys should be the canonical FQNs. Args: model (nn.Module): the nn.Module to the model. optimizers (Union[Optimizer, Iterable[Optimizer]]): The optimizers that are used to optimize ``model``. model_state_dict: (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): the model state_dict to load. If the key of the ``model_state_dict`` is nn.Module, the key is a submodule of ``model`` and the value should be the state_dict of the submodule. When loading the state_dict, the prefix of the submodule will be append to the state_dict. optim_state_dict: OptimizerStateType: the optimizer state_dict to load. options (StateDictOptions): the options to control how model state_dict and optimizer state_dict should be loaded. See `StateDictOptions` for the details. Returns: ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: * **missing_keys** is a list of str containing the missing keys of the model state_dict. * **unexpected_keys** is a list of str containing the unexpected keys of the model state_dict. :type model_state_dict: typing.Dict[str, ValueType] :type optim_state_dict: typing.OptimizerStateType """model_state_dict:Dict[str,ValueType]=_unflatten_model_state_dict(model,model_state_dict)withgc_context():optimizers=((optimizers,)ifisinstance(optimizers,torch.optim.Optimizer)elsetuple(optimizers))info=_verify_options(model,optimizers,optim_only=notmodel_state_dict,options=options)_verify_state_dict(model_state_dict,optim_state_dict,info)_load_optim_state_dict(model,optimizers,optim_state_dict,info)return_load_model_state_dict(model,model_state_dict,info)
# TODO: correct the state_dict function signature.# TODO: this API is not yet fully tested. Make it private@no_type_checkdef_patch_model_state_dict(model:nn.Module,*,options:Optional[StateDictOptions]=None,)->None:"""Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model``. Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model`` to be a partial function to call ``get_state_dict`` and ``set_state_dict``. Example: from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.checkpoint.state_dict import patch_model_state_dict model = fsdp(model) patch_model_state_dict(model) Args: model (nn.Module): the nn.Module to the model. options (StateDictOptions): the options to control how model state_dict and optimizer state_dict should be loaded. See `StateDictOptions` for the details. Returns: None """_state_dict_call=functools.partial(get_model_state_dict,model=model,options=options,)defstate_dict_call():return_state_dict_call()model.state_dict=state_dict_call_load_state_dict_call=functools.partial(set_model_state_dict,model=model,options=options,)defload_state_dict_call(state_dict:Dict[str,Any]):_load_state_dict_call(model_state_dict=state_dict)model.load_state_dict=load_state_dict_call_patched_state_dict.add(state_dict_call)_patched_state_dict.add(load_state_dict_call)# TODO: correct the load_state_dict function signature.# TODO: this API is not yet fully tested. Make it private@no_type_checkdef_patch_optimizer_state_dict(model:nn.Module,*,optimizers:Tuple[torch.optim.Optimizer,...],options:Optional[StateDictOptions]=None,)->None:"""Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers``. Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers`` to be a partial function to call ``get_state_dict`` and ``set_state_dict``. Note that if there are multiple optimizers, all of the optimizers will be patched. So users only need to call one of the state_dict() to get the full result. Example: from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.checkpoint.state_dict import patch_model_state_dict model = fsdp(model) patch_model_state_dict(model) Args: model (nn.Module): the nn.Module to the model. options (StateDictOptions): the options to control how model state_dict and optimizer state_dict should be loaded. See `StateDictOptions` for the details. Returns: None """_state_dict_call=functools.partial(get_optimizer_state_dict,model=model,optimizers=optimizers,options=options,)defstate_dict_call():return_state_dict_call()_load_state_dict_call=functools.partial(set_optimizer_state_dict,model=model,optimizers=optimizers,options=options,)defload_state_dict_call(state_dict:Dict[str,Any]):_load_state_dict_call(optim_state_dict=state_dict)_patched_state_dict.add(state_dict_call)_patched_state_dict.add(load_state_dict_call)optimizers=((optimizers,)ifisinstance(optimizers,torch.optim.Optimizer)elsetuple(optimizers))foroptiminoptimizers:optim.state_dict=state_dict_calloptim.load_state_dict=load_state_dict_call
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.