• Tutorials >
  • PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor
Shortcuts

PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor

Author: Leslie Fang, Weiwen Xia, Jiong Gong, Jerry Zhang

Introduction

This tutorial introduces the steps for utilizing the PyTorch 2 Export Quantization flow to generate a quantized model customized for the x86 inductor backend and explains how to lower the quantized model into the inductor.

The new quantization 2 flow uses the PT2 Export to capture the model into a graph and perform quantization transformations on top of the ATen graph. This approach is expected to have significantly higher model coverage, better programmability, and a simplified UX. TorchInductor is the new compiler backend that compiles the FX Graphs generated by TorchDynamo into optimized C++/Triton kernels.

This flow of quantization 2 with Inductor mainly includes three steps:

  • Step 1: Capture the FX Graph from the eager Model based on the torch export mechanism.

  • Step 2: Apply the Quantization flow based on the captured FX Graph, including defining the backend-specific quantizer, generating the prepared model with observers, performing the prepared model’s calibration, and converting the prepared model into the quantized model.

  • Step 3: Lower the quantized model into inductor with the API torch.compile.

The high-level architecture of this flow could look like this:

float_model(Python)                          Example Input
    \                                              /
     \                                            /
—--------------------------------------------------------
|                         export                       |
—--------------------------------------------------------
                            |
                    FX Graph in ATen
                            |            X86InductorQuantizer
                            |                 /
—--------------------------------------------------------
|                      prepare_pt2e                     |
|                           |                           |
|                     Calibrate/Train                   |
|                           |                           |
|                      convert_pt2e                     |
—--------------------------------------------------------
                            |
                     Quantized Model
                            |
—--------------------------------------------------------
|                    Lower into Inductor                |
—--------------------------------------------------------
                            |
                         Inductor

Combining Quantization in PyTorch 2 Export and TorchInductor, we have flexibility and productivity with the new Quantization frontend and outstanding out-of-box performance with the compiler backend. Especially on Intel fourth generation (SPR) Xeon processors which can further boost the models’ performance by leveraging the advanced-matrix-extensions feature.

Now, we will walk you through a step-by-step tutorial for how to use it with torchvision resnet18 model.

1. Capture FX Graph

We will start by performing the necessary imports, capturing the FX Graph from the eager module.

import torch
import torchvision.models as models
import copy
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
from torch._export import capture_pre_autograd_graph

# Create the Eager Model
model_name = "resnet18"
model = models.__dict__[model_name](pretrained=True)

# Set the model to eval mode
model = model.eval()

# Create the data, using the dummy data here as an example
traced_bs = 50
x = torch.randn(traced_bs, 3, 224, 224).contiguous(memory_format=torch.channels_last)
example_inputs = (x,)

# Capture the FX Graph to be quantized
with torch.no_grad():
     # if you are using the PyTorch nightlies or building from source with the pytorch master,
    # use the API of `capture_pre_autograd_graph`
    # Note 1: `capture_pre_autograd_graph` is also a short-term API, it will be updated to use the official `torch.export` API when that is ready.
    exported_model = capture_pre_autograd_graph(
        model,
        example_inputs
    )
    # Note 2: if you are using the PyTorch 2.1 release binary or building from source with the PyTorch 2.1 release branch,
    # please use the API of `torch._dynamo.export` to capture the FX Graph.
    # exported_model, guards = torch._dynamo.export(
    #     model,
    #     *copy.deepcopy(example_inputs),
    #     aten_graph=True,
    # )

Next, we will have the FX Module to be quantized.

2. Apply Quantization

After we capture the FX Module to be quantized, we will import the Backend Quantizer for X86 CPU and configure how to quantize the model.

quantizer = X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())

Note

The default quantization configuration in X86InductorQuantizer uses 8-bits for both activations and weights.

