• Docs >
  • Deploying Quantization Aware Trained models in INT8 using Torch-TensorRT
Shortcuts

Deploying Quantization Aware Trained models in INT8 using Torch-TensorRT

Overview

Quantization Aware training (QAT) simulates quantization during training by quantizing weights and activation layers. This will help to reduce the loss in accuracy when we convert the network trained in FP32 to INT8 for faster inference. QAT introduces additional nodes in the graph which will be used to learn the dynamic ranges of weights and activation layers. In this notebook, we illustrate the following steps from training to inference of a QAT model in Torch-TensorRT.

  1. Requirements

  2. VGG16 Overview

  3. Training a baseline VGG16 model

  4. Apply Quantization

  5. Model calibration

  6. Quantization Aware training

  7. Export to Torchscript

  8. Inference using Torch-TensorRT

  9. References

## 1. Requirements Please install the required dependencies and import these libraries accordingly

[ ]:
!pip install ipywidgets --trusted-host pypi.org --trusted-host pypi.python.org --trusted-host=files.pythonhosted.org
[1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch_tensorrt

from torch.utils.tensorboard import SummaryWriter

import pytorch_quantization
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import quant_modules
from pytorch_quantization.tensor_quant import QuantDescriptor
from pytorch_quantization import calib
from tqdm import tqdm

print(pytorch_quantization.__version__)

import os
import sys
sys.path.insert(0, "../examples/int8/training/vgg16")
from vgg16 import vgg16

2.1.0

## 2. VGG16 Overview ### Very Deep Convolutional Networks for Large-Scale Image Recognition VGG is one of the earliest family of image classification networks that first used small (3x3) convolution filters and achieved significant improvements on ImageNet recognition challenge. The network architecture looks as follows 78a60edfa3614a2286f59fdac891c4d2

## 3. Training a baseline VGG16 model We train VGG16 on CIFAR10 dataset. Define training and testing datasets and dataloaders. This will download the CIFAR 10 data in your data directory. Data preprocessing is performed using torchvision transforms.

[2]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# ========== Define Training dataset and dataloaders =============#
training_dataset = datasets.CIFAR10(root='./data',
                                        train=True,
                                        download=True,
                                        transform=transforms.Compose([
                                            transforms.RandomCrop(32, padding=4),
                                            transforms.RandomHorizontalFlip(),
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                        ]))

training_dataloader = torch.utils.data.DataLoader(training_dataset,
                                                      batch_size=32,
                                                      shuffle=True,
                                                      num_workers=2)

# ========== Define Testing dataset and dataloaders =============#
testing_dataset = datasets.CIFAR10(root='./data',
                                   train=False,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                   ]))

testing_dataloader = torch.utils.data.DataLoader(testing_dataset,
                                                 batch_size=16,
                                                 shuffle=False,
                                                 num_workers=2)

Files already downloaded and verified
Files already downloaded and verified
[3]:
def train(model, dataloader, crit, opt, epoch):
#     global writer
    model.train()
    running_loss = 0.0
    for batch, (data, labels) in enumerate(dataloader):
        data, labels = data.cuda(), labels.cuda(non_blocking=True)
        opt.zero_grad()
        out = model(data)
        loss = crit(out, labels)
        loss.backward()
        opt.step()

        running_loss += loss.item()
        if batch % 500 == 499:
            print("Batch: [%5d | %5d] loss: %.3f" % (batch + 1, len(dataloader), running_loss / 100))
            running_loss = 0.0

def test(model, dataloader, crit, epoch):
    global writer
    global classes
    total = 0
    correct = 0
    loss = 0.0
    class_probs = []
    class_preds = []
    model.eval()
    with torch.no_grad():
        for data, labels in dataloader:
            data, labels = data.cuda(), labels.cuda(non_blocking=True)
            out = model(data)
            loss += crit(out, labels)
            preds = torch.max(out, 1)[1]
            class_probs.append([F.softmax(i, dim=0) for i in out])
            class_preds.append(preds)
            total += labels.size(0)
            correct += (preds == labels).sum().item()

    test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
    test_preds = torch.cat(class_preds)

    return loss / total, correct / total

