importcopyimportitertoolsimportwarningsimporttorchimporttorch.nnasnnimporttorch.nn.quantizedasnnqfromtorch.nn.intrinsicimport_FusedModulefromtorch.quantization.quantization_mappingsimport(get_default_dynamic_quant_module_mappings,get_default_static_quant_module_mappings,get_default_qat_module_mappings,get_default_qconfig_propagation_list,no_observer_set,_has_special_act_post_process,_get_special_act_post_process,)fromtorch.ao.quantization.stubsimportDeQuantStub,QuantWrapperfromtorch.quantization.qconfigimport(add_module_to_qconfig_obs_ctr,default_dynamic_qconfig,float16_dynamic_qconfig,float_qparams_weight_only_qconfig)defis_activation_post_process(module):return(isinstance(module,torch.quantization.ObserverBase)orisinstance(module,torch.quantization.FakeQuantizeBase))def_propagate_qconfig_helper(module,qconfig_dict,allow_list=None,qconfig_parent=None,prefix=''):r"""This is a helper function for `propagate_qconfig_` Args: module: input module qconfig_dict: dictionary that maps from name of submodule to quantization configuration allow_list: list of quantizable modules qconfig_parent: quantization config of parent module, we will fallback to this config when there is no specified config for current module prefix: corresponding prefix of the current module, used as key in qconfig_dict Return: None, module is modified inplace with qconfig attached """# TODO: Add testifallow_listisNone:allow_list=get_default_qconfig_propagation_list()module_qconfig=qconfig_dict.get(type(module),qconfig_parent)module_qconfig=qconfig_dict.get(prefix,module_qconfig)module_qconfig=getattr(module,'qconfig',module_qconfig)torch.quantization.qconfig.assert_valid_qconfig(module_qconfig,module)qconfig_with_device_check=add_module_to_qconfig_obs_ctr(module_qconfig,module)module.qconfig=qconfig_with_device_checkforname,childinmodule.named_children():module_prefix=prefix+'.'+nameifprefixelsename_propagate_qconfig_helper(child,qconfig_dict,allow_list,qconfig_with_device_check,module_prefix)# TODO(jerryzh): expose allow_list
[docs]defpropagate_qconfig_(module,qconfig_dict=None,allow_list=None):r"""Propagate qconfig through the module hierarchy and assign `qconfig` attribute on each leaf module Args: module: input module qconfig_dict: dictionary that maps from name or type of submodule to quantization configuration, qconfig applies to all submodules of a given module unless qconfig for the submodules are specified (when the submodule already has qconfig attribute) Return: None, module is modified inplace with qconfig attached """ifqconfig_dictisNone:qconfig_dict={}_propagate_qconfig_helper(module,qconfig_dict,allow_list)
def_observer_forward_hook(self,input,output):r"""Forward hook that calls observer on the output """returnself.activation_post_process(output)defregister_activation_post_process_hook(module):asserthasattr(module,'activation_post_process'), \
'Expect activation_post_process attribut already attached to the module'returnmodule.register_forward_hook(_observer_forward_hook)
[docs]defadd_observer_(module,qconfig_propagation_list=None,non_leaf_module_list=None,device=None,custom_module_class_mapping=None):r"""Add observer for the leaf child of the module. This function insert observer module to all leaf child module that has a valid qconfig attribute. Args: module: input module with qconfig attributes for all the leaf modules that we want to quantize device: parent device, if any non_leaf_module_list: list of non-leaf modules we want to add observer Return: None, module is modified inplace with added observer modules and forward_hooks """ifqconfig_propagation_listisNone:qconfig_propagation_list=get_default_qconfig_propagation_list()ifcustom_module_class_mappingisNone:custom_module_class_mapping={}# respect device affinity when adding observersifdeviceisNone:devices=get_unique_devices_(module)assertlen(devices)<=1,("add_observer_ only works with cpu or single-device CUDA modules, ""but got devices {}".format(devices))device=next(iter(devices))iflen(devices)>0elseNonedefget_activation_post_process(qconfig,device,special_act_post_process=None):activation=qconfig.activation()ifspecial_act_post_processisNoneelsespecial_act_post_process()ifdeviceisnotNone:activation.to(device)returnactivationdefneeds_observation(m):returnhasattr(m,'qconfig')andm.qconfigisnotNonedefinsert_activation_post_process(m,special_act_post_process=None):""" Adds an activation post process module and register a post hook that calls the module """# We don't insert observer/fake_quantize for DeQuantStubifneeds_observation(m)andnotisinstance(m,DeQuantStub):# observer and hook will be gone after we swap the modulem.add_module('activation_post_process',get_activation_post_process(m.qconfig,device,special_act_post_process))# Register observer as the first entry in the hook list# All post forward hooks are preserved and will be executed after the observer before converthandle=register_activation_post_process_hook(m)m._forward_hooks.move_to_end(handle.id,last=False)forname,childinmodule.named_children():iftype(child)in[nnq.FloatFunctional,nnq.QFunctional]:ifneeds_observation(child):child.activation_post_process=get_activation_post_process(child.qconfig,device)elifisinstance(child,_FusedModule):# activation_post_process are now added directly to nn.Sequentail/_FusedModuleifneeds_observation(child):insert_activation_post_process(child)elif_has_special_act_post_process(child):special_act_post_process=_get_special_act_post_process(child)insert_activation_post_process(child,special_act_post_process)elifnon_leaf_module_listisnotNoneandtype(child)innon_leaf_module_list:ifneeds_observation(child):insert_activation_post_process(child)elifneeds_observation(child)andtype(child)incustom_module_class_mapping:observed_child=custom_module_class_mapping[type(child)].from_float(child)setattr(module,name,observed_child)# TODO: These are the modules that cannot be observed# Once there are more, we should move them to a separate listifcustom_module_class_mapping[type(child)]notinno_observer_set():insert_activation_post_process(observed_child)else:add_observer_(child,qconfig_propagation_list,non_leaf_module_list,device,custom_module_class_mapping)# Insert observers only for leaf nodes, note that this observer is for# the output of the module, for input QuantStub will observe themiflen(module._modules)==0andnotisinstance(module,torch.nn.Sequential) \
andtype(module)inqconfig_propagation_list:insert_activation_post_process(module)
[docs]defadd_quant_dequant(module):r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig Note that this function will modify the children of module inplace and it can return a new module which wraps the input module as well. Args: module: input module with qconfig attributes for all the leaf modules that we want to quantize Return: Either the inplace modified module with submodules wrapped in `QuantWrapper` based on qconfig or a new `QuantWrapper` module which wraps the input module, the latter case only happens when the input module is a leaf module and we want to quantize it. """iflen(module._modules)==0andhasattr(module,'qconfig')andmodule.qconfig:returnQuantWrapper(module)forname,childinmodule.named_children():module._modules[name]=add_quant_dequant(child)returnmodule
[docs]defprepare(model,inplace=False,allow_list=None,observer_non_leaf_module_list=None,prepare_custom_config_dict=None):r"""Prepares a copy of the model for quantization calibration or quantization-aware training. Quantization configuration should be assigned preemptively to individual submodules in `.qconfig` attribute. The model will be attached with observer or fake quant modules, and qconfig will be propagated. Args: `model`: input model to be modified in-place `inplace`: carry out model transformations in-place, the original module is mutated `allow_list`: list of quantizable modules `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer `prepare_custom_config_dict`: customization configuration dictionary for prepare function .. code-block:: python # Example of prepare_custom_config_dict: prepare_custom_config_dict = { # user will manually define the corresponding observed # module class which has a from_float class method that converts # float custom module to observed custom module "float_to_observed_custom_module_class": { CustomModule: ObservedCustomModule } } """torch._C._log_api_usage_once("quantization_api.quantize.prepare")ifprepare_custom_config_dictisNone:prepare_custom_config_dict={}custom_module_class_mapping=prepare_custom_config_dict.get("float_to_observed_custom_module_class",{})ifnotinplace:model=copy.deepcopy(model)# TODO: remove allow_listqconfig_propagation_list=allow_listifqconfig_propagation_listisNone:qconfig_propagation_list=get_default_qconfig_propagation_list()propagate_qconfig_(model,qconfig_dict=None)# sanity check common API misusageifnotany(hasattr(m,'qconfig')andm.qconfigforminmodel.modules()):warnings.warn("None of the submodule got qconfig applied. Make sure you ""passed correct configuration through `qconfig_dict` or ""by assigning the `.qconfig` attribute directly on submodules")add_observer_(model,qconfig_propagation_list,observer_non_leaf_module_list,custom_module_class_mapping=custom_module_class_mapping)returnmodel
def_remove_activation_post_process(module):# TODO: maybe we should change activation_post_process to _activation_post_process# to prevent it from being used by userifhasattr(module,'activation_post_process')and \
is_activation_post_process(module.activation_post_process):delattr(module,'activation_post_process')# remove activation_post_proceess hookhandle_ids_to_remove=set()forhandle_id,hook_fninmodule._forward_hooks.items():ifhook_fnis_observer_forward_hook:handle_ids_to_remove.add(handle_id)forhandle_idinhandle_ids_to_remove:module._forward_hooks.pop(handle_id)# TODO: rename to something more generaldef_remove_qconfig(module):r"""Clean up the qconfig left in the module so that new qconfig can be propagated. Args: module: module to be cleaned up """forchildinmodule.children():_remove_qconfig(child)ifhasattr(module,"qconfig"):delmodule.qconfig_remove_activation_post_process(module)
[docs]defquantize(model,run_fn,run_args,mapping=None,inplace=False):r"""Quantize the input float model with post training static quantization. First it will prepare the model for calibration, then it calls `run_fn` which will run the calibration step, after that we will convert the model to a quantized model. Args: model: input float model run_fn: a calibration function for calibrating the prepared model run_args: positional arguments for `run_fn` inplace: carry out model transformations in-place, the original module is mutated mapping: correspondence between original module types and quantized counterparts Return: Quantized model. """torch._C._log_api_usage_once("quantization_api.quantize.quantize")ifmappingisNone:mapping=get_default_static_quant_module_mappings()ifnotinplace:model=copy.deepcopy(model)model.eval()prepare(model,inplace=True)run_fn(model,*run_args)convert(model,mapping,inplace=True)returnmodel
[docs]defquantize_dynamic(model,qconfig_spec=None,dtype=torch.qint8,mapping=None,inplace=False):r"""Converts a float model to dynamic (i.e. weights-only) quantized model. Replaces specified modules with dynamic weight-only quantized versions and output the quantized model. For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization by default is performed for layers with large weights size - i.e. Linear and RNN variants. Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`. If `qconfig` is provided, the `dtype` argument is ignored. Args: model: input model qconfig_spec: Either: - A dictionary that maps from name or type of submodule to quantization configuration, qconfig applies to all submodules of a given module unless qconfig for the submodules are specified (when the submodule already has qconfig attribute). Entries in the dictionary need to be QConfigDynamic instances. - A set of types and/or submodule names to apply dynamic quantization to, in which case the `dtype` argument is used to specify the bit-width inplace: carry out model transformations in-place, the original module is mutated mapping: maps type of a submodule to a type of corresponding dynamically quantized version with which the submodule needs to be replaced """torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic")ifqconfig_specisNone:ifdtype==torch.qint8:qconfig_spec={nn.Linear:default_dynamic_qconfig,nn.LSTM:default_dynamic_qconfig,nn.GRU:default_dynamic_qconfig,nn.LSTMCell:default_dynamic_qconfig,nn.RNNCell:default_dynamic_qconfig,nn.GRUCell:default_dynamic_qconfig,}elifdtype==torch.float16:qconfig_spec={nn.Linear:float16_dynamic_qconfig,nn.LSTM:float16_dynamic_qconfig,nn.GRU:float16_dynamic_qconfig,nn.LSTMCell:float16_dynamic_qconfig,nn.RNNCell:float16_dynamic_qconfig,nn.GRUCell:float16_dynamic_qconfig,}elifdtype==torch.quint8:qconfig_spec={nn.EmbeddingBag:float_qparams_weight_only_qconfig,}else:raiseValueError("Don't know how to quantize with default settings for {}. Provide full qconfig please".format(dtype))elifisinstance(qconfig_spec,set):ifdtypeistorch.qint8:default_qconfig=default_dynamic_qconfigelifdtypeistorch.float16:default_qconfig=float16_dynamic_qconfigelifdtypeistorch.quint8:default_qconfig=float_qparams_weight_only_qconfigelse:raiseRuntimeError('Unknown dtype specified for quantize_dynamic: ',str(dtype))qconfig_spec=dict(zip(qconfig_spec,itertools.repeat(default_qconfig)))ifmappingisNone:mapping=get_default_dynamic_quant_module_mappings()ifnotinplace:model=copy.deepcopy(model)model.eval()propagate_qconfig_(model,qconfig_spec)convert(model,mapping,inplace=True)returnmodel
[docs]defprepare_qat(model,mapping=None,inplace=False):r""" Prepares a copy of the model for quantization calibration or quantization-aware training and converts it to quantized version. Quantization configuration should be assigned preemptively to individual submodules in `.qconfig` attribute. Args: model: input model to be modified in-place mapping: dictionary that maps float modules to quantized modules to be replaced. inplace: carry out model transformations in-place, the original module is mutated """torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat")ifmappingisNone:mapping=get_default_qat_module_mappings()ifnotinplace:model=copy.deepcopy(model)propagate_qconfig_(model,qconfig_dict=None)convert(model,mapping=mapping,inplace=True,remove_qconfig=False)prepare(model,observer_non_leaf_module_list=set(mapping.values()),inplace=True)returnmodel
[docs]defquantize_qat(model,run_fn,run_args,inplace=False):r"""Do quantization aware training and output a quantized model Args: model: input model run_fn: a function for evaluating the prepared model, can be a function that simply runs the prepared model or a training loop run_args: positional arguments for `run_fn` Return: Quantized model. """torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat")ifnotinplace:model=copy.deepcopy(model)model.train()prepare_qat(model,inplace=True)run_fn(model,*run_args)convert(model,inplace=True)returnmodel
[docs]defconvert(module,mapping=None,inplace=False,remove_qconfig=True,convert_custom_config_dict=None):r"""Converts submodules in input module to a different module according to `mapping` by calling `from_float` method on the target module class. And remove qconfig at the end if remove_qconfig is set to True. Args: `module`: prepared and calibrated module `mapping`: a dictionary that maps from source module type to target module type, can be overwritten to allow swapping user defined Modules `inplace`: carry out model transformations in-place, the original module is mutated `convert_custom_config_dict`: custom configuration dictionary for convert function .. code-block:: python # Example of convert_custom_config_dict: convert_custom_config_dict = { # user will manually define the corresponding quantized # module class which has a from_observed class method that converts # observed custom module to quantized custom module "observed_to_quantized_custom_module_class": { ObservedCustomModule: QuantizedCustomModule } } """torch._C._log_api_usage_once("quantization_api.quantize.convert")ifnotinplace:module=copy.deepcopy(module)_convert(module,mapping,inplace=True,convert_custom_config_dict=convert_custom_config_dict)ifremove_qconfig:_remove_qconfig(module)returnmodule
def_convert(module,mapping=None,inplace=False,convert_custom_config_dict=None):r"""Converts submodules in input module to a different module according to `mapping` by calling `from_float` method on the target module class Args: module: input module mapping: a dictionary that maps from source module type to target module type, can be overwritten to allow swapping user defined Modules inplace: carry out model transformations in-place, the original module is mutated """ifmappingisNone:mapping=get_default_static_quant_module_mappings()ifconvert_custom_config_dictisNone:convert_custom_config_dict={}custom_module_class_mapping=convert_custom_config_dict.get("observed_to_quantized_custom_module_class",{})ifnotinplace:module=copy.deepcopy(module)reassign={}forname,modinmodule.named_children():# both fused modules and observed custom modules are# swapped as one unitifnotisinstance(mod,_FusedModule)and \
type(mod)notincustom_module_class_mapping:_convert(mod,mapping,True,# inplaceconvert_custom_config_dict)reassign[name]=swap_module(mod,mapping,custom_module_class_mapping)forkey,valueinreassign.items():module._modules[key]=valuereturnmodule
[docs]defswap_module(mod,mapping,custom_module_class_mapping):r"""Swaps the module if it has a quantized counterpart and it has an `observer` attached. Args: mod: input module mapping: a dictionary that maps from nn module to nnq module Return: The corresponding quantized module of `mod` """new_mod=modifhasattr(mod,'qconfig')andmod.qconfigisnotNone:swapped=Falseiftype(mod)incustom_module_class_mapping:new_mod=custom_module_class_mapping[type(mod)].from_observed(mod)swapped=Trueeliftype(mod)inmapping:new_mod=mapping[type(mod)].from_float(mod)swapped=Trueifswapped:# Preserve module's pre forward hooks. They'll be called on quantized inputforpre_hook_fninmod._forward_pre_hooks.values():new_mod.register_forward_pre_hook(pre_hook_fn)# Preserve module's post forward hooks except _observer_forward_hook# After convert they'll work with quantized outputforhook_fninmod._forward_hooks.values():ifhook_fnisnot_observer_forward_hook:new_mod.register_forward_hook(hook_fn)# respect device affinity when swapping modulesdevices=get_unique_devices_(mod)assertlen(devices)<=1,("swap_module only works with cpu or single-device CUDA modules, ""but got devices {}".format(devices))device=next(iter(devices))iflen(devices)>0elseNoneifdevice:new_mod.to(device)returnnew_mod
[docs]defget_observer_dict(mod,target_dict,prefix=""):r"""Traverse the modules and save all observers into dict. This is mainly used for quantization accuracy debug Args: mod: the top module we want to save all observers prefix: the prefix for the current module target_dict: the dictionary used to save all the observers """defget_prefix(prefix):returnprefixifprefix==""elseprefix+'.'ifhasattr(mod,'activation_post_process'):target_dict[get_prefix(prefix)+'activation_post_process']=mod.activation_post_processforname,childinmod.named_children():module_prefix=get_prefix(prefix)+nameifprefixelsenameget_observer_dict(child,target_dict,module_prefix)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.