When Vector Neural Network Instruction is not available, the oneDNN backend silently chooses kernels that assume multiplications are 7-bit x 8-bit. In other words, potential numeric saturation and accuracy issue may happen when running on CPU without Vector Neural Network Instruction.

After we import the backend-specific Quantizer, we will prepare the model for post-training quantization. prepare_pt2e folds BatchNorm operators into preceding Conv2d operators, and inserts observers in appropriate places in the model.

prepared_model = prepare_pt2e(exported_model, quantizer)

Now, we will calibrate the prepared_model after the observers are inserted in the model.

# We use the dummy data as an example here
prepared_model(*example_inputs)

# Alternatively: user can define the dataset to calibrate
# def calibrate(model, data_loader):
#     model.eval()
#     with torch.no_grad():
#         for image, target in data_loader:
#             model(image)
# calibrate(prepared_model, data_loader_test)  # run calibration on sample data

Finally, we will convert the calibrated Model to a quantized Model. convert_pt2e takes a calibrated model and produces a quantized model.

converted_model = convert_pt2e(prepared_model)

After these steps, we finished running the quantization flow and we will get the quantized model.

3. Lower into Inductor

After we get the quantized model, we will further lower it to the inductor backend. The default Inductor wrapper generates Python code to invoke both generated kernels and external kernels. Additionally, Inductor supports C++ wrapper that generates pure C++ code. This allows seamless integration of the generated and external kernels, effectively reducing Python overhead. In the future, leveraging the C++ wrapper, we can extend the capability to achieve pure C++ deployment. For more comprehensive details about C++ Wrapper in general, please refer to the dedicated tutorial on Inductor C++ Wrapper Tutorial.

# Optional: using the C++ wrapper instead of default Python wrapper
import torch._inductor.config as config
config.cpp_wrapper = True
with torch.no_grad():
    optimized_model = torch.compile(converted_model)

    # Running some benchmark
    optimized_model(*example_inputs)

In a more advanced scenario, int8-mixed-bf16 quantization comes into play. In this instance, a Convolution or GEMM operator produces BFloat16 output data type instead of Float32 in the absence of a subsequent quantization node. Subsequently, the BFloat16 tensor seamlessly propagates through subsequent pointwise operators, effectively minimizing memory usage and potentially enhancing performance. The utilization of this feature mirrors that of regular BFloat16 Autocast, as simple as wrapping the script within the BFloat16 Autocast context.

with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True), torch.no_grad():
    # Turn on Autocast to use int8-mixed-bf16 quantization. After lowering into Inductor CPP Backend,
    # For operators such as QConvolution and QLinear:
    # * The input data type is consistently defined as int8, attributable to the presence of a pair
        of quantization and dequantization nodes inserted at the input.
    # * The computation precision remains at int8.
    # * The output data type may vary, being either int8 or BFloat16, contingent on the presence
    #   of a pair of quantization and dequantization nodes at the output.
    # For non-quantizable pointwise operators, the data type will be inherited from the previous node,
    # potentially resulting in a data type of BFloat16 in this scenario.
    # For quantizable pointwise operators such as QMaxpool2D, it continues to operate with the int8
    # data type for both input and output.
    optimized_model = torch.compile(converted_model)

    # Running some benchmark
    optimized_model(*example_inputs)

Put all these codes together, we will have the toy example code. Please note that since the Inductor freeze feature does not turn on by default yet, run your example code with TORCHINDUCTOR_FREEZING=1.

For example:

TORCHINDUCTOR_FREEZING=1 python example_x86inductorquantizer_pytorch_2_1.py

With PyTorch 2.1 release, all CNN models from TorchBench test suite have been measured and proven effective comparing with Inductor FP32 inference path. Please refer to this document for detail benchmark number.

4. Conclusion

With this tutorial, we introduce how to use Inductor with X86 CPU in PyTorch 2 Quantization. Users can learn about how to use X86InductorQuantizer to quantize a model and lower it into the inductor with X86 CPU devices.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources