Source code for torch.ao.quantization.fuse_modules
importcopyimporttorch.nnasnnfromtorch.ao.quantization.fuser_method_mappingsimportget_fuser_method# for backward compatibilityfromtorch.ao.quantization.fuser_method_mappingsimportfuse_conv_bn# noqa: F401fromtorch.ao.quantization.fuser_method_mappingsimportfuse_conv_bn_relu# noqa: F401fromtorch.nn.utils.parametrizeimporttype_before_parametrizationsfromtypingimportList,Optional__all__=["fuse_known_modules","fuse_modules","fuse_modules_qat",]# Generalization of getattrdef_get_module(model,submodule_key):tokens=submodule_key.split('.')cur_mod=modelforsintokens:cur_mod=getattr(cur_mod,s)returncur_mod# Generalization of setattrdef_set_module(model,submodule_key,module):tokens=submodule_key.split('.')sub_tokens=tokens[:-1]cur_mod=modelforsinsub_tokens:cur_mod=getattr(cur_mod,s)setattr(cur_mod,tokens[-1],module)deffuse_known_modules(mod_list,is_qat,additional_fuser_method_mapping=None):r"""Returns a list of modules that fuses the operations specified in the input module list. Fuses only the following sequence of modules: conv, bn conv, bn, relu conv, relu linear, bn linear, relu For these sequences, the first element in the output module list performs the fused operation. The rest of the elements are set to nn.Identity() """types=tuple(type_before_parametrizations(m)forminmod_list)fuser_method=get_fuser_method(types,additional_fuser_method_mapping)iffuser_methodisNone:raiseNotImplementedError(f"Cannot fuse modules: {types}")new_mod:List[Optional[nn.Module]]=[None]*len(mod_list)fused=fuser_method(is_qat,*mod_list)# NOTE: forward hooks not processed in the two following for loops will be lost after the fusion# Move pre forward hooks of the base module to resulting fused moduleforpre_hook_fninmod_list[0]._forward_pre_hooks.values():fused.register_forward_pre_hook(pre_hook_fn)mod_list[0]._forward_pre_hooks.clear()# Move post forward hooks of the last module to resulting fused moduleforhook_fninmod_list[-1]._forward_hooks.values():fused.register_forward_hook(hook_fn)mod_list[-1]._forward_hooks.clear()new_mod[0]=fusedforiinrange(1,len(mod_list)):identity=nn.Identity()identity.training=mod_list[0].trainingnew_mod[i]=identityreturnnew_moddef_fuse_modules_helper(model,modules_to_fuse,is_qat,fuser_func=fuse_known_modules,fuse_custom_config_dict=None):iffuse_custom_config_dictisNone:fuse_custom_config_dict={}additional_fuser_method_mapping=fuse_custom_config_dict.get("additional_fuser_method_mapping",{})mod_list=[]foriteminmodules_to_fuse:mod_list.append(_get_module(model,item))# Fuse list of modulesnew_mod_list=fuser_func(mod_list,is_qat,additional_fuser_method_mapping)# Replace original module list with fused module listfori,iteminenumerate(modules_to_fuse):_set_module(model,item,new_mod_list[i])def_fuse_modules(model,modules_to_fuse,is_qat,inplace=False,fuser_func=fuse_known_modules,fuse_custom_config_dict=None):ifnotinplace:model=copy.deepcopy(model)ifall(isinstance(module_element,str)formodule_elementinmodules_to_fuse):# Handle case of modules_to_fuse being a list_fuse_modules_helper(model,modules_to_fuse,is_qat,fuser_func,fuse_custom_config_dict)else:# Handle case of modules_to_fuse being a list of listsformodule_listinmodules_to_fuse:_fuse_modules_helper(model,module_list,is_qat,fuser_func,fuse_custom_config_dict)returnmodel
[docs]deffuse_modules(model,modules_to_fuse,inplace=False,fuser_func=fuse_known_modules,fuse_custom_config_dict=None):r"""Fuses a list of modules into a single module Fuses only the following sequence of modules: conv, bn conv, bn, relu conv, relu linear, relu bn, relu All other sequences are left unchanged. For these sequences, replaces the first item in the list with the fused module, replacing the rest of the modules with identity. Args: model: Model containing the modules to be fused modules_to_fuse: list of list of module names to fuse. Can also be a list of strings if there is only a single list of modules to fuse. inplace: bool specifying if fusion happens in place on the model, by default a new model is returned fuser_func: Function that takes in a list of modules and outputs a list of fused modules of the same length. For example, fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()] Defaults to torch.ao.quantization.fuse_known_modules `fuse_custom_config_dict`: custom configuration for fusion .. code-block:: python # Example of fuse_custom_config_dict fuse_custom_config_dict = { # Additional fuser_method mapping "additional_fuser_method_mapping": { (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn }, } Returns: model with fused modules. A new copy is created if inplace=True. Examples:: >>> # xdoctest: +SKIP >>> m = M().eval() >>> # m is a module containing the sub-modules below >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']] >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) >>> output = fused_m(input) >>> m = M().eval() >>> # Alternately provide a single list of modules to fuse >>> modules_to_fuse = ['conv1', 'bn1', 'relu1'] >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) >>> output = fused_m(input) """return_fuse_modules(model,modules_to_fuse,is_qat=False,inplace=inplace,fuser_func=fuser_func,fuse_custom_config_dict=fuse_custom_config_dict)
deffuse_modules_qat(model,modules_to_fuse,inplace=False,fuser_func=fuse_known_modules,fuse_custom_config_dict=None):""" QAT version for `fuse_modules` """return_fuse_modules(model,modules_to_fuse,is_qat=True,inplace=inplace,fuser_func=fuser_func,fuse_custom_config_dict=fuse_custom_config_dict)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.