Source code for torch.ao.quantization.backend_config.backend_config
from__future__importannotationsfromdataclassesimportdataclassfromtypingimportAny,Callable,Dict,List,Optional,Type,Unionimporttorchfromtorch.ao.quantization.observerimport_PartialWrapperfromtorch.ao.quantization.utilsimportPatternfromenumimportEnum__all__=["BackendConfig","BackendPatternConfig","DTypeConfig","DTypeWithConstraints","ObservationType",]# DTypeConfig dict keysINPUT_DTYPE_DICT_KEY="input_dtype"OUTPUT_DTYPE_DICT_KEY="output_dtype"WEIGHT_DTYPE_DICT_KEY="weight_dtype"BIAS_DTYPE_DICT_KEY="bias_dtype"IS_DYNAMIC_DICT_KEY="is_dynamic"# BackendConfig dict keysNAME_DICT_KEY="name"CONFIGS_DICT_KEY="configs"# BackendPatternConfig dict keysPATTERN_DICT_KEY="pattern"OBSERVATION_TYPE_DICT_KEY="observation_type"DTYPE_CONFIGS_DICT_KEY="dtype_configs"ROOT_MODULE_DICT_KEY="root_module"QAT_MODULE_DICT_KEY="qat_module"REFERENCE_QUANTIZED_MODULE_DICT_KEY="reference_quantized_module_for_root"FUSED_MODULE_DICT_KEY="fused_module"FUSER_METHOD_DICT_KEY="fuser_method"ROOT_NODE_GETTER_DICT_KEY="root_node_getter"EXTRA_INPUTS_GETTER_DICT_KEY="extra_inputs_getter"NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY="num_tensor_args_to_observation_type"INPUT_TYPE_TO_INDEX_DICT_KEY="input_type_to_index"INPUT_OUTPUT_OBSERVED_DICT_KEY="input_output_observed"OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY="overwrite_output_fake_quantize"OVERWRITE_OUTPUT_OBSERVER_DICT_KEY="overwrite_output_observer"# TODO: maybe rename this to something that's not related to observer# e.g. QParamsType
[docs]classObservationType(Enum):""" An enum that represents different ways of how an operator/operator pattern should be observed """OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT=0"""this means input and output are observed with different observers, based on qconfig.activation example: conv, linear, softmax """OUTPUT_SHARE_OBSERVER_WITH_INPUT=1"""this means the output will use the same observer instance as input, based on qconfig.activation example: torch.cat, maxpool """
@dataclassclassDTypeWithConstraints:""" Config for specifying additional constraints for a given dtype, such as quantization value ranges and scale value ranges, to be used in :class:`~torch.ao.quantization.backend_config.DTypeConfig`. """dtype:Optional[torch.dtype]=Nonequant_min_lower_bound:Union[int,float,None]=Nonequant_max_upper_bound:Union[int,float,None]=Nonescale_min_lower_bound:Union[int,float,None]=Nonescale_max_upper_bound:Union[int,float,None]=None
[docs]@dataclassclassDTypeConfig:""" Config for the set of supported input/output activation, weight, and bias data types for the patterns defined in :class:`~torch.ao.quantization.backend_config.BackendConfig`. Example usage:: >>> dtype_config1 = DTypeConfig( ... input_dtype=torch.quint8, ... output_dtype=torch.quint8, ... weight_dtype=torch.qint8, ... bias_dtype=torch.float) >>> dtype_config2 = DTypeConfig( ... input_dtype=DTypeWithConstraints( ... dtype=torch.quint8, ... quant_min_lower_bound=0, ... quant_max_upper_bound=255, ... ), ... output_dtype=DTypeWithConstraints( ... dtype=torch.quint8, ... quant_min_lower_bound=0, ... quant_max_upper_bound=255, ... ), ... weight_dtype=DTypeWithConstraints( ... dtype=torch.qint8, ... quant_min_lower_bound=-128, ... quant_max_upper_bound=127, ... ), ... bias_dtype=torch.float) >>> dtype_config1.input_dtype torch.quint8 >>> dtype_config2.input_dtype torch.quint8 >>> dtype_config2.input_dtype_with_constraints DTypeWithConstraints(dtype=torch.quint8, quant_min_lower_bound=0, quant_max_upper_bound=255, \scale_min_lower_bound=None, scale_max_upper_bound=None) """input_dtype_with_constraints:DTypeWithConstraintsoutput_dtype_with_constraints:DTypeWithConstraintsweight_dtype_with_constraints:DTypeWithConstraintsbias_dtype:Optional[torch.dtype]is_dynamic:Optional[bool]def__init__(self,input_dtype:Union[torch.dtype,DTypeWithConstraints,None]=None,output_dtype:Union[torch.dtype,DTypeWithConstraints,None]=None,weight_dtype:Union[torch.dtype,DTypeWithConstraints,None]=None,bias_dtype:Optional[torch.dtype]=None,is_dynamic:Optional[bool]=None,):ifisinstance(input_dtype,DTypeWithConstraints):self.input_dtype_with_constraints=input_dtypeelse:self.input_dtype_with_constraints=DTypeWithConstraints(dtype=input_dtype)ifisinstance(output_dtype,DTypeWithConstraints):self.output_dtype_with_constraints=output_dtypeelse:self.output_dtype_with_constraints=DTypeWithConstraints(dtype=output_dtype)ifisinstance(weight_dtype,DTypeWithConstraints):self.weight_dtype_with_constraints=weight_dtypeelse:self.weight_dtype_with_constraints=DTypeWithConstraints(dtype=weight_dtype)self.bias_dtype=bias_dtypeself.is_dynamic=is_dynamic@propertydefinput_dtype(self)->Optional[torch.dtype]:returnself.input_dtype_with_constraints.dtype@propertydefoutput_dtype(self)->Optional[torch.dtype]:returnself.output_dtype_with_constraints.dtype@propertydefweight_dtype(self)->Optional[torch.dtype]:returnself.weight_dtype_with_constraints.dtype
[docs]@classmethoddeffrom_dict(cls,dtype_config_dict:Dict[str,Any])->DTypeConfig:""" Create a ``DTypeConfig`` from a dictionary with the following items (all optional): "input_dtype": torch.dtype or ``DTypeWithConstraints`` "output_dtype": torch.dtype or ``DTypeWithConstraints`` "weight_dtype": torch.dtype or ``DTypeWithConstraints`` "bias_type": torch.dtype "is_dynamic": bool """input_dtype=dtype_config_dict.get(INPUT_DTYPE_DICT_KEY,None)ifinput_dtypeisnotNoneandnotisinstance(input_dtype,(torch.dtype,DTypeWithConstraints)):raiseValueError("Expected input_dtype to be a torch.dtype or DTypeWithConstraints")output_dtype=dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY,None)ifoutput_dtypeisnotNoneandnotisinstance(output_dtype,(torch.dtype,DTypeWithConstraints)):raiseValueError("Expected output_dtype to be a torch.dtype or DTypeWithConstraints")weight_dtype=dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY,None)ifweight_dtypeisnotNoneandnotisinstance(weight_dtype,(torch.dtype,DTypeWithConstraints)):raiseValueError("Expected weight_dtype to be a torch.dtype or DTypeWithConstraints")bias_dtype=dtype_config_dict.get(BIAS_DTYPE_DICT_KEY,None)is_dynamic=dtype_config_dict.get(IS_DYNAMIC_DICT_KEY,None)returncls(input_dtype,output_dtype,weight_dtype,bias_dtype,is_dynamic)
[docs]defto_dict(self)->Dict[str,Any]:""" Convert this ``DTypeConfig`` to a dictionary with the items described in :func:`~torch.ao.quantization.backend_config.DTypeConfig.from_dict`. """dtype_config_dict:Dict[str,Any]={}ifself.input_dtypeisnotNone:dtype_config_dict[INPUT_DTYPE_DICT_KEY]=self.input_dtype_with_constraintsifself.output_dtypeisnotNone:dtype_config_dict[OUTPUT_DTYPE_DICT_KEY]=self.output_dtype_with_constraintsifself.weight_dtypeisnotNone:dtype_config_dict[WEIGHT_DTYPE_DICT_KEY]=self.weight_dtype_with_constraintsifself.bias_dtypeisnotNone:dtype_config_dict[BIAS_DTYPE_DICT_KEY]=self.bias_dtypeifself.is_dynamicisnotNone:dtype_config_dict[IS_DYNAMIC_DICT_KEY]=self.is_dynamicreturndtype_config_dict
[docs]classBackendConfig:# TODO: refer to NativeBackendConfig once that is implemented"""Config that defines the set of patterns that can be quantized on a given backend, and how reference quantized models can be produced from these patterns. A pattern in this context refers to a module, a functional, an operator, or a directed acyclic graph of the above. Each pattern supported on the target backend can be individually configured through :class:`~torch.ao.quantization.backend_config.BackendPatternConfig` in terms of: (1) The supported input/output activation, weight, and bias data types (2) How observers and quant/dequant ops are inserted in order to construct the reference pattern, and (3) (Optionally) Fusion, QAT, and reference module mappings. The format of the patterns is described in: https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md Example usage:: import torch from torch.ao.quantization.backend_config import BackendConfig, BackendPatternConfig, DTypeConfig, ObservationType from torch.ao.quantization.fuser_method_mappings import reverse_sequential_wrapper2 weighted_int8_dtype_config = DTypeConfig( input_dtype=torch.quint8, output_dtype=torch.quint8, weight_dtype=torch.qint8, bias_type=torch.float) linear_config = BackendPatternConfig(torch.nn.Linear) \ .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ .add_dtype_config(weighted_int8_dtype_config) \ .set_root_module(torch.nn.Linear) \ .set_qat_module(torch.nn.qat.Linear) \ .set_reference_quantized_module(torch.nn.quantized._reference.Linear) conv_relu_config = BackendPatternConfig((torch.nn.ReLU, torch.nn.Conv2d)) \ .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ .add_dtype_config(weighted_int8_dtype_config) \ .set_fused_module(torch.nn.intrinsic.ConvReLU2d) \ .set_fuser_method(reverse_sequential_wrapper2(torch.nn.intrinsic.ConvReLU2d)) backend_config = BackendConfig("my_backend") \ .set_backend_pattern_config(linear_config) \ .set_backend_pattern_config(conv_relu_config) """def__init__(self,name:str=""):self.name=nameself.configs:Dict[Pattern,BackendPatternConfig]={}
[docs]defset_name(self,name:str)->BackendConfig:""" Set the name of the target backend. """self.name=namereturnself
[docs]defset_backend_pattern_config(self,config:BackendPatternConfig)->BackendConfig:""" Set the config for an pattern that can be run on the target backend. This overrides any existing config for the given pattern. """self.configs[config.pattern]=configreturnself
[docs]defset_backend_pattern_configs(self,configs:List[BackendPatternConfig])->BackendConfig:""" Set the configs for patterns that can be run on the target backend. This overrides any existing config for a given pattern if it was previously registered already. """forconfinconfigs:self.set_backend_pattern_config(conf)returnself
[docs]@classmethoddeffrom_dict(cls,backend_config_dict:Dict[str,Any])->BackendConfig:""" Create a ``BackendConfig`` from a dictionary with the following items: "name": the name of the target backend "configs": a list of dictionaries that each represents a `BackendPatternConfig` """conf=cls(backend_config_dict.get(NAME_DICT_KEY,""))fordinbackend_config_dict.get(CONFIGS_DICT_KEY,[]):ifisinstance(d,BackendPatternConfig):conf.set_backend_pattern_config(d)elifisinstance(d,Dict):conf.set_backend_pattern_config(BackendPatternConfig.from_dict(d))else:raiseValueError("Expected backend_config_dict['%s'] to be a dictionary"%CONFIGS_DICT_KEY)returnconf
[docs]defto_dict(self)->Dict[str,Any]:""" Convert this ``BackendConfig`` to a dictionary with the items described in :func:`~torch.ao.quantization.backend_config.BackendConfig.from_dict`. """return{NAME_DICT_KEY:self.name,CONFIGS_DICT_KEY:[c.to_dict()forcinself.configs.values()],}
[docs]classBackendPatternConfig:""" Config for ops defined in :class:`~torch.ao.quantization.backend_config.BackendConfig`. For a detailed example usage, see :class:`~torch.ao.quantization.backend_config.BackendConfig`. """def__init__(self,pattern:Pattern):self.pattern=patternself.observation_type=ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUTself.dtype_configs:List[DTypeConfig]=[]self.root_module:Optional[Type[torch.nn.Module]]=Noneself.qat_module:Optional[Type[torch.nn.Module]]=Noneself.reference_quantized_module:Optional[Type[torch.nn.Module]]=Noneself.fused_module:Optional[Type[torch.nn.Module]]=Noneself.fuser_method:Optional[Callable]=None# Temporary/internal configsself._root_node_getter:Optional[Callable]=Noneself._extra_inputs_getter:Optional[Callable]=Noneself._num_tensor_args_to_observation_type:Dict[int,ObservationType]={}self._input_type_to_index:Dict[str,int]={}self._input_output_observed:Optional[bool]=Noneself._overwrite_output_fake_quantize:Optional[_PartialWrapper]=Noneself._overwrite_output_observer:Optional[_PartialWrapper]=None
[docs]defset_observation_type(self,observation_type:ObservationType)->BackendPatternConfig:""" Set how observers should be inserted for this pattern. See :class:`~torch.ao.quantization.backend_config.ObservationType` for details """self.observation_type=observation_typereturnself
[docs]defadd_dtype_config(self,dtype_config:DTypeConfig)->BackendPatternConfig:""" Add a set of supported input/output activation, weight, and bias data types for this pattern. """self.dtype_configs.append(dtype_config)returnself
[docs]defset_dtype_configs(self,dtype_configs:List[DTypeConfig])->BackendPatternConfig:""" Set the supported input/output activation, weight, and bias data types for this pattern, overriding all previously registered data types. """self.dtype_configs=dtype_configsreturnself
[docs]defset_root_module(self,root_module:Type[torch.nn.Module])->BackendPatternConfig:""" Set the module that represents the root for this pattern. For example, the root module for :class:`torch.nn.intrinsic.LinearReLU` should be :class:`torch.nn.Linear`. """self.root_module=root_modulereturnself
[docs]defset_qat_module(self,qat_module:Type[torch.nn.Module])->BackendPatternConfig:""" Set the module that represents the QAT implementation for this pattern. """self.qat_module=qat_modulereturnself
[docs]defset_reference_quantized_module(self,reference_quantized_module:Type[torch.nn.Module])->BackendPatternConfig:""" Set the module that represents the reference quantized implementation for this pattern's root module. """self.reference_quantized_module=reference_quantized_modulereturnself
[docs]defset_fused_module(self,fused_module:Type[torch.nn.Module])->BackendPatternConfig:""" Set the module that represents the fused implementation for this pattern. """self.fused_module=fused_modulereturnself
[docs]defset_fuser_method(self,fuser_method:Callable)->BackendPatternConfig:""" Set the function that specifies how to fuse the pattern for this pattern. """self.fuser_method=fuser_methodreturnself
[docs]@classmethoddeffrom_dict(cls,backend_pattern_config_dict:Dict[str,Any])->BackendPatternConfig:""" Create a ``BackendPatternConfig`` from a dictionary with the following items: "pattern": the pattern being configured "observation_type": the :class:`~torch.ao.quantization.backend_config.ObservationType` that specifies how observers should be inserted for this pattern "dtype_configs": a list of dictionaries that represents :class:`~torch.ao.quantization.backend_config.DTypeConfig` s "root_module": a :class:`torch.nn.Module` that represents the root for this pattern "qat_module": a :class:`torch.nn.Module` that represents the QAT implementation for this pattern "reference_quantized_module": a :class:`torch.nn.Module` that represents the reference quantized implementation for this pattern's root module. "fused_module": a :class:`torch.nn.Module` that represents the fused implementation for this pattern "fuser_method": a function that specifies how to fuse the pattern for this pattern """def_get_dtype_config(obj:Any)->DTypeConfig:""" Convert the given object into a ``DTypeConfig`` if possible, else throw an exception. """ifisinstance(obj,DTypeConfig):returnobjifisinstance(obj,Dict):returnDTypeConfig.from_dict(obj)raiseValueError("Expected a list of DTypeConfigs in backend_pattern_config_dict[\"%s\"], got '%s'"%(DTYPE_CONFIGS_DICT_KEY,type(obj)))ifPATTERN_DICT_KEYnotinbackend_pattern_config_dict:raiseValueError("backend_pattern_config_dict must contain '%s'"%PATTERN_DICT_KEY)conf=cls(backend_pattern_config_dict[PATTERN_DICT_KEY])ifOBSERVATION_TYPE_DICT_KEYinbackend_pattern_config_dict:conf.set_observation_type(backend_pattern_config_dict[OBSERVATION_TYPE_DICT_KEY])fordinbackend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY,[]):conf.add_dtype_config(_get_dtype_config(d))conf.set_root_module(backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY,None))conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY,None))conf.set_reference_quantized_module(backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY,None))conf.set_fused_module(backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY,None))conf.set_fuser_method(backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY,None))conf._set_root_node_getter(backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY,None))conf._set_extra_inputs_getter(backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY,None))conf._set_num_tensor_args_to_observation_type(backend_pattern_config_dict.get(NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY,{}))conf._set_input_type_to_index(backend_pattern_config_dict.get(INPUT_TYPE_TO_INDEX_DICT_KEY,{}))conf._set_input_output_observed(backend_pattern_config_dict.get(INPUT_OUTPUT_OBSERVED_DICT_KEY,None))conf._set_overwrite_output_fake_quantize(backend_pattern_config_dict.get(OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY,None))conf._set_overwrite_output_observer(backend_pattern_config_dict.get(OVERWRITE_OUTPUT_OBSERVER_DICT_KEY,None))returnconf
[docs]defto_dict(self)->Dict[str,Any]:""" Convert this ``BackendPatternConfig`` to a dictionary with the items described in :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.from_dict`. """backend_pattern_config_dict:Dict[str,Any]={PATTERN_DICT_KEY:self.pattern,OBSERVATION_TYPE_DICT_KEY:self.observation_type,DTYPE_CONFIGS_DICT_KEY:[c.to_dict()forcinself.dtype_configs],}ifself.root_moduleisnotNone:backend_pattern_config_dict[ROOT_MODULE_DICT_KEY]=self.root_moduleifself.qat_moduleisnotNone:backend_pattern_config_dict[QAT_MODULE_DICT_KEY]=self.qat_moduleifself.reference_quantized_moduleisnotNone:backend_pattern_config_dict[REFERENCE_QUANTIZED_MODULE_DICT_KEY]=self.reference_quantized_moduleifself.fused_moduleisnotNone:backend_pattern_config_dict[FUSED_MODULE_DICT_KEY]=self.fused_moduleifself.fuser_methodisnotNone:backend_pattern_config_dict[FUSER_METHOD_DICT_KEY]=self.fuser_methodifself._root_node_getterisnotNone:backend_pattern_config_dict[ROOT_NODE_GETTER_DICT_KEY]=self._root_node_getterifself._extra_inputs_getterisnotNone:backend_pattern_config_dict[EXTRA_INPUTS_GETTER_DICT_KEY]=self._extra_inputs_getteriflen(self._num_tensor_args_to_observation_type)>0:backend_pattern_config_dict[NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY]=self._num_tensor_args_to_observation_typeiflen(self._input_type_to_index)>0:backend_pattern_config_dict[INPUT_TYPE_TO_INDEX_DICT_KEY]=self._input_type_to_indexifself._input_output_observedisnotNone:backend_pattern_config_dict[INPUT_OUTPUT_OBSERVED_DICT_KEY]=self._input_output_observedifself._overwrite_output_fake_quantizeisnotNone:backend_pattern_config_dict[OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY]=self._overwrite_output_fake_quantizeifself._overwrite_output_observerisnotNone:backend_pattern_config_dict[OVERWRITE_OUTPUT_OBSERVER_DICT_KEY]=self._overwrite_output_observerreturnbackend_pattern_config_dict
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.