def save_checkpoint(state, ckpt_path="checkpoint.pth"):
    torch.save(state, ckpt_path)
    print("Checkpoint saved")

Define the VGG model that we are going to perfom QAT on.

[4]:
# CIFAR 10 has 10 classes
model = vgg16(num_classes=len(classes), init_weights=False)
model = model.cuda()
[5]:
# Declare Learning rate
lr = 0.1
state = {}
state["lr"] = lr

# Use cross entropy loss for classification and SGD optimizer
crit = nn.CrossEntropyLoss()
opt = optim.SGD(model.parameters(), lr=state["lr"], momentum=0.9, weight_decay=1e-4)


# Adjust learning rate based on epoch number
def adjust_lr(optimizer, epoch):
    global state
    new_lr = lr * (0.5**(epoch // 12)) if state["lr"] > 1e-7 else state["lr"]
    if new_lr != state["lr"]:
        state["lr"] = new_lr
        print("Updating learning rate: {}".format(state["lr"]))
        for param_group in optimizer.param_groups:
            param_group["lr"] = state["lr"]
[6]:
# Train the model for 25 epochs to get ~80% accuracy.
num_epochs=25
for epoch in range(num_epochs):
    adjust_lr(opt, epoch)
    print('Epoch: [%5d / %5d] LR: %f' % (epoch + 1, num_epochs, state["lr"]))

    train(model, training_dataloader, crit, opt, epoch)
    test_loss, test_acc = test(model, testing_dataloader, crit, epoch)

    print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))

save_checkpoint({'epoch': epoch + 1,
                 'model_state_dict': model.state_dict(),
                 'acc': test_acc,
                 'opt_state_dict': opt.state_dict(),
                 'state': state},
                ckpt_path="vgg16_base_ckpt")
Epoch: [    1 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 13.288
Batch: [ 1000 |  1563] loss: 11.345
Batch: [ 1500 |  1563] loss: 11.008
Test Loss: 0.13388 Test Acc: 13.23%
Epoch: [    2 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 10.742
Batch: [ 1000 |  1563] loss: 10.311
Batch: [ 1500 |  1563] loss: 10.141
Test Loss: 0.11888 Test Acc: 23.96%
Epoch: [    3 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 9.877
Batch: [ 1000 |  1563] loss: 9.821
Batch: [ 1500 |  1563] loss: 9.818
Test Loss: 0.11879 Test Acc: 24.68%
Epoch: [    4 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 9.677
Batch: [ 1000 |  1563] loss: 9.613
Batch: [ 1500 |  1563] loss: 9.504
Test Loss: 0.11499 Test Acc: 23.68%
Epoch: [    5 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 9.560
Batch: [ 1000 |  1563] loss: 9.536
Batch: [ 1500 |  1563] loss: 9.309
Test Loss: 0.10990 Test Acc: 27.84%
Epoch: [    6 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 9.254
Batch: [ 1000 |  1563] loss: 9.234
Batch: [ 1500 |  1563] loss: 9.188
Test Loss: 0.11594 Test Acc: 23.29%
Epoch: [    7 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 9.141
Batch: [ 1000 |  1563] loss: 9.110
Batch: [ 1500 |  1563] loss: 9.013
Test Loss: 0.10732 Test Acc: 29.24%
Epoch: [    8 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 9.120
Batch: [ 1000 |  1563] loss: 9.086
Batch: [ 1500 |  1563] loss: 8.948
Test Loss: 0.10732 Test Acc: 27.24%
Epoch: [    9 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 8.941
Batch: [ 1000 |  1563] loss: 8.997
Batch: [ 1500 |  1563] loss: 9.028
Test Loss: 0.11299 Test Acc: 25.52%
Epoch: [   10 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 8.927
Batch: [ 1000 |  1563] loss: 8.837
Batch: [ 1500 |  1563] loss: 8.860
Test Loss: 0.10130 Test Acc: 34.61%
Epoch: [   11 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 8.953
Batch: [ 1000 |  1563] loss: 8.738
Batch: [ 1500 |  1563] loss: 8.724
Test Loss: 0.10018 Test Acc: 32.27%
Epoch: [   12 /    25] LR: 0.100000
Batch: [  500 |  1563] loss: 8.721
Batch: [ 1000 |  1563] loss: 8.716
Batch: [ 1500 |  1563] loss: 8.701
Test Loss: 0.10070 Test Acc: 29.57%
Updating learning rate: 0.05
Epoch: [   13 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 7.944
Batch: [ 1000 |  1563] loss: 7.649
Batch: [ 1500 |  1563] loss: 7.511
Test Loss: 0.08555 Test Acc: 44.62%
Epoch: [   14 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 7.057
Batch: [ 1000 |  1563] loss: 6.944
Batch: [ 1500 |  1563] loss: 6.687
Test Loss: 0.08331 Test Acc: 52.27%
Epoch: [   15 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 6.470
Batch: [ 1000 |  1563] loss: 6.439
Batch: [ 1500 |  1563] loss: 6.126
Test Loss: 0.07266 Test Acc: 58.02%
Epoch: [   16 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 5.834
Batch: [ 1000 |  1563] loss: 5.801
Batch: [ 1500 |  1563] loss: 5.622
Test Loss: 0.06340 Test Acc: 65.17%
Epoch: [   17 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 5.459
Batch: [ 1000 |  1563] loss: 5.442
Batch: [ 1500 |  1563] loss: 5.314
Test Loss: 0.05945 Test Acc: 67.22%
Epoch: [   18 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 5.071
Batch: [ 1000 |  1563] loss: 5.145
Batch: [ 1500 |  1563] loss: 5.063
Test Loss: 0.06567 Test Acc: 64.46%
Epoch: [   19 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 4.796
Batch: [ 1000 |  1563] loss: 4.781
Batch: [ 1500 |  1563] loss: 4.732
Test Loss: 0.05374 Test Acc: 71.87%
Epoch: [   20 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 4.568
Batch: [ 1000 |  1563] loss: 4.564
Batch: [ 1500 |  1563] loss: 4.484
Test Loss: 0.05311 Test Acc: 71.12%
Epoch: [   21 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 4.385
Batch: [ 1000 |  1563] loss: 4.302
Batch: [ 1500 |  1563] loss: 4.285
Test Loss: 0.05080 Test Acc: 74.29%
Epoch: [   22 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 4.069
Batch: [ 1000 |  1563] loss: 4.105
Batch: [ 1500 |  1563] loss: 4.096
Test Loss: 0.04807 Test Acc: 75.20%
Epoch: [   23 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 3.959
Batch: [ 1000 |  1563] loss: 3.898
Batch: [ 1500 |  1563] loss: 3.916
Test Loss: 0.04743 Test Acc: 75.81%
Epoch: [   24 /    25] LR: 0.050000
Batch: [  500 |  1563] loss: 3.738
Batch: [ 1000 |  1563] loss: 3.847
Batch: [ 1500 |  1563] loss: 3.797
Test Loss: 0.04609 Test Acc: 76.42%
Updating learning rate: 0.025
Epoch: [   25 /    25] LR: 0.025000
Batch: [  500 |  1563] loss: 2.952
Batch: [ 1000 |  1563] loss: 2.906
Batch: [ 1500 |  1563] loss: 2.735
Test Loss: 0.03466 Test Acc: 82.00%
Checkpoint saved

## 4. Apply Quantization

quant_modules.initialize() will ensure quantized version of modules will be called instead of original modules. For example, when you define a model with convolution, linear, pooling layers, QuantConv2d, QuantLinear and QuantPooling will be called. QuantConv2d basically wraps quantizer nodes around inputs and weights of regular Conv2d. Please refer to all the quantized modules in pytorch-quantization toolkit for more information. A QuantConv2d is represented in pytorch-quantization toolkit as follows.

def forward(self, input):
        # the actual quantization happens in the next level of the class hierarchy
        quant_input, quant_weight = self._quant(input)

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
                                (self.padding[0] + 1) // 2, self.padding[0] // 2)
            output = F.conv2d(F.pad(quant_input, expanded_padding, mode='circular'),
                              quant_weight, self.bias, self.stride,
                              _pair(0), self.dilation, self.groups)
        else:
            output = F.conv2d(quant_input, quant_weight, self.bias, self.stride, self.padding, self.dilation,
                              self.groups)

        return output
[7]:
quant_modules.initialize()
[8]:
# All the regular conv, FC layers will be converted to their quantozed counterparts due to quant_modules.initialize()
qat_model = vgg16(num_classes=len(classes), init_weights=False)
qat_model = qat_model.cuda()
[9]:
# vgg16_base_ckpt is the checkpoint generated from Step 3 : Training a baseline VGG16 model.
ckpt = torch.load("./vgg16_base_ckpt")
modified_state_dict={}
for key, val in ckpt["model_state_dict"].items():
    # Remove 'module.' from the key names
    if key.startswith('module'):
        modified_state_dict[key[7:]] = val
    else:
        modified_state_dict[key] = val

# Load the pre-trained checkpoint
qat_model.load_state_dict(modified_state_dict)
opt.load_state_dict(ckpt["opt_state_dict"])

## 5. Model Calibration

The quantizer nodes introduced in the model around desired layers capture the dynamic range (min_value, max_value) that is observed by the layer. Calibration is the process of computing the dynamic range of these layers by passing calibration data, which is usually a subset of training or validation data. There are different ways of calibration: max, histogram and entropy. We use max calibration technique as it is simple and effective.

[10]:
def compute_amax(model, **kwargs):
    # Load calib result
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                if isinstance(module._calibrator, calib.MaxCalibrator):
                    module.load_calib_amax()
                else:
                    module.load_calib_amax(**kwargs)
            print(F"{name:40}: {module}")
    model.cuda()

def collect_stats(model, data_loader, num_batches):
    """Feed data to the network and collect statistics"""
    # Enable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()

    # Feed data to the network for collecting stats
    for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):
        model(image.cuda())
        if i >= num_batches:
            break

    # Disable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.enable_quant()
                module.disable_calib()
            else:
                module.enable()

def calibrate_model(model, model_name, data_loader, num_calib_batch, calibrator, hist_percentile, out_dir):
    """
        Feed data to the network and calibrate.
        Arguments:
            model: classification model
            model_name: name to use when creating state files
            data_loader: calibration data set
            num_calib_batch: amount of calibration passes to perform
            calibrator: type of calibration to use (max/histogram)
            hist_percentile: percentiles to be used for historgram calibration
            out_dir: dir to save state files in
    """

    if num_calib_batch > 0:
        print("Calibrating model")
        with torch.no_grad():
            collect_stats(model, data_loader, num_calib_batch)

        if not calibrator == "histogram":
            compute_amax(model, method="max")
            calib_output = os.path.join(
                out_dir,
                F"{model_name}-max-{num_calib_batch*data_loader.batch_size}.pth")
            torch.save(model.state_dict(), calib_output)
        else:
            for percentile in hist_percentile:
                print(F"{percentile} percentile calibration")
                compute_amax(model, method="percentile")
                calib_output = os.path.join(
                    out_dir,
                    F"{model_name}-percentile-{percentile}-{num_calib_batch*data_loader.batch_size}.pth")
                torch.save(model.state_dict(), calib_output)

            for method in ["mse", "entropy"]:
                print(F"{method} calibration")
                compute_amax(model, method=method)
                calib_output = os.path.join(
                    out_dir,
                    F"{model_name}-{method}-{num_calib_batch*data_loader.batch_size}.pth")
                torch.save(model.state_dict(), calib_output)
[11]:
#Calibrate the model using max calibration technique.
with torch.no_grad():
    calibrate_model(
        model=qat_model,
        model_name="vgg16",
        data_loader=training_dataloader,
        num_calib_batch=32,
        calibrator="max",
        hist_percentile=[99.9, 99.99, 99.999, 99.9999],
        out_dir="./")
Calibrating model
100%|███████████████████████████████████████████████████████| 32/32 [00:00<00:00, 96.04it/s]
WARNING: Logging before flag parsing goes to stderr.
W1109 04:01:43.512364 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.513354 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.514046 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.514638 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.515270 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.515859 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.516441 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.517009 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.517600 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.518167 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.518752 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.519333 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.519911 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.520473 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.521038 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.521596 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.522170 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.522742 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.523360 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.523957 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.524581 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.525059 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.525366 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.525675 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.525962 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.526257 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.526566 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.526885 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.527188 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.527489 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.527792 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.528097 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.528387 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator
W1109 04:01:43.528834 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1109 04:01:43.529163 139704147265344 tensor_quantizer.py:238] Call .cuda() if running on GPU after loading calibrated amax.
W1109 04:01:43.532748 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([64, 1, 1, 1]).
W1109 04:01:43.533468 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1109 04:01:43.534033 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([64, 1, 1, 1]).
W1109 04:01:43.534684 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1109 04:01:43.535320 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([128, 1, 1, 1]).
W1109 04:01:43.535983 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1109 04:01:43.536569 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([128, 1, 1, 1]).
W1109 04:01:43.537248 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1109 04:01:43.537833 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([256, 1, 1, 1]).
W1109 04:01:43.538480 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1109 04:01:43.539074 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([256, 1, 1, 1]).
W1109 04:01:43.539724 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1109 04:01:43.540307 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([256, 1, 1, 1]).
W1109 04:01:43.540952 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1109 04:01:43.541534 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).
W1109 04:01:43.542075 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1109 04:01:43.542596 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).
W1109 04:01:43.543248 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1109 04:01:43.543719 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).
W1109 04:01:43.544424 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1109 04:01:43.544952 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).
W1109 04:01:43.545530 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1109 04:01:43.546114 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).
W1109 04:01:43.546713 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1109 04:01:43.547292 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).
W1109 04:01:43.547902 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1109 04:01:43.548453 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1109 04:01:43.549015 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([4096, 1]).
W1109 04:01:43.549665 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1109 04:01:43.550436 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([4096, 1]).
W1109 04:01:43.551925 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).
W1109 04:01:43.553105 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([10, 1]).
features.0._input_quantizer             : TensorQuantizer(8bit narrow fake per-tensor amax=2.7537 calibrator=MaxCalibrator scale=1.0 quant)
features.0._weight_quantizer            : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0263, 2.7454](64) calibrator=MaxCalibrator scale=1.0 quant)
features.3._input_quantizer             : TensorQuantizer(8bit narrow fake per-tensor amax=27.5676 calibrator=MaxCalibrator scale=1.0 quant)
features.3._weight_quantizer            : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0169, 1.8204](64) calibrator=MaxCalibrator scale=1.0 quant)
features.7._input_quantizer             : TensorQuantizer(8bit narrow fake per-tensor amax=15.2002 calibrator=MaxCalibrator scale=1.0 quant)
features.7._weight_quantizer            : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0493, 1.3207](128) calibrator=MaxCalibrator scale=1.0 quant)
features.10._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=7.7376 calibrator=MaxCalibrator scale=1.0 quant)
features.10._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0163, 0.9624](128) calibrator=MaxCalibrator scale=1.0 quant)
features.14._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=8.8351 calibrator=MaxCalibrator scale=1.0 quant)
features.14._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0622, 0.8791](256) calibrator=MaxCalibrator scale=1.0 quant)
features.17._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=12.5746 calibrator=MaxCalibrator scale=1.0 quant)
features.17._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0505, 0.5117](256) calibrator=MaxCalibrator scale=1.0 quant)
features.20._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=9.7203 calibrator=MaxCalibrator scale=1.0 quant)
features.20._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0296, 0.5335](256) calibrator=MaxCalibrator scale=1.0 quant)
features.24._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=8.9367 calibrator=MaxCalibrator scale=1.0 quant)
features.24._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0220, 0.3763](512) calibrator=MaxCalibrator scale=1.0 quant)
features.27._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=6.6539 calibrator=MaxCalibrator scale=1.0 quant)
features.27._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0151, 0.1777](512) calibrator=MaxCalibrator scale=1.0 quant)
features.30._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=3.7099 calibrator=MaxCalibrator scale=1.0 quant)
features.30._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0087, 0.1906](512) calibrator=MaxCalibrator scale=1.0 quant)
features.34._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=4.0491 calibrator=MaxCalibrator scale=1.0 quant)
features.34._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0106, 0.1971](512) calibrator=MaxCalibrator scale=1.0 quant)
features.37._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=2.1531 calibrator=MaxCalibrator scale=1.0 quant)
features.37._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0070, 0.2305](512) calibrator=MaxCalibrator scale=1.0 quant)
features.40._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=3.3631 calibrator=MaxCalibrator scale=1.0 quant)
features.40._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0023, 0.4726](512) calibrator=MaxCalibrator scale=1.0 quant)
avgpool._input_quantizer                : TensorQuantizer(8bit narrow fake per-tensor amax=5.3550 calibrator=MaxCalibrator scale=1.0 quant)
classifier.0._input_quantizer           : TensorQuantizer(8bit narrow fake per-tensor amax=5.3550 calibrator=MaxCalibrator scale=1.0 quant)
classifier.0._weight_quantizer          : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0026, 0.5320](4096) calibrator=MaxCalibrator scale=1.0 quant)
classifier.3._input_quantizer           : TensorQuantizer(8bit narrow fake per-tensor amax=6.6733 calibrator=MaxCalibrator scale=1.0 quant)
classifier.3._weight_quantizer          : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0018, 0.5172](4096) calibrator=MaxCalibrator scale=1.0 quant)
classifier.6._input_quantizer           : TensorQuantizer(8bit narrow fake per-tensor amax=9.4352 calibrator=MaxCalibrator scale=1.0 quant)
classifier.6._weight_quantizer          : TensorQuantizer(8bit narrow fake axis=0 amax=[0.3877, 0.5620](10) calibrator=MaxCalibrator scale=1.0 quant)

