# mypy: allow-untyped-defsimportcopyimportinspectimportitertoolsimportwarningsimporttorchimporttorch.ao.nn.quantizedasnnqimporttorch.nnasnnfromtorch.ao.nn.intrinsicimport_FusedModulefromtorch.ao.quantization.observerimport_is_activation_post_processfromtorch.ao.quantization.qconfigimport(_activation_is_memoryless,_add_module_to_qconfig_obs_ctr,default_dynamic_qconfig,float16_dynamic_qconfig,float_qparams_weight_only_qconfig,float_qparams_weight_only_qconfig_4bit,)fromtorch.ao.quantization.quantization_mappingsimport(_get_special_act_post_process,_has_special_act_post_process,get_default_dynamic_quant_module_mappings,get_default_qat_module_mappings,get_default_qconfig_propagation_list,get_default_static_quant_module_mappings,get_default_static_quant_reference_module_mappings,no_observer_set,)fromtorch.ao.quantization.stubsimportDeQuantStub,QuantWrapperfromtorch.nn.utils.parametrizeimporttype_before_parametrizationsfrom.utilsimportget_qparam_dict,has_no_children_ignoring_parametrizations__all__=["get_default_custom_config_dict","propagate_qconfig_","add_quant_dequant","prepare","quantize","quantize_dynamic","prepare_qat","quantize_qat","convert","swap_module",]# TODO remove this once BC is no longer required to avoid a SEVis_activation_post_process=_is_activation_post_process_DEFAULT_CUSTOM_CONFIG_DICT={"float_to_observed_custom_module_class":{nn.LSTM:nn.quantizable.LSTM,nn.MultiheadAttention:nn.quantizable.MultiheadAttention,},"observed_to_quantized_custom_module_class":{nn.quantizable.LSTM:nn.quantized.LSTM,nn.quantizable.MultiheadAttention:nn.quantized.MultiheadAttention,},}defget_default_custom_config_dict():r"""Defines the default custom config dict."""return_DEFAULT_CUSTOM_CONFIG_DICTdef_propagate_qconfig_helper(module,qconfig_dict,qconfig_parent=None,prefix="",prepare_custom_config_dict=None,):r"""This is a helper function for `propagate_qconfig_` Args: module: input module qconfig_dict: dictionary that maps from name of submodule to quantization configuration qconfig_parent: quantization config of parent module, we will fallback to this config when there is no specified config for current module prefix: corresponding prefix of the current module, used as key in qconfig_dict prepare_custom_config_dict: dictionary for custom handling of modules see docs for :func:`~torch.ao.quantization.prepare_fx` Return: None, module is modified inplace with qconfig attached """module_qconfig=qconfig_dict.get(type_before_parametrizations(module),qconfig_parent)module_qconfig=qconfig_dict.get(prefix,module_qconfig)module_qconfig=getattr(module,"qconfig",module_qconfig)torch.ao.quantization.qconfig._assert_valid_qconfig(module_qconfig,module)qconfig_with_device_check=_add_module_to_qconfig_obs_ctr(module_qconfig,module)module.qconfig=qconfig_with_device_checkforname,childinmodule.named_children():module_prefix=prefix+"."+nameifprefixelsename# do no not propagate qconfig to child if child is non traceableifprepare_custom_config_dictisNoneornot(nameinprepare_custom_config_dict.get("non_traceable_module_name",[])ortype(child)inprepare_custom_config_dict.get("non_traceable_module_class",[])):_propagate_qconfig_helper(child,qconfig_dict,qconfig_with_device_check,module_prefix)
[docs]defpropagate_qconfig_(module,qconfig_dict=None,prepare_custom_config_dict=None):r"""Propagate qconfig through the module hierarchy and assign `qconfig` attribute on each leaf module Args: module: input module qconfig_dict: dictionary that maps from name or type of submodule to quantization configuration, qconfig applies to all submodules of a given module unless qconfig for the submodules are specified (when the submodule already has qconfig attribute) prepare_custom_config_dict: dictionary for custom handling of modules see docs for :func:`~torch.ao.quantization.prepare_fx` Return: None, module is modified inplace with qconfig attached """ifqconfig_dictisNone:qconfig_dict={}ifprepare_custom_config_dictisNone:prepare_custom_config_dict={}_propagate_qconfig_helper(module,qconfig_dict,prepare_custom_config_dict=prepare_custom_config_dict)
def_observer_forward_hook(self,input,output):r"""Forward hook that calls observer on the output"""returnself.activation_post_process(output)def_observer_forward_pre_hook(self,input):r"""Forward pre hook that calls observer on the output"""returnself.activation_post_process(input[0])def_register_activation_post_process_hook(module,pre_hook=False):asserthasattr(module,"activation_post_process"),"Expect activation_post_process attribute already attached to the module"ifpre_hook:handle=module.register_forward_pre_hook(_observer_forward_pre_hook,prepend=True)else:handle=module.register_forward_hook(_observer_forward_hook,prepend=True)def_add_observer_(module,qconfig_propagation_list=None,non_leaf_module_list=None,device=None,custom_module_class_mapping=None,):r"""Add observer for the leaf child of the module. This function insert observer module to all leaf child module that has a valid qconfig attribute. Args: module: input module with qconfig attributes for all the leaf modules that we want to quantize qconfig_propagation_list: a list of quantizable modules that will have observers added to them if they are leaf nodes device: parent device, if any non_leaf_module_list: list of non-leaf modules we want to add observer Return: None, module is modified inplace with added observer modules and forward_hooks """ifqconfig_propagation_listisNone:qconfig_propagation_list=get_default_qconfig_propagation_list()ifcustom_module_class_mappingisNone:custom_module_class_mapping={}# respect device affinity when adding observersifdeviceisNone:devices=_get_unique_devices_(module)assert(len(devices)<=1),f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}"device=next(iter(devices))iflen(devices)>0elseNonedefget_activation_post_process(qconfig,device,special_act_post_process=None):activation=(qconfig.activation()ifspecial_act_post_processisNoneelsespecial_act_post_process())ifdeviceisnotNone:activation.to(device)returnactivationdefneeds_observation(m):returnhasattr(m,"qconfig")andm.qconfigisnotNonedefinsert_activation_post_process(m,special_act_post_process=None):"""Adds an activation post process module and register a pre or post hook that calls the module """# We don't insert observer/fake_quantize for DeQuantStubifneeds_observation(m)andnotisinstance(m,DeQuantStub):# observer and hook will be gone after we swap the modulem.add_module("activation_post_process",get_activation_post_process(m.qconfig,device,special_act_post_process),)# Register observer as the first entry in the hook list# All post forward hooks are preserved and will be executed after the observer before convert_register_activation_post_process_hook(m,pre_hook=_activation_is_memoryless(m.qconfig))forname,childinmodule.named_children():# TODO remove Dropout special after codebase stableiftype_before_parametrizations(child)in[nn.Dropout]:continueelifissubclass(type_before_parametrizations(child),(nnq.FloatFunctional,nnq.QFunctional)):ifneeds_observation(child):asserthasattr(child,"activation_post_process"),f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`"child.activation_post_process=get_activation_post_process(child.qconfig,device)elifisinstance(child,_FusedModule):# activation_post_process are now added directly to nn.Sequential/_FusedModuleifneeds_observation(child):insert_activation_post_process(child)elif(non_leaf_module_listisnotNoneandtype_before_parametrizations(child)innon_leaf_module_list):ifneeds_observation(child):insert_activation_post_process(child)elif_has_special_act_post_process(child):special_act_post_process=_get_special_act_post_process(child)insert_activation_post_process(child,special_act_post_process)elif(needs_observation(child)andtype_before_parametrizations(child)incustom_module_class_mapping):observed_child=custom_module_class_mapping[type_before_parametrizations(child)].from_float(child)setattr(module,name,observed_child)# TODO: These are the modules that cannot be observed# Once there are more, we should move them to a separate listif(custom_module_class_mapping[type_before_parametrizations(child)]notinno_observer_set()):insert_activation_post_process(observed_child)else:_add_observer_(child,qconfig_propagation_list,non_leaf_module_list,device,custom_module_class_mapping,)# Insert observers only for leaf nodes, note that this observer is for# the output of the module, for input QuantStub will observe themif(has_no_children_ignoring_parametrizations(module)andnotisinstance(module,torch.nn.Sequential)andtype_before_parametrizations(module)inqconfig_propagation_list):insert_activation_post_process(module)# This is a special case for AdaRound eager mode# AdaRound contains weight_fake_quant to be propagated from API to convert# leaf node check with a number of children looks naive assumption that blocks# Adding an exception case for AdaRoundif(hasattr(module,"weight_fake_quant")andnotisinstance(module,torch.nn.Sequential)andtype_before_parametrizations(module)inqconfig_propagation_list):insert_activation_post_process(module)def_get_unique_devices_(module):return{p.deviceforpinmodule.parameters()}|{p.deviceforpinmodule.buffers()}
[docs]defadd_quant_dequant(module):r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig Note that this function will modify the children of module inplace and it can return a new module which wraps the input module as well. Args: module: input module with qconfig attributes for all the leaf modules that we want to quantize Return: Either the inplace modified module with submodules wrapped in `QuantWrapper` based on qconfig or a new `QuantWrapper` module which wraps the input module, the latter case only happens when the input module is a leaf module and we want to quantize it. """if(has_no_children_ignoring_parametrizations(module)andhasattr(module,"qconfig")andmodule.qconfig):returnQuantWrapper(module)forname,childinmodule.named_children():module._modules[name]=add_quant_dequant(child)returnmodule
[docs]defprepare(model,inplace=False,allow_list=None,observer_non_leaf_module_list=None,prepare_custom_config_dict=None,):r"""Prepares a copy of the model for quantization calibration or quantization-aware training. Quantization configuration should be assigned preemptively to individual submodules in `.qconfig` attribute. The model will be attached with observer or fake quant modules, and qconfig will be propagated. Args: `model`: input model to be modified in-place `inplace`: carry out model transformations in-place, the original module is mutated `allow_list`: list of quantizable modules `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer `prepare_custom_config_dict`: customization configuration dictionary for prepare function .. code-block:: python # Example of prepare_custom_config_dict: prepare_custom_config_dict = { # user will manually define the corresponding observed # module class which has a from_float class method that converts # float custom module to observed custom module "float_to_observed_custom_module_class": { CustomModule: ObservedCustomModule } } """torch._C._log_api_usage_once("quantization_api.quantize.prepare")ifprepare_custom_config_dictisNone:prepare_custom_config_dict=get_default_custom_config_dict()custom_module_class_mapping=prepare_custom_config_dict.get("float_to_observed_custom_module_class",{})ifnotinplace:model=copy.deepcopy(model)# TODO: remove allow_listqconfig_propagation_list=allow_listifallow_listisNone:qconfig_propagation_list=get_default_qconfig_propagation_list()propagate_qconfig_(model,qconfig_dict=None)# sanity check common API misusageifnotany(hasattr(m,"qconfig")andm.qconfigforminmodel.modules()):warnings.warn("None of the submodule got qconfig applied. Make sure you ""passed correct configuration through `qconfig_dict` or ""by assigning the `.qconfig` attribute directly on submodules")_add_observer_(model,qconfig_propagation_list,observer_non_leaf_module_list,custom_module_class_mapping=custom_module_class_mapping,)returnmodel
def_remove_activation_post_process(module):# TODO: maybe we should change activation_post_process to _activation_post_process# to prevent it from being used by userifhasattr(module,"activation_post_process")and_is_activation_post_process(module.activation_post_process):delattr(module,"activation_post_process")# remove activation_post_process pre and post hooksdefremove_hooks(pre_hook=False):hook_map=module._forward_pre_hooksifpre_hookelsemodule._forward_hooksobserver_hook=(_observer_forward_pre_hookifpre_hookelse_observer_forward_hook)handle_ids_to_remove=set()forhandle_id,hook_fninhook_map.items():ifhook_fnisobserver_hook:handle_ids_to_remove.add(handle_id)forhandle_idinhandle_ids_to_remove:hook_map.pop(handle_id)remove_hooks(pre_hook=True)remove_hooks(pre_hook=False)# TODO: rename to something more generaldef_remove_qconfig(module):r"""Clean up the qconfig left in the module so that new qconfig can be propagated. Args: module: module to be cleaned up """forchildinmodule.children():_remove_qconfig(child)ifhasattr(module,"qconfig"):delmodule.qconfig_remove_activation_post_process(module)
[docs]defquantize(model,run_fn,run_args,mapping=None,inplace=False):r"""Quantize the input float model with post training static quantization. First it will prepare the model for calibration, then it calls `run_fn` which will run the calibration step, after that we will convert the model to a quantized model. Args: model: input float model run_fn: a calibration function for calibrating the prepared model run_args: positional arguments for `run_fn` inplace: carry out model transformations in-place, the original module is mutated mapping: correspondence between original module types and quantized counterparts Return: Quantized model. """torch._C._log_api_usage_once("quantization_api.quantize.quantize")ifmappingisNone:mapping=get_default_static_quant_module_mappings()ifnotinplace:model=copy.deepcopy(model)model.eval()prepare(model,inplace=True)run_fn(model,*run_args)convert(model,mapping,inplace=True)returnmodel
[docs]defquantize_dynamic(model,qconfig_spec=None,dtype=torch.qint8,mapping=None,inplace=False):r"""Converts a float model to dynamic (i.e. weights-only) quantized model. Replaces specified modules with dynamic weight-only quantized versions and output the quantized model. For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization by default is performed for layers with large weights size - i.e. Linear and RNN variants. Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`. If `qconfig` is provided, the `dtype` argument is ignored. Args: model: input model qconfig_spec: Either: - A dictionary that maps from name or type of submodule to quantization configuration, qconfig applies to all submodules of a given module unless qconfig for the submodules are specified (when the submodule already has qconfig attribute). Entries in the dictionary need to be QConfig instances. - A set of types and/or submodule names to apply dynamic quantization to, in which case the `dtype` argument is used to specify the bit-width inplace: carry out model transformations in-place, the original module is mutated mapping: maps type of a submodule to a type of corresponding dynamically quantized version with which the submodule needs to be replaced """torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic")ifqconfig_specisNone:ifdtype==torch.qint8:qconfig_spec={nn.Linear:default_dynamic_qconfig,nn.LSTM:default_dynamic_qconfig,nn.GRU:default_dynamic_qconfig,nn.LSTMCell:default_dynamic_qconfig,nn.RNNCell:default_dynamic_qconfig,nn.GRUCell:default_dynamic_qconfig,}elifdtype==torch.float16:qconfig_spec={nn.Linear:float16_dynamic_qconfig,nn.LSTM:float16_dynamic_qconfig,nn.GRU:float16_dynamic_qconfig,nn.LSTMCell:float16_dynamic_qconfig,nn.RNNCell:float16_dynamic_qconfig,nn.GRUCell:float16_dynamic_qconfig,}elifdtype==torch.quint8:qconfig_spec={nn.EmbeddingBag:float_qparams_weight_only_qconfig,nn.Embedding:float_qparams_weight_only_qconfig,}elifdtype==torch.quint4x2:qconfig_spec={nn.EmbeddingBag:float_qparams_weight_only_qconfig_4bit,}else:raiseValueError(f"Don't know how to quantize with default settings for {dtype}. Provide full qconfig please")elifisinstance(qconfig_spec,set):ifdtypeistorch.qint8:default_qconfig=default_dynamic_qconfigelifdtypeistorch.float16:default_qconfig=float16_dynamic_qconfigelifdtypeistorch.quint8:default_qconfig=float_qparams_weight_only_qconfigelifdtypeistorch.quint4x2:default_qconfig=float_qparams_weight_only_qconfig_4bitelse:raiseRuntimeError("Unknown dtype specified for quantize_dynamic: ",str(dtype))qconfig_spec=dict(zip(qconfig_spec,itertools.repeat(default_qconfig)))ifmappingisNone:mapping=get_default_dynamic_quant_module_mappings()ifnotinplace:model=copy.deepcopy(model)model.eval()propagate_qconfig_(model,qconfig_spec)convert(model,mapping,inplace=True)returnmodel
[docs]defprepare_qat(model,mapping=None,inplace=False):r""" Prepares a copy of the model for quantization calibration or quantization-aware training and converts it to quantized version. Quantization configuration should be assigned preemptively to individual submodules in `.qconfig` attribute. Args: model: input model to be modified in-place mapping: dictionary that maps float modules to quantized modules to be replaced. inplace: carry out model transformations in-place, the original module is mutated """torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat")assertmodel.training,"prepare_qat only works on models in training mode"ifmappingisNone:mapping=get_default_qat_module_mappings()ifnotinplace:model=copy.deepcopy(model)propagate_qconfig_(model,qconfig_dict=None)convert(model,mapping=mapping,inplace=True,remove_qconfig=False)prepare(model,observer_non_leaf_module_list=set(mapping.values()),inplace=True)returnmodel
[docs]defquantize_qat(model,run_fn,run_args,inplace=False):r"""Do quantization aware training and output a quantized model Args: model: input model run_fn: a function for evaluating the prepared model, can be a function that simply runs the prepared model or a training loop run_args: positional arguments for `run_fn` Return: Quantized model. """torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat")ifnotinplace:model=copy.deepcopy(model)model.train()prepare_qat(model,inplace=True)run_fn(model,*run_args)convert(model,inplace=True)returnmodel
[docs]defconvert(module,mapping=None,inplace=False,remove_qconfig=True,is_reference=False,convert_custom_config_dict=None,use_precomputed_fake_quant=False,):r"""Converts submodules in input module to a different module according to `mapping` by calling `from_float` method on the target module class. And remove qconfig at the end if remove_qconfig is set to True. Args: `module`: prepared and calibrated module `mapping`: a dictionary that maps from source module type to target module type, can be overwritten to allow swapping user defined Modules `inplace`: carry out model transformations in-place, the original module is mutated `convert_custom_config_dict`: custom configuration dictionary for convert function `use_precomputed_fake_quant`: a flag to enable use of precomputed fake quant .. code-block:: python # Example of convert_custom_config_dict: convert_custom_config_dict = { # user will manually define the corresponding quantized # module class which has a from_observed class method that converts # observed custom module to quantized custom module "observed_to_quantized_custom_module_class": { ObservedCustomModule: QuantizedCustomModule } } """torch._C._log_api_usage_once("quantization_api.quantize.convert")ifnotinplace:module=copy.deepcopy(module)_convert(module,mapping,inplace=True,is_reference=is_reference,convert_custom_config_dict=convert_custom_config_dict,use_precomputed_fake_quant=use_precomputed_fake_quant,)ifremove_qconfig:_remove_qconfig(module)returnmodule
def_convert(module,mapping=None,inplace=False,is_reference=False,convert_custom_config_dict=None,use_precomputed_fake_quant=False,):r"""Converts submodules in input module to a different module according to `mapping` by calling `from_float` method on the target module class Args: module: input module mapping: a dictionary that maps from source module type to target module type, can be overwritten to allow swapping user defined Modules inplace: carry out model transformations in-place, the original module is mutated is_reference: a flag to enable quantized reference module use_precomputed_fake_quant: a flag to enable use of precomputed fake quant """ifmappingisNone:mapping=(get_default_static_quant_reference_module_mappings()ifis_referenceelseget_default_static_quant_module_mappings())ifconvert_custom_config_dictisNone:convert_custom_config_dict=get_default_custom_config_dict()custom_module_class_mapping=convert_custom_config_dict.get("observed_to_quantized_custom_module_class",{})ifnotinplace:module=copy.deepcopy(module)reassign={}forname,modinmodule.named_children():# both fused modules and observed custom modules are# swapped as one unitif(notisinstance(mod,_FusedModule)andtype_before_parametrizations(mod)notincustom_module_class_mapping):_convert(mod,mapping,True,# inplaceis_reference,convert_custom_config_dict,use_precomputed_fake_quant=use_precomputed_fake_quant,)reassign[name]=swap_module(mod,mapping,custom_module_class_mapping,use_precomputed_fake_quant)forkey,valueinreassign.items():module._modules[key]=valuereturnmodule
[docs]defswap_module(mod,mapping,custom_module_class_mapping,use_precomputed_fake_quant=False):r"""Swaps the module if it has a quantized counterpart and it has an `observer` attached. Args: mod: input module mapping: a dictionary that maps from nn module to nnq module Return: The corresponding quantized module of `mod` """new_mod=modifhasattr(mod,"qconfig")andmod.qconfigisnotNone:swapped=Falseiftype_before_parametrizations(mod)incustom_module_class_mapping:new_mod=custom_module_class_mapping[type_before_parametrizations(mod)].from_observed(mod)swapped=Trueeliftype_before_parametrizations(mod)inmapping:qmod=mapping[type_before_parametrizations(mod)]ifhasattr(qmod,"_IS_REFERENCE")andqmod._IS_REFERENCE:assertmod.qconfigisnotNoneweight_post_process=mod.qconfig.weight()weight_post_process(mod.weight)weight_qparams=get_qparam_dict(weight_post_process)new_mod=qmod.from_float(mod,weight_qparams)else:sig=inspect.signature(qmod.from_float)if"use_precomputed_fake_quant"insig.parameters:new_mod=qmod.from_float(mod,use_precomputed_fake_quant=use_precomputed_fake_quant)else:new_mod=qmod.from_float(mod)swapped=Trueifswapped:# Preserve module's pre forward hooks. They'll be called on quantized inputforpre_hook_fninmod._forward_pre_hooks.values():new_mod.register_forward_pre_hook(pre_hook_fn)# Preserve module's post forward hooks except _observer_forward_hook# After convert they'll work with quantized outputforhook_fninmod._forward_hooks.values():ifhook_fnisnot_observer_forward_hook:new_mod.register_forward_hook(hook_fn)# respect device affinity when swapping modulesdevices=_get_unique_devices_(mod)assert(len(devices)<=1),f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"device=next(iter(devices))iflen(devices)>0elseNoneifdevice:new_mod.to(device)returnnew_mod
def_get_observer_dict(mod,target_dict,prefix=""):r"""Traverse the modules and save all observers into dict. This is mainly used for quantization accuracy debug Args: mod: the top module we want to save all observers prefix: the prefix for the current module target_dict: the dictionary used to save all the observers """defget_prefix(prefix):returnprefixifprefix==""elseprefix+"."ifhasattr(mod,"activation_post_process"):target_dict[get_prefix(prefix)+"activation_post_process"]=mod.activation_post_processforname,childinmod.named_children():module_prefix=get_prefix(prefix)+nameifprefixelsename_get_observer_dict(child,target_dict,module_prefix)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.