• Tutorials >
  • (prototype) PyTorch 2 Export Quantization-Aware Training (QAT)
Shortcuts

(prototype) PyTorch 2 Export Quantization-Aware Training (QAT)

Author: Andrew Or

This tutorial shows how to perform quantization-aware training (QAT) in graph mode based on torch.export.export. For more details about PyTorch 2 Export Quantization in general, refer to the post training quantization tutorial.

The PyTorch 2 Export QAT flow looks like the following—it is similar to the post training quantization (PTQ) flow for the most part:

import torch
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import (
  prepare_qat_pt2e,
  convert_pt2e,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
  XNNPACKQuantizer,
  get_symmetric_quantization_config,
)

class M(torch.nn.Module):
   def __init__(self):
      super().__init__()
      self.linear = torch.nn.Linear(5, 10)

   def forward(self, x):
      return self.linear(x)


example_inputs = (torch.randn(1, 5),)
m = M()

# Step 1. program capture
# This is available for pytorch 2.5+, for more details on lower pytorch versions
# please check `Export the model with torch.export` section
m = torch.export.export_for_training(m, example_inputs).module()
# we get a model with aten ops

# Step 2. quantization-aware training
# backend developer will write their own Quantizer and expose methods to allow
# users to express how they want the model to be quantized
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
m = prepare_qat_pt2e(m, quantizer)

# train omitted

m = convert_pt2e(m)
# we have a model with aten ops doing integer computations when possible

# move the quantized model to eval mode, equivalent to `m.eval()`
torch.ao.quantization.move_exported_model_to_eval(m)

Note that calling model.eval() or model.train() after program capture is not allowed, because these methods no longer correctly change the behavior of certain ops like dropout and batch normalization. Instead, please use torch.ao.quantization.move_exported_model_to_eval() and torch.ao.quantization.move_exported_model_to_train() (coming soon) respectively.

Define Helper Functions and Prepare the Dataset

To run the code in this tutorial using the entire ImageNet dataset, first download ImageNet by following the instructions in ImageNet Data. Unzip the downloaded file into the data_path folder.

Next, download the torchvision resnet18 model and rename it to data/resnet18_pretrained_float.pth.

We’ll start by doing the necessary imports, defining some helper functions and prepare the data. These steps are very similar to the ones defined in the static eager mode post training quantization tutorial:

import os
import sys
import time
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets
from torchvision.models.resnet import resnet18
import torchvision.transforms as transforms

# Set up warnings
import warnings
warnings.filterwarnings(
    action='ignore',
    category=DeprecationWarning,
    module=r'.*'
)
warnings.filterwarnings(
    action='default',
    module=r'torch.ao.quantization'
)

# Specify random seed for repeatable results
_ = torch.manual_seed(191009)

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

def accuracy(output, target, topk=(1,)):
    """
    Computes the accuracy over the k top predictions for the specified
    values of k.
    """
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def evaluate(model, criterion, data_loader, device):
    torch.ao.quantization.move_exported_model_to_eval(model)
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    cnt = 0
    with torch.no_grad():
        for image, target in data_loader:
            image = image.to(device)
            target = target.to(device)
            output = model(image)
            loss = criterion(output, target)
            cnt += 1
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            top1.update(acc1[0], image.size(0))
            top5.update(acc5[0], image.size(0))
    print('')

    return top1, top5

def load_model(model_file):
    model = resnet18(pretrained=False)
    state_dict = torch.load(model_file, weights_only=True)
    model.load_state_dict(state_dict)
    return model

def print_size_of_model(model):
    if isinstance(model, torch.jit.RecursiveScriptModule):
        torch.jit.save(model, "temp.p")
    else:
        torch.jit.save(torch.jit.script(model), "temp.p")
    print("Size (MB):", os.path.getsize("temp.p")/1e6)
    os.remove("temp.p")

def prepare_data_loaders(data_path):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    dataset = torchvision.datasets.ImageNet(
        data_path, split="train", transform=transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    dataset_test = torchvision.datasets.ImageNet(
        data_path, split="val", transform=transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=train_batch_size,
        sampler=train_sampler)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=eval_batch_size,
        sampler=test_sampler)

    return data_loader, data_loader_test

def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_batches):
    # Note: do not call model.train() here, since this doesn't work on an exported model.
    # Instead, call `torch.ao.quantization.move_exported_model_to_train(model)`, which will
    # be added in the near future
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    avgloss = AverageMeter('Loss', '1.5f')

    cnt = 0
    for image, target in data_loader:
        start_time = time.time()
        print('.', end = '')
        cnt += 1
        image, target = image.to(device), target.to(device)
        output = model(image)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        top1.update(acc1[0], image.size(0))
        top5.update(acc5[0], image.size(0))
        avgloss.update(loss, image.size(0))
        if cnt >= ntrain_batches:
            print('Loss', avgloss.avg)

            print('Training: * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
                  .format(top1=top1, top5=top5))
            return

    print('Full imagenet train set:  * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}'
          .format(top1=top1, top5=top5))
    return

data_path = '~/.data/imagenet'
saved_model_dir = 'data/'
float_model_file = 'resnet18_pretrained_float.pth'

train_batch_size = 32
eval_batch_size = 32

data_loader, data_loader_test = prepare_data_loaders(data_path)
example_inputs = (next(iter(data_loader))[0])
criterion = nn.CrossEntropyLoss()
float_model = load_model(saved_model_dir + float_model_file).to("cuda")

Export the model with torch.export

Here is how you can use torch.export to export the model:

from torch._export import capture_pre_autograd_graph

example_inputs = (torch.rand(2, 3, 224, 224),)
# for pytorch 2.5+
exported_model = torch.export.export_for_training(float_model, example_inputs).module()
# for pytorch 2.4 and before
# from torch._export import capture_pre_autograd_graph
# exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs)
# or, to capture with dynamic dimensions:

# for pytorch 2.5+
dynamic_shapes = tuple(
  {0: torch.export.Dim("dim")} if i == 0 else None
  for i in range(len(example_inputs))
)
exported_model = torch.export.export_for_training(float_model, example_inputs, dynamic_shapes=dynamic_shapes).module()

# for pytorch 2.4 and before
# dynamic_shape API may vary as well
# from torch._export import dynamic_dim

# example_inputs = (torch.rand(2, 3, 224, 224),)
# exported_model = capture_pre_autograd_graph(
#     float_model,
#     example_inputs,
#     constraints=[dynamic_dim(example_inputs[0], 0)],
# )

Import the Backend Specific Quantizer and Configure how to Quantize the Model

The following code snippets describe how to quantize the model:

from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
)
quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))

Quantizer is backend specific, and each Quantizer will provide their own way to allow users to configure their model.

Note

Check out our tutorial that describes how to write a new Quantizer.

Prepare the Model for Quantization-Aware Training

prepare_qat_pt2e inserts fake quantizes in appropriate places in the model and performs the appropriate QAT “fusions”, such as Conv2d + BatchNorm2d, for better training accuracies. The fused operations are represented as a subgraph of ATen ops in the prepared graph.

prepared_model = prepare_qat_pt2e(exported_model, quantizer)
print(prepared_model)

Note

If your model contains batch normalization, the actual ATen ops you get in the graph depend on the model’s device when you export the model. If the model is on CPU, then you’ll get torch.ops.aten._native_batch_norm_legit. If the model is on CUDA, then you’ll get torch.ops.aten.cudnn_batch_norm. However, this is not fundamental and may be subject to change in the future.

Between these two ops, it has been shown that torch.ops.aten.cudnn_batch_norm provides better numerics on models like MobileNetV2. To get this op, either call model.cuda() before export, or run the following after prepare to manually swap the ops:

for n in prepared_model.graph.nodes:
    if n.target == torch.ops.aten._native_batch_norm_legit.default:
        n.target = torch.ops.aten.cudnn_batch_norm.default
prepared_model.recompile()

In the future, we plan to consolidate the batch normalization ops such that the above will no longer be necessary.

Training Loop

The training loop is similar to the ones in previous versions of QAT. To achieve better accuracies, you may optionally disable observers and updating batch normalization statistics after a certain number of epochs, or evaluate the QAT or the quantized model trained so far every N epochs.

num_epochs = 10
num_train_batches = 20
num_eval_batches = 20
num_observer_update_epochs = 4
num_batch_norm_update_epochs = 3
num_epochs_between_evals = 2

# QAT takes time and one needs to train over a few epochs.
# Train and check accuracy after each epoch
for nepoch in range(num_epochs):
    train_one_epoch(prepared_model, criterion, optimizer, data_loader, "cuda", num_train_batches)

    # Optionally disable observer/batchnorm stats after certain number of epochs
    if epoch >= num_observer_update_epochs:
        print("Disabling observer for subseq epochs, epoch = ", epoch)
        prepared_model.apply(torch.ao.quantization.disable_observer)
    if epoch >= num_batch_norm_update_epochs:
        print("Freezing BN for subseq epochs, epoch = ", epoch)
        for n in prepared_model.graph.nodes:
            # Args: input, weight, bias, running_mean, running_var, training, momentum, eps
            # We set the `training` flag to False here to freeze BN stats
            if n.target in [
                torch.ops.aten._native_batch_norm_legit.default,
                torch.ops.aten.cudnn_batch_norm.default,
            ]:
                new_args = list(n.args)
                new_args[5] = False
                n.args = new_args
        prepared_model.recompile()

    # Check the quantized accuracy every N epochs
    # Note: If you wish to just evaluate the QAT model (not the quantized model),
    # then you can just call `torch.ao.quantization.move_exported_model_to_eval/train`.
    # However, the latter API is not ready yet and will be available in the near future.
    if (nepoch + 1) % num_epochs_between_evals == 0:
        prepared_model_copy = copy.deepcopy(prepared_model)
        quantized_model = convert_pt2e(prepared_model_copy)
        top1, top5 = evaluate(quantized_model, criterion, data_loader_test, neval_batches=num_eval_batches)
        print('Epoch %d: Evaluation accuracy on %d images, %2.2f' % (nepoch, num_eval_batches * eval_batch_size, top1.avg))

Saving and Loading Model Checkpoints

Model checkpoints for the PyTorch 2 Export QAT flow are the same as in any other training flow. They are useful for pausing training and resuming it later, recovering from failed training runs, and performing inference on different machines at a later time. You can save model checkpoints during or after training as follows:

checkpoint_path = "/path/to/my/checkpoint_%s.pth" % nepoch
torch.save(prepared_model.state_dict(), "checkpoint_path")

To load the checkpoints, you must export and prepare the model the exact same way it was initially exported and prepared. For example:

from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
)
from torchvision.models.resnet import resnet18

example_inputs = (torch.rand(2, 3, 224, 224),)
float_model = resnet18(pretrained=False)
exported_model = capture_pre_autograd_graph(float_model, example_inputs)
quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
prepared_model = prepare_qat_pt2e(exported_model, quantizer)
prepared_model.load_state_dict(torch.load(checkpoint_path))

# resume training or perform inference

Convert the Trained Model to a Quantized Model

convert_pt2e takes a calibrated model and produces a quantized model. Note that, before inference, you must first call torch.ao.quantization.move_exported_model_to_eval() to ensure certain ops like dropout behave correctly in the eval graph. Otherwise, we would continue to incorrectly apply dropout in the forward pass during inference, for example.

quantized_model = convert_pt2e(prepared_model)

# move certain ops like dropout to eval mode, equivalent to `m.eval()`
torch.ao.quantization.move_exported_model_to_eval(m)

print(quantized_model)

top1, top5 = evaluate(quantized_model, criterion, data_loader_test, neval_batches=num_eval_batches)
print('Final evaluation accuracy on %d images, %2.2f' % (num_eval_batches * eval_batch_size, top1.avg))

Conclusion

In this tutorial, we demonstrated how to run Quantization-Aware Training (QAT) flow in PyTorch 2 Export Quantization. After convert, the rest of the flow is the same as Post-Training Quantization (PTQ); the user can serialize/deserialize the model and further lower it to a backend that supports inference with XNNPACK backend. For more detail, follow the PTQ tutorial.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources