Shortcuts

torch.quantization

This module implements the functions you call directly to convert your model from FP32 to quantized form. For example the prepare() is used in post training quantization to prepares your model for the calibration step and convert() actually converts the weights to int8 and replaces the operations with their quantized counterparts. There are other helper functions for things like quantizing the input to your model and performing critical fusions like conv+relu.

Top-level quantization APIs

torch.quantization.quantize(model, run_fn, run_args, mapping=None, inplace=False)[source]

Quantize the input float model with post training static quantization.

First it will prepare the model for calibration, then it calls run_fn which will run the calibration step, after that we will convert the model to a quantized model.

Parameters
  • model – input float model

  • run_fn – a calibration function for calibrating the prepared model

  • run_args – positional arguments for run_fn

  • inplace – carry out model transformations in-place, the original module is mutated

  • mapping – correspondence between original module types and quantized counterparts

Returns

Quantized model.

torch.quantization.quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False)[source]

Converts a float model to dynamic (i.e. weights-only) quantized model.

Replaces specified modules with dynamic weight-only quantized versions and output the quantized model.

For simplest usage provide dtype argument that can be float16 or qint8. Weight-only quantization by default is performed for layers with large weights size - i.e. Linear and RNN variants.

Fine grained control is possible with qconfig and mapping that act similarly to quantize(). If qconfig is provided, the dtype argument is ignored.

Parameters
  • model – input model

  • qconfig_spec

    Either:

    • A dictionary that maps from name or type of submodule to quantization configuration, qconfig applies to all submodules of a given module unless qconfig for the submodules are specified (when the submodule already has qconfig attribute). Entries in the dictionary need to be QConfigDynamic instances.

    • A set of types and/or submodule names to apply dynamic quantization to, in which case the dtype argument is used to specify the bit-width

  • inplace – carry out model transformations in-place, the original module is mutated

  • mapping – maps type of a submodule to a type of corresponding dynamically quantized version with which the submodule needs to be replaced

torch.quantization.quantize_qat(model, run_fn, run_args, inplace=False)[source]

Do quantization aware training and output a quantized model

Parameters
  • model – input model

  • run_fn – a function for evaluating the prepared model, can be a function that simply runs the prepared model or a training loop

  • run_args – positional arguments for run_fn

Returns

Quantized model.

torch.quantization.prepare(model, inplace=False, allow_list=None, observer_non_leaf_module_list=None, prepare_custom_config_dict=None)[source]

Prepares a copy of the model for quantization calibration or quantization-aware training.

Quantization configuration should be assigned preemptively to individual submodules in .qconfig attribute.

The model will be attached with observer or fake quant modules, and qconfig will be propagated.

Parameters
  • model – input model to be modified in-place

  • inplace – carry out model transformations in-place, the original module is mutated

  • allow_list – list of quantizable modules

  • observer_non_leaf_module_list – list of non-leaf modules we want to add observer

  • prepare_custom_config_dict – customization configuration dictionary for prepare function

# Example of prepare_custom_config_dict:
prepare_custom_config_dict = {
    # user will manually define the corresponding observed
    # module class which has a from_float class method that converts
    # float custom module to observed custom module
    "float_to_observed_custom_module_class": {
        CustomModule: ObservedCustomModule
    }
 }
torch.quantization.prepare_qat(model, mapping=None, inplace=False)[source]

Prepares a copy of the model for quantization calibration or quantization-aware training and converts it to quantized version.

Quantization configuration should be assigned preemptively to individual submodules in .qconfig attribute.

Parameters
  • model – input model to be modified in-place

  • mapping – dictionary that maps float modules to quantized modules to be replaced.

  • inplace – carry out model transformations in-place, the original module is mutated

torch.quantization.convert(module, mapping=None, inplace=False, remove_qconfig=True, convert_custom_config_dict=None)[source]

Converts submodules in input module to a different module according to mapping by calling from_float method on the target module class. And remove qconfig at the end if remove_qconfig is set to True.

Parameters
  • module – prepared and calibrated module

  • mapping – a dictionary that maps from source module type to target module type, can be overwritten to allow swapping user defined Modules

  • inplace – carry out model transformations in-place, the original module is mutated

  • convert_custom_config_dict – custom configuration dictionary for convert function