## 6. Quantization Aware Training

In this phase, we finetune the model weights and leave the quantizer node values frozen. The dynamic ranges for each layer obtained from the calibration are kept constant while the weights of the model are finetuned to be close to the accuracy of original FP32 model (model without quantizer nodes) is preserved. Usually the finetuning of QAT model should be quick compared to the full training of the original model. Use QAT to fine-tune for around 10% of the original training schedule with an annealing learning-rate. Please refer to Achieving FP32 Accuracy for INT8 Inference Using Quantization Aware Training with NVIDIA TensorRT for detailed recommendations. For this VGG model, it is enough to finetune for 1 epoch to get acceptable accuracy. During finetuning with QAT, the quantization is applied as a composition of max, clamp, round and mul ops.

# amax is absolute maximum value for an input
# The upper bound for integer quantization (127 for int8)
max_bound = torch.tensor((2.0**(num_bits - 1 + int(unsigned))) - 1.0, device=amax.device)
scale = max_bound / amax
outputs = torch.clamp((inputs * scale).round_(), min_bound, max_bound)

tensor_quant function in pytorch_quantization toolkit is responsible for the above tensor quantization. Usually, per channel quantization is recommended for weights, while per tensor quantization is recommended for activations in a network. During inference, we use torch.fake_quantize_per_tensor_affine and torch.fake_quantize_per_channel_affine to perform quantization as this is easier to convert into corresponding TensorRT operators. Please refer to next sections for more details on how these operators are exported in torchscript and converted in Torch-TensorRT.

[12]:
# Finetune the QAT model for 1 epoch
num_epochs=1
for epoch in range(num_epochs):
    adjust_lr(opt, epoch)
    print('Epoch: [%5d / %5d] LR: %f' % (epoch + 1, num_epochs, state["lr"]))

    train(qat_model, training_dataloader, crit, opt, epoch)
    test_loss, test_acc = test(qat_model, testing_dataloader, crit, epoch)

    print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))

save_checkpoint({'epoch': epoch + 1,
                 'model_state_dict': qat_model.state_dict(),
                 'acc': test_acc,
                 'opt_state_dict': opt.state_dict(),
                 'state': state},
                ckpt_path="vgg16_qat_ckpt")
Updating learning rate: 0.1
Epoch: [    1 /     1] LR: 0.100000
Batch: [  500 |  1563] loss: 2.635
Batch: [ 1000 |  1563] loss: 2.655
Batch: [ 1500 |  1563] loss: 2.646
Test Loss: 0.03291 Test Acc: 82.98%
Checkpoint saved

## 7. Export to Torchscript Export the model to Torch script. Trace the model and convert it into torchscript for deployment. To learn more about Torchscript, please refer to https://pytorch.org/docs/stable/jit.html. Setting quant_nn.TensorQuantizer.use_fb_fake_quant = True enables the QAT model to use torch.fake_quantize_per_tensor_affine and torch.fake_quantize_per_channel_affine operators instead of tensor_quant function to export quantization operators. In torchscript, they are represented as aten::fake_quantize_per_tensor_affine and aten::fake_quantize_per_channel_affine.

[13]:
quant_nn.TensorQuantizer.use_fb_fake_quant = True
with torch.no_grad():
    data = iter(testing_dataloader)
    images, _ = data.next()
    jit_model = torch.jit.trace(qat_model, images.to("cuda"))
    torch.jit.save(jit_model, "trained_vgg16_qat.jit.pt")
E1109 04:02:37.101168 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.102248 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.107194 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.107625 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.115269 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.115740 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.117969 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.118358 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.126382 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.126834 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.128674 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.129518 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.135453 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.135936 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.137858 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.138366 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.145539 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.146053 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.147871 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.148353 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.154252 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.154685 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.156558 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.157159 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.163197 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.163676 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.165549 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.165991 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.173305 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.173926 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.176034 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.176697 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.182843 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.183426 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.185377 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.185962 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.191966 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.192424 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.194325 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.194817 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.201988 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.202665 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.204763 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.205461 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.211393 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.211987 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.213899 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.214450 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.220892 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.221533 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.223519 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.224037 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.233809 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.234434 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.238212 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.239042 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.241022 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.241654 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.247820 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.248445 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.250366 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.250959 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.257248 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.257854 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.259968 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.260660 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
W1109 04:02:37.268160 139704147265344 tensor_quantizer.py:280] Use Pytorch's native experimental fake quantization.
/opt/conda/lib/python3.8/site-packages/pytorch_quantization/nn/modules/tensor_quantizer.py:285: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  inputs, amax.item() / bound, 0,
/opt/conda/lib/python3.8/site-packages/pytorch_quantization/nn/modules/tensor_quantizer.py:291: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  quant_dim = list(amax.shape).index(list(amax_sequeeze.shape)[0])
E1109 04:02:37.329273 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.330212 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.332529 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.333365 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.339547 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.340248 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.342257 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.342890 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.350619 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.351372 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.353470 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.354121 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.360090 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.360806 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.362803 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.363274 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.370369 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.371057 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.373071 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.373766 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.379890 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.380538 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.382532 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.383128 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.389077 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.389760 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.391815 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.392399 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.399809 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.400472 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.402399 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.402939 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.408818 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.409424 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.411513 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.412097 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.418537 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.419128 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.421343 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.421946 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.429382 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.430156 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.432259 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.433079 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.439297 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.440027 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.442149 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.442826 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.449377 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.449968 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.452122 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.452754 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.462532 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.463295 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.466963 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.467725 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.469692 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.470336 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.476204 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.476738 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.478809 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.479375 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.485666 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.486219 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.488416 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!
E1109 04:02:37.488986 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!

## 8. Inference using Torch-TensorRT In this phase, we run the exported torchscript graph of VGG QAT using Torch-TensorRT. Torch-TensorRT is a Pytorch-TensorRT compiler which converts Torchscript graphs into TensorRT. TensorRT 8.0 supports inference of quantization aware trained models and introduces new APIs; QuantizeLayer and DequantizeLayer. We can observe the entire VGG QAT graph quantization nodes from the debug log of Torch-TensorRT. To enable debug logging, you can set torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Debug). For example, QuantConv2d layer from pytorch_quantization toolkit is represented as follows in Torchscript

%quant_input : Tensor = aten::fake_quantize_per_tensor_affine(%x, %636, %637, %638, %639)
%quant_weight : Tensor = aten::fake_quantize_per_channel_affine(%394, %640, %641, %637, %638, %639)
%input.2 : Tensor = aten::_convolution(%quant_input, %quant_weight, %395, %687, %688, %689, %643, %690, %642, %643, %643, %644, %644)

aten::fake_quantize_per_*_affine is converted into QuantizeLayer + DequantizeLayer in Torch-TensorRT internally. Please refer to quantization op converters in Torch-TensorRT.

[14]:
qat_model = torch.jit.load("trained_vgg16_qat.jit.pt").eval()

compile_spec = {"inputs": [torch_tensorrt.Input([16, 3, 32, 32])],
                "enabled_precisions": torch.int8,
                }
trt_mod = torch_tensorrt.compile(qat_model, **compile_spec)

test_loss, test_acc = test(trt_mod, testing_dataloader, crit, 0)
print("VGG QAT accuracy using TensorRT: {:.2f}%".format(100 * test_acc))
WARNING: [Torch-TensorRT] - Cannot infer input type from calcuations in graph for input x.2. Assuming it is Float32. If not, specify input type explicity
WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter
WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter
WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter
WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter
WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter
WARNING: [Torch-TensorRT TorchScript Conversion Context] - Detected invalid timing cache, setup a local cache instead
VGG QAT accuracy using TensorRT: 82.97%

Performance benchmarking

[15]:
import time
import numpy as np

import torch.backends.cudnn as cudnn
cudnn.benchmark = True

# Helper function to benchmark the model
def benchmark(model, input_shape=(1024, 1, 32, 32), dtype='fp32', nwarmup=50, nruns=1000):
    input_data = torch.randn(input_shape)
    input_data = input_data.to("cuda")
    if dtype=='fp16':
        input_data = input_data.half()

    print("Warm up ...")
    with torch.no_grad():
        for _ in range(nwarmup):
            features = model(input_data)
    torch.cuda.synchronize()
    print("Start timing ...")
    timings = []
    with torch.no_grad():
        for i in range(1, nruns+1):
            start_time = time.time()
            output = model(input_data)
            torch.cuda.synchronize()
            end_time = time.time()
            timings.append(end_time - start_time)
            if i%100==0:
                print('Iteration %d/%d, avg batch time %.2f ms'%(i, nruns, np.mean(timings)*1000))

    print("Input shape:", input_data.size())
    print("Output shape:", output.shape)
    print('Average batch time: %.2f ms'%(np.mean(timings)*1000))

[16]:
benchmark(jit_model, input_shape=(16, 3, 32, 32))
Warm up ...
Start timing ...
Iteration 100/1000, avg batch time 4.83 ms
Iteration 200/1000, avg batch time 4.83 ms
Iteration 300/1000, avg batch time 4.83 ms
Iteration 400/1000, avg batch time 4.83 ms
Iteration 500/1000, avg batch time 4.83 ms
Iteration 600/1000, avg batch time 4.83 ms
Iteration 700/1000, avg batch time 4.83 ms
Iteration 800/1000, avg batch time 4.83 ms
Iteration 900/1000, avg batch time 4.83 ms
Iteration 1000/1000, avg batch time 4.83 ms
Input shape: torch.Size([16, 3, 32, 32])
Output shape: torch.Size([16, 10])
Average batch time: 4.83 ms
[17]:
benchmark(trt_mod, input_shape=(16, 3, 32, 32))
Warm up ...
Start timing ...
Iteration 100/1000, avg batch time 1.87 ms
Iteration 200/1000, avg batch time 1.84 ms
Iteration 300/1000, avg batch time 1.85 ms
Iteration 400/1000, avg batch time 1.83 ms
Iteration 500/1000, avg batch time 1.82 ms
Iteration 600/1000, avg batch time 1.81 ms
Iteration 700/1000, avg batch time 1.81 ms
Iteration 800/1000, avg batch time 1.80 ms
Iteration 900/1000, avg batch time 1.80 ms
Iteration 1000/1000, avg batch time 1.79 ms
Input shape: torch.Size([16, 3, 32, 32])
Output shape: torch.Size([16, 10])
Average batch time: 1.79 ms

## 9. References * Very Deep Convolution Networks for large scale Image Recognition * Achieving FP32 Accuracy for INT8 Inference Using Quantization Aware Training with NVIDIA TensorRT * QAT workflow for VGG16 * Deploying VGG QAT model in C++ using Torch-TensorRT * Pytorch-quantization toolkit from NVIDIA * Pytorch quantization toolkit userguide * Quantization basics

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources