[docs]classQConfig(namedtuple('QConfig',['activation','weight'])):""" Describes how to quantize a layer or a part of the network by providing settings (observer classes) for activations and weights respectively. Note that QConfig needs to contain observer **classes** (like MinMaxObserver) or a callable that returns instances on invocation, not the concrete observer instances themselves. Quantization preparation function will instantiate observers multiple times for each of the layers. Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args` method (that behaves like functools.partial):: my_qconfig = QConfig( activation=MinMaxObserver.with_args(dtype=torch.qint8), weight=default_observer.with_args(dtype=torch.qint8)) """def__new__(cls,activation,weight):# catch common mistakesifisinstance(activation,nn.Module)orisinstance(weight,nn.Module):raiseValueError("QConfig received observer instance, please pass observer class instead. "+"Use MyObserver.with_args(x=1) to override arguments to constructor if needed")returnsuper().__new__(cls,activation,weight)
classQConfigDynamic(namedtuple('QConfigDynamic',['activation','weight'])):""" Describes how to dynamically quantize a layer or a part of the network by providing settings (observer classes) for weights. It's like QConfig, but for dynamic quantization. Note that QConfigDynamic needs to contain observer **classes** (like MinMaxObserver) or a callable that returns instances on invocation, not the concrete observer instances themselves. Quantization function will instantiate observers multiple times for each of the layers. Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args` method (that behaves like functools.partial):: my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8)) """def__new__(cls,activation=torch.nn.Identity,weight=torch.nn.Identity):# catch common mistakesifisinstance(weight,nn.Module):raiseValueError("QConfigDynamic received observer instance, please pass observer class instead. "+"Use MyObserver.with_args(x=1) to override arguments to constructor if needed")warnings.warn("QConfigDynamic is going to be deprecated in PyTorch 1.12, please use QConfig instead")returnsuper().__new__(cls,activation,weight)default_qconfig=QConfig(activation=default_observer,weight=default_weight_observer)"""Default qconfig configuration."""default_debug_qconfig=QConfig(weight=default_weight_observer,activation=default_debug_observer)"""Default qconfig configuration for debugging."""default_per_channel_qconfig=QConfig(activation=default_observer,weight=default_per_channel_weight_observer)"""Default qconfig configuration for per channel weight quantization."""default_dynamic_qconfig=QConfig(activation=default_dynamic_quant_observer,weight=default_weight_observer)"""Default dynamic qconfig."""float16_dynamic_qconfig=QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float16,is_dynamic=True),weight=PlaceholderObserver.with_args(dtype=torch.float16))"""Dynamic qconfig with weights quantized to `torch.float16`."""float16_static_qconfig=QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float16),weight=PlaceholderObserver.with_args(dtype=torch.float16))"""Dynamic qconfig with both activations and weights quantized to `torch.float16`."""per_channel_dynamic_qconfig=QConfig(activation=default_dynamic_quant_observer,weight=default_per_channel_weight_observer)"""Dynamic qconfig with weights quantized per channel."""float_qparams_weight_only_qconfig=QConfig(activation=default_placeholder_observer,weight=default_float_qparams_observer)"""Dynamic qconfig with weights quantized with a floating point zero_point."""float_qparams_weight_only_qconfig_4bit=QConfig(activation=default_placeholder_observer,weight=default_float_qparams_observer_4bit)default_qat_qconfig=QConfig(activation=default_fake_quant,weight=default_weight_fake_quant)"""Default qconfig for QAT."""default_dynamic_qat_qconfig=QConfig(activation=default_dynamic_fake_quant,weight=default_weight_fake_quant)"""Default qconfig for dynamic QAT."""default_weight_only_qconfig=QConfig(activation=torch.nn.Identity,weight=default_weight_fake_quant)"""Default qconfig for quantizing weights only."""default_activation_only_qconfig=QConfig(activation=default_fake_quant,weight=torch.nn.Identity)"""Default qconfig for quantizing activations only."""# QAT config that uses a fused observer + fake quant modules for optimized training performance.# to modify the activation/weight observers, the default entries in fake_quantize.py can be modified.default_qat_qconfig_v2=QConfig(activation=default_fused_act_fake_quant,weight=default_fused_wt_fake_quant)"""Fused version of `default_qat_config`, has performance benefits."""default_reuse_input_qconfig=QConfig(activation=default_reuse_input_observer,weight=NoopObserver)"""Default qconfig for operators that reuse the observers from input Tensor, e.g. reshape"""defget_default_qconfig(backend='x86',version=0):""" Returns the default PTQ qconfig for the specified backend. Args: * `backend` (str): a string representing the target backend. Currently supports `x86` (default), `fbgemm`, `qnnpack` and `onednn`. Return: qconfig """supported_backends=["fbgemm","x86","qnnpack","onednn"]ifbackendnotinsupported_backends:raiseAssertionError("backend: "+str(backend)+f" not supported. backend must be one of {supported_backends}")ifversion==0:ifbackend=='fbgemm':qconfig=QConfig(activation=HistogramObserver.with_args(reduce_range=True),weight=default_per_channel_weight_observer)elifbackend=='qnnpack':# TODO: make this compatible with xnnpack constraintsqconfig=QConfig(activation=HistogramObserver.with_args(reduce_range=False),weight=default_weight_observer)elifbackend=='onednn':ifnottorch.cpu._is_cpu_support_vnni():warnings.warn("Default qconfig of oneDNN backend with reduce_range of false may have accuracy issues ""on CPU without Vector Neural Network Instruction support.")qconfig=QConfig(activation=HistogramObserver.with_args(reduce_range=False),weight=default_per_channel_weight_observer)elifbackend=='x86':qconfig=QConfig(activation=HistogramObserver.with_args(reduce_range=True),weight=default_per_channel_weight_observer)else:# won't reachqconfig=default_qconfigelse:raiseAssertionError("Version number: "+str(version)+" in get_default_qconfig is not supported. Version number must be 0")returnqconfig"""Default, symmetric PTQ qconfig for the specified backend. And a per_channelvariant of the same.Symmetric here applies to signed weights with zero point = 0, and additionalvalue restrictions. The activations are also signed 8-bit integers with thisqconfig. * Once this change is merged [as of 3/17/22], with backend or qengine = 'qnnpack', some quantized operators with this symmetric qconfig may use operators from xnnpack library. ** Support to use xnnpack ops with `qnnpack` backed for asymmetric qconfig (returned by get_default_qconfig()) is not available yet. * This qconfig uses signed activations and weights. Weights have added restrictions such as zero point is forced to be 0, making the weights symmetric, hence the name. And the 8-bit quantized values are restricting to to [-127, +127], excluding -128. * xnnpack has a requantization scale value restriction, 0x1p-32 <= requantization_scale < 256.0 where, `requantization_scale = (input_scale * kernel_scale) / (output_scale)`. Using this eps (w/ assumed max value of 256) is to prevent requantization_scale to go below xnnpack lower threshold."""default_symmetric_qnnpack_qconfig=QConfig(activation=HistogramObserver.with_args(dtype=torch.qint8,reduce_range=False,eps=2**-12),weight=weight_observer_range_neg_127_to_127)default_per_channel_symmetric_qnnpack_qconfig=QConfig(activation=HistogramObserver.with_args(dtype=torch.qint8,reduce_range=False,eps=2**-12),weight=per_channel_weight_observer_range_neg_127_to_127)default_embedding_qat_qconfig=QConfig(activation=NoopObserver.with_args(dtype=torch.float32),weight=default_embedding_fake_quant)default_embedding_qat_qconfig_4bit=QConfig(activation=NoopObserver.with_args(dtype=torch.float32),weight=default_embedding_fake_quant_4bit)default_quint8_weight_qconfig=QConfig(activation=HistogramObserver,weight=MinMaxObserver)defget_default_qat_qconfig(backend='x86',version=1):""" Returns the default QAT qconfig for the specified backend. Args: * `backend` (str): a string representing the target backend. Currently supports `x86` (default), `fbgemm`, `qnnpack` and `onednn`. * `version`: version, for backwards compatibility. Can be `None` or `1`. Return: qconfig """supported_backends=["fbgemm","x86","qnnpack","onednn"]ifbackendnotinsupported_backends:raiseAssertionError("backend: "+str(backend)+f" not supported. backend must be one of {supported_backends}")# Histogram observer is too slow for quantization aware trainingifversion==0:ifbackend=='fbgemm':qconfig=QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,quant_min=0,quant_max=255,reduce_range=True),weight=default_per_channel_weight_fake_quant)elifbackend=='qnnpack':qconfig=QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,quant_min=0,quant_max=255,reduce_range=False),weight=default_weight_fake_quant)elifbackend=='onednn':qconfig=QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,quant_min=0,quant_max=255),weight=default_per_channel_weight_fake_quant)elifbackend=='x86':qconfig=QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,quant_min=0,quant_max=255,reduce_range=True),weight=default_per_channel_weight_fake_quant)else:qconfig=default_qat_qconfig# Use the fused observe + fake_quant modules for doing QAT.elifversion==1:ifbackend=='fbgemm':qconfig=QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,quant_min=0,quant_max=255,reduce_range=True),weight=default_fused_per_channel_wt_fake_quant)elifbackend=='qnnpack':# TODO: make this compatible with xnnpack constraintsqconfig=QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,quant_min=0,quant_max=255,reduce_range=False),weight=default_fused_wt_fake_quant)elifbackend=='onednn':qconfig=QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,quant_min=0,quant_max=255),weight=default_fused_per_channel_wt_fake_quant)elifbackend=='x86':qconfig=QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,quant_min=0,quant_max=255,reduce_range=True),weight=default_fused_per_channel_wt_fake_quant)else:qconfig=default_qat_qconfig_v2else:raiseAssertionError("Version number: "+str(version)+"in get_default_qat_qconfig is not supported. Version number must be 0 or 1")returnqconfig"""Default symmetric QAT qconfig for qnnpack. And its per channel weight variant."""default_symmetric_qnnpack_qat_qconfig=QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,quant_min=-128,quant_max=127,dtype=torch.qint8,reduce_range=False,eps=2**-12),weight=fused_wt_fake_quant_range_neg_127_to_127)default_per_channel_symmetric_qnnpack_qat_qconfig=QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,quant_min=-128,quant_max=127,dtype=torch.qint8,reduce_range=False,eps=2**-12),weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127)_default_fp32_placeholder_qconfig=QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float32),weight=PlaceholderObserver.with_args(dtype=torch.float32))_default_quint8_placeholder_qconfig=QConfig(activation=PlaceholderObserver.with_args(dtype=torch.quint8),# operators using this qconfig doesn't have weightsweight=None,)defget_default_qconfig_dict(backend='x86',version=0):warnings.warn("torch.ao.quantization.get_default_qconfig_dict is deprecated and will be removed in ""a future version. Please use torch.ao.quantization.get_default_qconfig_mapping instead.")returntorch.ao.quantization.get_default_qconfig_mapping(backend,version).to_dict()defget_default_qat_qconfig_dict(backend='x86',version=1):warnings.warn("torch.ao.quantization.get_default_qat_qconfig_dict is deprecated and will be removed in ""a future version. Please use torch.ao.quantization.get_default_qat_qconfig_mapping instead.")returntorch.ao.quantization.get_default_qat_qconfig_mapping(backend,version).to_dict()def_assert_valid_qconfig(qconfig:Optional[QConfig],mod:torch.nn.Module)->None:""" Verifies that this `qconfig` is valid. """ifqconfigisNone:returnis_conv_transpose_mod=(isinstance(mod,(torch.nn.ConvTranspose1d,torch.nn.ConvTranspose2d,torch.nn.ConvTranspose3d)))ifis_conv_transpose_mod:ifqconfig.weightisNone:# for now, we assume that any qconfig for ConvTranspose without a weight is validreturnexample_observer=qconfig.weight()is_per_channel=(isinstance(example_observer,(torch.ao.quantization.PerChannelMinMaxObserver,torch.ao.quantization.MovingAveragePerChannelMinMaxObserver)))assertnotis_per_channel, \
'Per channel weight observer is not supported yet for ConvTranspose{n}d.'QConfigAny=Optional[QConfig]QConfigAny.__module__="torch.ao.quantization.qconfig"def_add_module_to_qconfig_obs_ctr(qconfig:QConfigAny,module:Optional[nn.Module])->Any:r"""This is a helper function for use in quantization prepare that updates a qconfig so that the constructors stored in the qconfig will create observers on the same device that 'module' is on. This is intended to be used when the qconfigs are propagated to each module in order to avoid potential device alignment issues. Args: qconfig: QConfig with obs constructors stored in activation and weight module: module which the qconfig is related to Return: qconfig: configured so that obs constructors set to construct on the same device as module """ifmoduleisNoneorqconfigisNoneorqconfig._fields!=('activation','weight'):returnqconfigdefget_factory_kwargs_based_on_module_device():assertisinstance(module,torch.nn.Module)devices={p.deviceforpinmodule.parameters()}| \
{p.deviceforpinmodule.buffers()}device=next(iter(devices))iflen(devices)>0elseNonereturnNoneifdeviceisNoneelse{'device':device}defconfigure_constructor_to_put_obs_on_module_device(original_constructor):try:# check if constructor can accept factory_kwargscheck=original_constructor.with_args(factory_kwargs=None)check()returnoriginal_constructor.with_callable_args(factory_kwargs=get_factory_kwargs_based_on_module_device)exceptAttributeError:# qconfig doesn't have activation or weightreturnoriginal_constructorexceptTypeError:# the class doesn't accept factory_kwargs argumentreturnoriginal_constructoractivation=configure_constructor_to_put_obs_on_module_device(qconfig.activation)weight=configure_constructor_to_put_obs_on_module_device(qconfig.weight)returnQConfig(activation,weight)_ObserverOrFakeQuantizeConstructor=Union[_PartialWrapper,Type[ObserverBase],Type[FakeQuantizeBase]]def_obs_or_fq_ctr_equals(obs_or_fq1:_ObserverOrFakeQuantizeConstructor,obs_or_fq2:_ObserverOrFakeQuantizeConstructor):ifisinstance(obs_or_fq1,_PartialWrapper)andisinstance(obs_or_fq2,_PartialWrapper):return_partial_wrapper_equals(obs_or_fq1,obs_or_fq2)returnobs_or_fq1==obs_or_fq2def_partial_wrapper_equals(obs_or_fq1:_PartialWrapper,obs_or_fq2:_PartialWrapper):""" Return whether the two partial wrappers are equal, """# functools.partial has no __eq__ operator defined so '==' defaults to 'is'obs_or_fq1_keywords=copy.copy(obs_or_fq1.p.keywords)obs_or_fq2_keywords=copy.copy(obs_or_fq2.p.keywords)keywords_equal=True# compare observer constructor with _obs_or_fq_ctr_equals since direct compare would failif"observer"inobs_or_fq1_keywordsand"observer"inobs_or_fq2_keywords:keywords_equal=keywords_equaland_obs_or_fq_ctr_equals(obs_or_fq1_keywords["observer"],obs_or_fq2_keywords["observer"])obs_or_fq1_keywords.pop("observer")obs_or_fq2_keywords.pop("observer")keywords_equal=keywords_equalandobs_or_fq1_keywords==obs_or_fq2_keywordsreturnobs_or_fq1.p.func==obs_or_fq2.p.funcandobs_or_fq1.p.args==obs_or_fq2.p.argsandkeywords_equaldefqconfig_equals(q1:QConfigAny,q2:QConfigAny):""" Returns `True` if `q1` equals `q2`, and `False` otherwise. """ifq1isNoneorq2isNone:returnq1==q2else:assertq1isnotNoneandq2isnotNonetry:# Qconfig weight and activation can be either a partial wrapper,# or an observer class. Special handling is required (above) for# comparing partial wrappers.activation_same=_obs_or_fq_ctr_equals(q1.activation,q2.activation)weight_same=_obs_or_fq_ctr_equals(q1.weight,q2.weight)returnactivation_sameandweight_sameexceptAttributeError:returnq1==q2def_activation_is_memoryless(qconfig:QConfig):""" Return whether the observer for activations defined in the given QConfig is memoryless. This means a MovingAverage observer with averaging constant equal to 1. """def_is_memoryless(observer):returnhasattr(observer,"averaging_constant")andobserver.averaging_constant==1act=qconfig.activation()ifisinstance(act,FakeQuantizeBase)andhasattr(act,"activation_post_process"):return_is_memoryless(act.activation_post_process)else:return_is_memoryless(act)def_is_reuse_input_qconfig(qconfig:Optional[QConfig]):returnqconfigisnotNoneand \
isinstance(qconfig.activation(),ReuseInputObserver)and \
isinstance(qconfig.weight(),NoopObserver)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.