# Example of convert_custom_config_dict:
convert_custom_config_dict = {
    # user will manually define the corresponding quantized
    # module class which has a from_observed class method that converts
    # observed custom module to quantized custom module
    "observed_to_quantized_custom_module_class": {
        ObservedCustomModule: QuantizedCustomModule
    }
}
class torch.quantization.QConfig(activation, weight)[source]

Describes how to quantize a layer or a part of the network by providing settings (observer classes) for activations and weights respectively.

Note that QConfig needs to contain observer classes (like MinMaxObserver) or a callable that returns instances on invocation, not the concrete observer instances themselves. Quantization preparation function will instantiate observers multiple times for each of the layers.

Observer classes have usually reasonable default arguments, but they can be overwritten with with_args method (that behaves like functools.partial):

my_qconfig = QConfig(activation=MinMaxObserver.with_args(dtype=torch.qint8), weight=default_observer.with_args(dtype=torch.qint8))

class torch.quantization.QConfigDynamic(activation=<class 'torch.nn.modules.linear.Identity'>, weight=<class 'torch.nn.modules.linear.Identity'>)[source]

Describes how to dynamically quantize a layer or a part of the network by providing settings (observer classes) for weights.

It’s like QConfig, but for dynamic quantization.

Note that QConfigDynamic needs to contain observer classes (like MinMaxObserver) or a callable that returns instances on invocation, not the concrete observer instances themselves. Quantization function will instantiate observers multiple times for each of the layers.

Observer classes have usually reasonable default arguments, but they can be overwritten with with_args method (that behaves like functools.partial):

my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8))

Preparing model for quantization

torch.quantization.fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=<function fuse_known_modules>, fuse_custom_config_dict=None)[source]

Fuses a list of modules into a single module

Fuses only the following sequence of modules: conv, bn conv, bn, relu conv, relu linear, relu bn, relu All other sequences are left unchanged. For these sequences, replaces the first item in the list with the fused module, replacing the rest of the modules with identity.

Parameters
  • model – Model containing the modules to be fused

  • modules_to_fuse – list of list of module names to fuse. Can also be a list of strings if there is only a single list of modules to fuse.

  • inplace – bool specifying if fusion happens in place on the model, by default a new model is returned

  • fuser_func – Function that takes in a list of modules and outputs a list of fused modules of the same length. For example, fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()] Defaults to torch.quantization.fuse_known_modules

  • fuse_custom_config_dict – custom configuration for fusion

# Example of fuse_custom_config_dict
fuse_custom_config_dict = {
    # Additional fuser_method mapping
    "additional_fuser_method_mapping": {
        (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
    },
}
Returns

model with fused modules. A new copy is created if inplace=True.

Examples:

>>> m = myModel()
>>> # m is a module containing  the sub-modules below
>>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
>>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
>>> output = fused_m(input)

>>> m = myModel()
>>> # Alternately provide a single list of modules to fuse
>>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
>>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
>>> output = fused_m(input)
class torch.quantization.QuantStub(qconfig=None)[source]

Quantize stub module, before calibration, this is same as an observer, it will be swapped as nnq.Quantize in convert.

Parameters

qconfig – quantization configuration for the tensor, if qconfig is not provided, we will get qconfig from parent modules

class torch.quantization.DeQuantStub[source]

Dequantize stub module, before calibration, this is same as identity, this will be swapped as nnq.DeQuantize in convert.

class torch.quantization.QuantWrapper(module)[source]

A wrapper class that wraps the input module, adds QuantStub and DeQuantStub and surround the call to module with call to quant and dequant modules.

This is used by the quantization utility functions to add the quant and dequant modules, before convert function QuantStub will just be observer, it observes the input tensor, after convert, QuantStub will be swapped to nnq.Quantize which does actual quantization. Similarly for DeQuantStub.

torch.quantization.add_quant_dequant(module)[source]

Wrap the leaf child module in QuantWrapper if it has a valid qconfig Note that this function will modify the children of module inplace and it can return a new module which wraps the input module as well.

Parameters
  • module – input module with qconfig attributes for all the leaf modules

  • we want to quantize (that) –

Returns

Either the inplace modified module with submodules wrapped in QuantWrapper based on qconfig or a new QuantWrapper module which wraps the input module, the latter case only happens when the input module is a leaf module and we want to quantize it.

Utility functions

torch.quantization.add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None)[source]

Add observer for the leaf child of the module.

This function insert observer module to all leaf child module that has a valid qconfig attribute.

Parameters
  • module – input module with qconfig attributes for all the leaf modules that we want to quantize

  • device – parent device, if any

  • non_leaf_module_list – list of non-leaf modules we want to add observer

Returns

None, module is modified inplace with added observer modules and forward_hooks

torch.quantization.swap_module(mod, mapping, custom_module_class_mapping)[source]

Swaps the module if it has a quantized counterpart and it has an observer attached.

Parameters
  • mod – input module

  • mapping – a dictionary that maps from nn module to nnq module

Returns

The corresponding quantized module of mod

torch.quantization.propagate_qconfig_(module, qconfig_dict=None, allow_list=None)[source]

Propagate qconfig through the module hierarchy and assign qconfig attribute on each leaf module

Parameters
  • module – input module

  • qconfig_dict – dictionary that maps from name or type of submodule to quantization configuration, qconfig applies to all submodules of a given module unless qconfig for the submodules are specified (when the submodule already has qconfig attribute)

Returns

None, module is modified inplace with qconfig attached

torch.quantization.default_eval_fn(model, calib_data)[source]

Default evaluation function takes a torch.utils.data.Dataset or a list of input Tensors and run the model on the dataset

Observers

class torch.quantization.ObserverBase(dtype)[source]

Base observer Module. Any observer implementation should derive from this class.

Concrete observers should follow the same API. In forward, they will update the statistics of the observed Tensor. And they should provide a calculate_qparams function that computes the quantization parameters given the collected statistics.

Parameters

dtype – Quantized data type

classmethod with_args(**kwargs)

Wrapper that allows creation of class factories.

This can be useful when there is a need to create classes with the same constructor arguments, but different instances. Can be used in conjunction with _callable_args

Example:

>>> Foo.with_args = classmethod(_with_args)
>>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42)
>>> foo_instance1 = foo_builder()
>>> foo_instance2 = foo_builder()
>>> id(foo_instance1) == id(foo_instance2)
False
classmethod with_callable_args(**kwargs)

Wrapper that allows creation of class factories args that need to be called at construction time.

This can be useful when there is a need to create classes with the same constructor arguments, but different instances and those arguments should only be calculated at construction time. Can be used in conjunction with _with_args

Example:

>>> Foo.with_callable_args = classmethod(_with_callable_args)
>>> Foo.with_args = classmethod(_with_args)
>>> foo_builder = Foo.with_callable_args(cur_time=get_time_func).with_args(name="dan")
>>> foo_instance1 = foo_builder()
>>> wait 50
>>> foo_instance2 = foo_builder()
>>> id(foo_instance1.creation_time) == id(foo_instance2.creation_time)
False
class torch.quantization.MinMaxObserver(dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=None, quant_max=None, factory_kwargs=None)[source]

Observer module for computing the quantization parameters based on the running min and max values.

This observer uses the tensor min/max statistics to compute the quantization parameters. The module records the running minimum and maximum of incoming tensors, and uses this statistic to compute the quantization parameters.

Parameters
  • dtype – Quantized data type

  • qscheme – Quantization scheme to be used

  • reduce_range – Reduces the range of the quantized data type by 1 bit

  • quant_min – Minimum quantization value. If unspecified, it will follow the 8-bit setup.

  • quant_max – Maximum quantization value. If unspecified, it will follow the 8-bit setup.

Given running min/max as xminx_\text{min} and xmaxx_\text{max}, scale ss and zero point zz are computed as:

The running minimum/maximum xmin/maxx_\text{min/max} is computed as:

xmin={min(X)if xmin=Nonemin(xmin,min(X))otherwisexmax={max(X)if xmax=Nonemax(xmax,max(X))otherwise\begin{array}{ll} x_\text{min} &= \begin{cases} \min(X) & \text{if~}x_\text{min} = \text{None} \\ \min\left(x_\text{min}, \min(X)\right) & \text{otherwise} \end{cases}\\ x_\text{max} &= \begin{cases} \max(X) & \text{if~}x_\text{max} = \text{None} \\ \max\left(x_\text{max}, \max(X)\right) & \text{otherwise} \end{cases}\\ \end{array}

where XX is the observed tensor.

The scale ss and zero point zz are then computed as:

if Symmetric:s=2max(xmin,xmax)/(QmaxQmin)z={0if dtype is qint8128otherwiseOtherwise:s=(xmaxxmin)/(QmaxQmin)z=Qminround(xmin/s)\begin{aligned} \text{if Symmetric:}&\\ &s = 2 \max(|x_\text{min}|, x_\text{max}) / \left( Q_\text{max} - Q_\text{min} \right) \\ &z = \begin{cases} 0 & \text{if dtype is qint8} \\ 128 & \text{otherwise} \end{cases}\\ \text{Otherwise:}&\\ &s = \left( x_\text{max} - x_\text{min} \right ) / \left( Q_\text{max} - Q_\text{min} \right ) \\ &z = Q_\text{min} - \text{round}(x_\text{min} / s) \end{aligned}

where QminQ_\text{min} and QmaxQ_\text{max} are the minimum and maximum of the quantized data type.

Warning

Only works with torch.per_tensor_symmetric quantization scheme

Warning

dtype can only take torch.qint8 or torch.quint8.

Note

If the running minimum equals to the running maximum, the scale and zero_point are set to 1.0 and 0.

class torch.quantization.MovingAverageMinMaxObserver(averaging_constant=0.01, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=None, quant_max=None, **kwargs)[source]

Observer module for computing the quantization parameters based on the moving average of the min and max values.

This observer computes the quantization parameters based on the moving averages of minimums and maximums of the incoming tensors. The module records the average minimum and maximum of incoming tensors, and uses this statistic to compute the quantization parameters.

Parameters
  • averaging_constant – Averaging constant for min/max.

  • dtype – Quantized data type

  • qscheme – Quantization scheme to be used

  • reduce_range – Reduces the range of the quantized data type by 1 bit

  • quant_min – Minimum quantization value. If unspecified, it will follow the 8-bit setup.

  • quant_max – Maximum quantization value. If unspecified, it will follow the 8-bit setup.

The moving average min/max is computed as follows

xmin={min(X)if xmin=None(1c)xmin+cmin(X)otherwisexmax={max(X)if xmax=None(1c)xmax+cmax(X)otherwise\begin{array}{ll} x_\text{min} = \begin{cases} \min(X) & \text{if~}x_\text{min} = \text{None} \\ (1 - c) x_\text{min} + c \min(X) & \text{otherwise} \end{cases}\\ x_\text{max} = \begin{cases} \max(X) & \text{if~}x_\text{max} = \text{None} \\ (1 - c) x_\text{max} + c \max(X) & \text{otherwise} \end{cases}\\ \end{array}

where xmin/maxx_\text{min/max} is the running average min/max, XX is is the incoming tensor, and cc is the averaging_constant.

The scale and zero point are then computed as in MinMaxObserver.

Note

Only works with torch.per_tensor_affine quantization scheme.

Note

If the running minimum equals to the running maximum, the scale and zero_point are set to 1.0 and 0.

class torch.quantization.PerChannelMinMaxObserver(ch_axis=0, dtype=torch.quint8, qscheme=torch.per_channel_affine, reduce_range=False, quant_min=None, quant_max=None, factory_kwargs=None)[source]

Observer module for computing the quantization parameters based on the running per channel min and max values.

This observer uses the tensor min/max statistics to compute the per channel quantization parameters. The module records the running minimum and maximum of incoming tensors, and uses this statistic to compute the quantization parameters.

Parameters
  • ch_axis – Channel axis

  • dtype – Quantized data type

  • qscheme – Quantization scheme to be used

  • reduce_range – Reduces the range of the quantized data type by 1 bit

  • quant_min – Minimum quantization value. If unspecified, it will follow the 8-bit setup.

  • quant_max – Maximum quantization value. If unspecified, it will follow the 8-bit setup.

The quantization parameters are computed the same way as in MinMaxObserver, with the difference that the running min/max values are stored per channel. Scales and zero points are thus computed per channel as well.

Note

If the running minimum equals to the running maximum, the scales and zero_points are set to 1.0 and 0.

class torch.quantization.MovingAveragePerChannelMinMaxObserver(averaging_constant=0.01, ch_axis=0, dtype=torch.quint8, qscheme=torch.per_channel_affine, reduce_range=False, quant_min=None, quant_max=None, **kwargs)[source]

Observer module for computing the quantization parameters based on the running per channel min and max values.

This observer uses the tensor min/max statistics to compute the per channel quantization parameters. The module records the running minimum and maximum of incoming tensors, and uses this statistic to compute the quantization parameters.

Parameters
  • averaging_constant – Averaging constant for min/max.

  • ch_axis – Channel axis

  • dtype – Quantized data type

  • qscheme – Quantization scheme to be used

  • reduce_range – Reduces the range of the quantized data type by 1 bit

  • quant_min – Minimum quantization value. If unspecified, it will follow the 8-bit setup.

  • quant_max – Maximum quantization value. If unspecified, it will follow the 8-bit setup.

The quantization parameters are computed the same way as in MovingAverageMinMaxObserver, with the difference that the running min/max values are stored per channel. Scales and zero points are thus computed per channel as well.

Note

If the running minimum equals to the running maximum, the scales and zero_points are set to 1.0 and 0.

class torch.quantization.HistogramObserver(bins=2048, upsample_rate=128, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, factory_kwargs=None)[source]

The module records the running histogram of tensor values along with min/max values. calculate_qparams will calculate scale and zero_point.

Parameters
  • bins – Number of bins to use for the histogram

  • upsample_rate – Factor by which the histograms are upsampled, this is used to interpolate histograms with varying ranges across observations

  • dtype – Quantized data type

  • qscheme – Quantization scheme to be used

  • reduce_range – Reduces the range of the quantized data type by 1 bit

The scale and zero point are computed as follows:

  1. Create the histogram of the incoming inputs.

    The histogram is computed continuously, and the ranges per bin change with every new tensor observed.

  2. Search the distribution in the histogram for optimal min/max values.

    The search for the min/max values ensures the minimization of the quantization error with respect to the floating point model.

  3. Compute the scale and zero point the same way as in the

    MinMaxObserver

class torch.quantization.FakeQuantize(observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, quant_min=0, quant_max=255, **observer_kwargs)[source]

Simulate the quantize and dequantize operations in training time. The output of this module is given by

x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale

  • scale defines the scale factor used for quantization.

  • zero_point specifies the quantized value to which 0 in floating point maps to

  • quant_min specifies the minimum allowable quantized value.

  • quant_max specifies the maximum allowable quantized value.

  • fake_quant_enable controls the application of fake quantization on tensors, note that statistics can still be updated.

  • observer_enable controls statistics collection on tensors

  • dtype specifies the quantized dtype that is being emulated with fake-quantization,

    allowable values are torch.qint8 and torch.quint8. The values of quant_min and quant_max should be chosen to be consistent with the dtype

Parameters
  • observer (module) – Module for observing statistics on input tensors and calculating scale and zero-point.

  • quant_min (int) – The minimum allowable quantized value.

  • quant_max (int) – The maximum allowable quantized value.

  • observer_kwargs (optional) – Arguments for the observer module

Variables

~FakeQuantize.observer (Module) – User provided module that collects statistics on the input tensor and provides a method to calculate scale and zero-point.

class torch.quantization.NoopObserver(dtype=torch.float16, custom_op_name='')[source]

Observer that doesn’t do anything and just passes its configuration to the quantized module’s .from_float().

Primarily used for quantization to float16 which doesn’t require determining ranges.

Parameters
  • dtype – Quantized data type

  • custom_op_name – (temporary) specify this observer for an operator that doesn’t require any observation (Can be used in Graph Mode Passes for special case ops).

Debugging utilities

torch.quantization.get_observer_dict(mod, target_dict, prefix='')[source]

Traverse the modules and save all observers into dict. This is mainly used for quantization accuracy debug :param mod: the top module we want to save all observers :param prefix: the prefix for the current module :param target_dict: the dictionary used to save all the observers

class torch.quantization.RecordingObserver(**kwargs)[source]

The module is mainly for debug and records the tensor values during runtime.

Parameters
  • dtype – Quantized data type

  • qscheme – Quantization scheme to be used

  • reduce_range – Reduces the range of the quantized data type by 1 bit

nn.intrinsic

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources