Source code for torch.ao.quantization.backend_config.backend_config
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Type, Union
import torch
from torch.ao.quantization.observer import _PartialWrapper
from torch.ao.quantization.utils import Pattern
from enum import Enum
__all__ = [
"BackendConfig",
"BackendPatternConfig",
"DTypeConfig",
"DTypeWithConstraints",
"ObservationType",
]
# DTypeConfig dict keys
INPUT_DTYPE_DICT_KEY = "input_dtype"
OUTPUT_DTYPE_DICT_KEY = "output_dtype"
WEIGHT_DTYPE_DICT_KEY = "weight_dtype"
BIAS_DTYPE_DICT_KEY = "bias_dtype"
IS_DYNAMIC_DICT_KEY = "is_dynamic"
# BackendConfig dict keys
NAME_DICT_KEY = "name"
CONFIGS_DICT_KEY = "configs"
# BackendPatternConfig dict keys
PATTERN_DICT_KEY = "pattern"
OBSERVATION_TYPE_DICT_KEY = "observation_type"
DTYPE_CONFIGS_DICT_KEY = "dtype_configs"
ROOT_MODULE_DICT_KEY = "root_module"
QAT_MODULE_DICT_KEY = "qat_module"
REFERENCE_QUANTIZED_MODULE_DICT_KEY = "reference_quantized_module_for_root"
FUSED_MODULE_DICT_KEY = "fused_module"
FUSER_METHOD_DICT_KEY = "fuser_method"
ROOT_NODE_GETTER_DICT_KEY = "root_node_getter"
EXTRA_INPUTS_GETTER_DICT_KEY = "extra_inputs_getter"
NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY = "num_tensor_args_to_observation_type"
INPUT_TYPE_TO_INDEX_DICT_KEY = "input_type_to_index"
INPUT_OUTPUT_OBSERVED_DICT_KEY = "input_output_observed"
OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY = "overwrite_output_fake_quantize"
OVERWRITE_OUTPUT_OBSERVER_DICT_KEY = "overwrite_output_observer"
# TODO: maybe rename this to something that's not related to observer
# e.g. QParamsType
[docs]class ObservationType(Enum):
""" An enum that represents different ways of how an operator/operator pattern
should be observed
"""
OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT = 0
"""this means input and output are observed with different observers, based
on qconfig.activation
example: conv, linear, softmax
"""
OUTPUT_SHARE_OBSERVER_WITH_INPUT = 1
"""this means the output will use the same observer instance as input, based
on qconfig.activation
example: torch.cat, maxpool
"""
@dataclass
class DTypeWithConstraints:
"""
Config for specifying additional constraints for a given dtype, such as quantization value
ranges and scale value ranges, to be used in :class:`~torch.ao.quantization.backend_config.DTypeConfig`.
"""
dtype: Optional[torch.dtype] = None
quant_min_lower_bound: Union[int, float, None] = None
quant_max_upper_bound: Union[int, float, None] = None
scale_min_lower_bound: Union[int, float, None] = None
scale_max_upper_bound: Union[int, float, None] = None
[docs]@dataclass
class DTypeConfig:
"""
Config for the set of supported input/output activation, weight, and bias data types for the
patterns defined in :class:`~torch.ao.quantization.backend_config.BackendConfig`.
Example usage::
>>> dtype_config1 = DTypeConfig(
... input_dtype=torch.quint8,
... output_dtype=torch.quint8,
... weight_dtype=torch.qint8,
... bias_dtype=torch.float)
>>> dtype_config2 = DTypeConfig(
... input_dtype=DTypeWithConstraints(
... dtype=torch.quint8,
... quant_min_lower_bound=0,
... quant_max_upper_bound=255,
... ),
... output_dtype=DTypeWithConstraints(
... dtype=torch.quint8,
... quant_min_lower_bound=0,
... quant_max_upper_bound=255,
... ),
... weight_dtype=DTypeWithConstraints(
... dtype=torch.qint8,
... quant_min_lower_bound=-128,
... quant_max_upper_bound=127,
... ),
... bias_dtype=torch.float)
>>> dtype_config1.input_dtype
torch.quint8
>>> dtype_config2.input_dtype
torch.quint8
>>> dtype_config2.input_dtype_with_constraints
DTypeWithConstraints(dtype=torch.quint8, quant_min_lower_bound=0, quant_max_upper_bound=255, \
scale_min_lower_bound=None, scale_max_upper_bound=None)
"""
input_dtype_with_constraints: DTypeWithConstraints
output_dtype_with_constraints: DTypeWithConstraints
weight_dtype_with_constraints: DTypeWithConstraints
bias_dtype: Optional[torch.dtype]
is_dynamic: Optional[bool]
def __init__(
self,
input_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None,
output_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None,
weight_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None,
bias_dtype: Optional[torch.dtype] = None,
is_dynamic: Optional[bool] = None,
):
if isinstance(input_dtype, DTypeWithConstraints):
self.input_dtype_with_constraints = input_dtype
else:
self.input_dtype_with_constraints = DTypeWithConstraints(dtype=input_dtype)
if isinstance(output_dtype, DTypeWithConstraints):
self.output_dtype_with_constraints = output_dtype
else:
self.output_dtype_with_constraints = DTypeWithConstraints(dtype=output_dtype)
if isinstance(weight_dtype, DTypeWithConstraints):
self.weight_dtype_with_constraints = weight_dtype
else:
self.weight_dtype_with_constraints = DTypeWithConstraints(dtype=weight_dtype)
self.bias_dtype = bias_dtype
self.is_dynamic = is_dynamic
@property
def input_dtype(self) -> Optional[torch.dtype]:
return self.input_dtype_with_constraints.dtype
@property
def output_dtype(self) -> Optional[torch.dtype]:
return self.output_dtype_with_constraints.dtype
@property
def weight_dtype(self) -> Optional[torch.dtype]:
return self.weight_dtype_with_constraints.dtype
[docs] @classmethod
def from_dict(cls, dtype_config_dict: Dict[str, Any]) -> DTypeConfig:
"""
Create a ``DTypeConfig`` from a dictionary with the following items (all optional):
"input_dtype": torch.dtype or ``DTypeWithConstraints``
"output_dtype": torch.dtype or ``DTypeWithConstraints``
"weight_dtype": torch.dtype or ``DTypeWithConstraints``
"bias_type": torch.dtype
"is_dynamic": bool
"""
input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY, None)
if input_dtype is not None and not isinstance(input_dtype, (torch.dtype, DTypeWithConstraints)):
raise ValueError("Expected input_dtype to be a torch.dtype or DTypeWithConstraints")
output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY, None)
if output_dtype is not None and not isinstance(output_dtype, (torch.dtype, DTypeWithConstraints)):
raise ValueError("Expected output_dtype to be a torch.dtype or DTypeWithConstraints")
weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY, None)
if weight_dtype is not None and not isinstance(weight_dtype, (torch.dtype, DTypeWithConstraints)):
raise ValueError("Expected weight_dtype to be a torch.dtype or DTypeWithConstraints")
bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY, None)
is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY, None)
return cls(input_dtype, output_dtype, weight_dtype, bias_dtype, is_dynamic)
[docs] def to_dict(self) -> Dict[str, Any]:
"""
Convert this ``DTypeConfig`` to a dictionary with the items described in
:func:`~torch.ao.quantization.backend_config.DTypeConfig.from_dict`.
"""
dtype_config_dict: Dict[str, Any] = {}
if self.input_dtype is not None:
dtype_config_dict[INPUT_DTYPE_DICT_KEY] = self.input_dtype_with_constraints
if self.output_dtype is not None:
dtype_config_dict[OUTPUT_DTYPE_DICT_KEY] = self.output_dtype_with_constraints
if self.weight_dtype is not None:
dtype_config_dict[WEIGHT_DTYPE_DICT_KEY] = self.weight_dtype_with_constraints
if self.bias_dtype is not None:
dtype_config_dict[BIAS_DTYPE_DICT_KEY] = self.bias_dtype
if self.is_dynamic is not None:
dtype_config_dict[IS_DYNAMIC_DICT_KEY] = self.is_dynamic
return dtype_config_dict
[docs]class BackendConfig:
# TODO: refer to NativeBackendConfig once that is implemented
"""Config that defines the set of patterns that can be quantized on a given backend, and how reference
quantized models can be produced from these patterns.
A pattern in this context refers to a module, a functional, an operator, or a directed acyclic graph
of the above. Each pattern supported on the target backend can be individually configured through
:class:`~torch.ao.quantization.backend_config.BackendPatternConfig` in terms of:
(1) The supported input/output activation, weight, and bias data types
(2) How observers and quant/dequant ops are inserted in order to construct the reference pattern, and
(3) (Optionally) Fusion, QAT, and reference module mappings.
The format of the patterns is described in:
https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md
Example usage::
import torch
from torch.ao.quantization.backend_config import BackendConfig, BackendPatternConfig, DTypeConfig, ObservationType
from torch.ao.quantization.fuser_method_mappings import reverse_sequential_wrapper2
weighted_int8_dtype_config = DTypeConfig(
input_dtype=torch.quint8,
output_dtype=torch.quint8,
weight_dtype=torch.qint8,
bias_type=torch.float)
linear_config = BackendPatternConfig(torch.nn.Linear) \
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
.add_dtype_config(weighted_int8_dtype_config) \
.set_root_module(torch.nn.Linear) \
.set_qat_module(torch.nn.qat.Linear) \
.set_reference_quantized_module(torch.nn.quantized._reference.Linear)
conv_relu_config = BackendPatternConfig((torch.nn.ReLU, torch.nn.Conv2d)) \
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
.add_dtype_config(weighted_int8_dtype_config) \
.set_fused_module(torch.nn.intrinsic.ConvReLU2d) \
.set_fuser_method(reverse_sequential_wrapper2(torch.nn.intrinsic.ConvReLU2d))
backend_config = BackendConfig("my_backend") \
.set_backend_pattern_config(linear_config) \
.set_backend_pattern_config(conv_relu_config)
"""
def __init__(self, name: str = ""):
self.name = name
self.configs: Dict[Pattern, BackendPatternConfig] = {}
[docs] def set_name(self, name: str) -> BackendConfig:
"""
Set the name of the target backend.
"""
self.name = name
return self
[docs] def set_backend_pattern_config(self, config: BackendPatternConfig) -> BackendConfig:
"""
Set the config for an pattern that can be run on the target backend.
This overrides any existing config for the given pattern.
"""
self.configs[config.pattern] = config
return self
[docs] def set_backend_pattern_configs(self, configs: List[BackendPatternConfig]) -> BackendConfig:
"""
Set the configs for patterns that can be run on the target backend.
This overrides any existing config for a given pattern if it was previously registered already.
"""
for conf in configs:
self.set_backend_pattern_config(conf)
return self
[docs] @classmethod
def from_dict(cls, backend_config_dict: Dict[str, Any]) -> BackendConfig:
"""
Create a ``BackendConfig`` from a dictionary with the following items:
"name": the name of the target backend
"configs": a list of dictionaries that each represents a `BackendPatternConfig`
"""
conf = cls(backend_config_dict.get(NAME_DICT_KEY, ""))
for d in backend_config_dict.get(CONFIGS_DICT_KEY, []):
if isinstance(d, BackendPatternConfig):
conf.set_backend_pattern_config(d)
elif isinstance(d, Dict):
conf.set_backend_pattern_config(BackendPatternConfig.from_dict(d))
else:
raise ValueError("Expected backend_config_dict['%s'] to be a dictionary" % CONFIGS_DICT_KEY)
return conf
[docs] def to_dict(self) -> Dict[str, Any]:
"""
Convert this ``BackendConfig`` to a dictionary with the items described in
:func:`~torch.ao.quantization.backend_config.BackendConfig.from_dict`.
"""
return {
NAME_DICT_KEY: self.name,
CONFIGS_DICT_KEY: [c.to_dict() for c in self.configs.values()],
}
[docs]class BackendPatternConfig:
"""
Config for ops defined in :class:`~torch.ao.quantization.backend_config.BackendConfig`.
For a detailed example usage, see :class:`~torch.ao.quantization.backend_config.BackendConfig`.
"""
def __init__(self, pattern: Pattern):
self.pattern = pattern
self.observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
self.dtype_configs: List[DTypeConfig] = []
self.root_module: Optional[Type[torch.nn.Module]] = None
self.qat_module: Optional[Type[torch.nn.Module]] = None
self.reference_quantized_module: Optional[Type[torch.nn.Module]] = None
self.fused_module: Optional[Type[torch.nn.Module]] = None
self.fuser_method: Optional[Callable] = None
# Temporary/internal configs
self._root_node_getter: Optional[Callable] = None
self._extra_inputs_getter: Optional[Callable] = None
self._num_tensor_args_to_observation_type: Dict[int, ObservationType] = {}
self._input_type_to_index: Dict[str, int] = {}
self._input_output_observed: Optional[bool] = None
self._overwrite_output_fake_quantize: Optional[_PartialWrapper] = None
self._overwrite_output_observer: Optional[_PartialWrapper] = None
[docs] def set_observation_type(self, observation_type: ObservationType) -> BackendPatternConfig:
"""
Set how observers should be inserted for this pattern.
See :class:`~torch.ao.quantization.backend_config.ObservationType` for details
"""
self.observation_type = observation_type
return self
[docs] def add_dtype_config(self, dtype_config: DTypeConfig) -> BackendPatternConfig:
"""
Add a set of supported input/output activation, weight, and bias data types for this pattern.
"""
self.dtype_configs.append(dtype_config)
return self
[docs] def set_dtype_configs(self, dtype_configs: List[DTypeConfig]) -> BackendPatternConfig:
"""
Set the supported input/output activation, weight, and bias data types for this pattern,
overriding all previously registered data types.
"""
self.dtype_configs = dtype_configs
return self
[docs] def set_root_module(self, root_module: Type[torch.nn.Module]) -> BackendPatternConfig:
"""
Set the module that represents the root for this pattern.
For example, the root module for :class:`torch.nn.intrinsic.LinearReLU` should be :class:`torch.nn.Linear`.
"""
self.root_module = root_module
return self
[docs] def set_qat_module(self, qat_module: Type[torch.nn.Module]) -> BackendPatternConfig:
"""
Set the module that represents the QAT implementation for this pattern.
"""
self.qat_module = qat_module
return self
[docs] def set_reference_quantized_module(self, reference_quantized_module: Type[torch.nn.Module]) -> BackendPatternConfig:
"""
Set the module that represents the reference quantized implementation for this pattern's root module.
"""
self.reference_quantized_module = reference_quantized_module
return self
[docs] def set_fused_module(self, fused_module: Type[torch.nn.Module]) -> BackendPatternConfig:
"""
Set the module that represents the fused implementation for this pattern.
"""
self.fused_module = fused_module
return self
[docs] def set_fuser_method(self, fuser_method: Callable) -> BackendPatternConfig:
"""
Set the function that specifies how to fuse the pattern for this pattern.
"""
self.fuser_method = fuser_method
return self
def _set_root_node_getter(self, root_node_getter: Callable) -> BackendPatternConfig:
self._root_node_getter = root_node_getter
return self
def _set_extra_inputs_getter(self, extra_inputs_getter: Callable) -> BackendPatternConfig:
self._extra_inputs_getter = extra_inputs_getter
return self
def _set_num_tensor_args_to_observation_type(
self, num_tensor_args_to_observation_type: Dict[int, ObservationType]) -> BackendPatternConfig:
self._num_tensor_args_to_observation_type = num_tensor_args_to_observation_type
return self
def _set_input_type_to_index(self, input_type_to_index: Dict[str, int]) -> BackendPatternConfig:
self._input_type_to_index = input_type_to_index
return self
def _set_input_output_observed(self, input_output_observed: bool) -> BackendPatternConfig:
self._input_output_observed = input_output_observed
return self
def _set_overwrite_output_fake_quantize(self, overwrite_output_fake_quantize: _PartialWrapper) -> BackendPatternConfig:
self._overwrite_output_fake_quantize = overwrite_output_fake_quantize
return self
def _set_overwrite_output_observer(self, overwrite_output_observer: _PartialWrapper) -> BackendPatternConfig:
self._overwrite_output_observer = overwrite_output_observer
return self
[docs] @classmethod
def from_dict(cls, backend_pattern_config_dict: Dict[str, Any]) -> BackendPatternConfig:
"""
Create a ``BackendPatternConfig`` from a dictionary with the following items:
"pattern": the pattern being configured
"observation_type": the :class:`~torch.ao.quantization.backend_config.ObservationType` that specifies how
observers should be inserted for this pattern
"dtype_configs": a list of dictionaries that represents :class:`~torch.ao.quantization.backend_config.DTypeConfig` s
"root_module": a :class:`torch.nn.Module` that represents the root for this pattern
"qat_module": a :class:`torch.nn.Module` that represents the QAT implementation for this pattern
"reference_quantized_module": a :class:`torch.nn.Module` that represents the reference quantized
implementation for this pattern's root module.
"fused_module": a :class:`torch.nn.Module` that represents the fused implementation for this pattern
"fuser_method": a function that specifies how to fuse the pattern for this pattern
"""
def _get_dtype_config(obj: Any) -> DTypeConfig:
"""
Convert the given object into a ``DTypeConfig`` if possible, else throw an exception.
"""
if isinstance(obj, DTypeConfig):
return obj
if isinstance(obj, Dict):
return DTypeConfig.from_dict(obj)
raise ValueError("Expected a list of DTypeConfigs in backend_pattern_config_dict[\"%s\"], got '%s'" %
(DTYPE_CONFIGS_DICT_KEY, type(obj)))
if PATTERN_DICT_KEY not in backend_pattern_config_dict:
raise ValueError("backend_pattern_config_dict must contain '%s'" % PATTERN_DICT_KEY)
conf = cls(backend_pattern_config_dict[PATTERN_DICT_KEY])
if OBSERVATION_TYPE_DICT_KEY in backend_pattern_config_dict:
conf.set_observation_type(backend_pattern_config_dict[OBSERVATION_TYPE_DICT_KEY])
for d in backend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY, []):
conf.add_dtype_config(_get_dtype_config(d))
conf.set_root_module(backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY, None))
conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY, None))
conf.set_reference_quantized_module(backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY, None))
conf.set_fused_module(backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY, None))
conf.set_fuser_method(backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY, None))
conf._set_root_node_getter(backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY, None))
conf._set_extra_inputs_getter(backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY, None))
conf._set_num_tensor_args_to_observation_type(
backend_pattern_config_dict.get(NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY, {}))
conf._set_input_type_to_index(backend_pattern_config_dict.get(INPUT_TYPE_TO_INDEX_DICT_KEY, {}))
conf._set_input_output_observed(backend_pattern_config_dict.get(INPUT_OUTPUT_OBSERVED_DICT_KEY, None))
conf._set_overwrite_output_fake_quantize(backend_pattern_config_dict.get(OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY, None))
conf._set_overwrite_output_observer(backend_pattern_config_dict.get(OVERWRITE_OUTPUT_OBSERVER_DICT_KEY, None))
return conf
[docs] def to_dict(self) -> Dict[str, Any]:
"""
Convert this ``BackendPatternConfig`` to a dictionary with the items described in
:func:`~torch.ao.quantization.backend_config.BackendPatternConfig.from_dict`.
"""
backend_pattern_config_dict: Dict[str, Any] = {
PATTERN_DICT_KEY: self.pattern,
OBSERVATION_TYPE_DICT_KEY: self.observation_type,
DTYPE_CONFIGS_DICT_KEY: [c.to_dict() for c in self.dtype_configs],
}
if self.root_module is not None:
backend_pattern_config_dict[ROOT_MODULE_DICT_KEY] = self.root_module
if self.qat_module is not None:
backend_pattern_config_dict[QAT_MODULE_DICT_KEY] = self.qat_module
if self.reference_quantized_module is not None:
backend_pattern_config_dict[REFERENCE_QUANTIZED_MODULE_DICT_KEY] = self.reference_quantized_module
if self.fused_module is not None:
backend_pattern_config_dict[FUSED_MODULE_DICT_KEY] = self.fused_module
if self.fuser_method is not None:
backend_pattern_config_dict[FUSER_METHOD_DICT_KEY] = self.fuser_method
if self._root_node_getter is not None:
backend_pattern_config_dict[ROOT_NODE_GETTER_DICT_KEY] = self._root_node_getter
if self._extra_inputs_getter is not None:
backend_pattern_config_dict[EXTRA_INPUTS_GETTER_DICT_KEY] = self._extra_inputs_getter
if len(self._num_tensor_args_to_observation_type) > 0:
backend_pattern_config_dict[NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY] = self._num_tensor_args_to_observation_type
if len(self._input_type_to_index) > 0:
backend_pattern_config_dict[INPUT_TYPE_TO_INDEX_DICT_KEY] = self._input_type_to_index
if self._input_output_observed is not None:
backend_pattern_config_dict[INPUT_OUTPUT_OBSERVED_DICT_KEY] = self._input_output_observed
if self._overwrite_output_fake_quantize is not None:
backend_pattern_config_dict[OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY] = self._overwrite_output_fake_quantize
if self._overwrite_output_observer is not None:
backend_pattern_config_dict[OVERWRITE_OUTPUT_OBSERVER_DICT_KEY] = self._overwrite_output_observer
return backend_pattern_config_dict