.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/_rendered_examples/dynamo/vgg16_ptq.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end <sphx_glr_download_tutorials__rendered_examples_dynamo_vgg16_ptq.py>` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials__rendered_examples_dynamo_vgg16_ptq.py: .. _vgg16_ptq: Deploy Quantized Models using Torch-TensorRT ====================================================== Here we demonstrate how to deploy a model quantized to INT8 or FP8 using the Dynamo frontend of Torch-TensorRT .. GENERATED FROM PYTHON SOURCE LINES 11-13 Imports and Model Definition ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 13-125 .. code-block:: python import argparse import modelopt.torch.quantization as mtq import torch import torch.nn as nn import torch.nn.functional as F import torch_tensorrt as torchtrt import torchvision.datasets as datasets import torchvision.transforms as transforms from modelopt.torch.quantization.utils import export_torch_mode class VGG(nn.Module): def __init__(self, layer_spec, num_classes=1000, init_weights=False): super(VGG, self).__init__() layers = [] in_channels = 3 for l in layer_spec: if l == "pool": layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) else: layers += [ nn.Conv2d(in_channels, l, kernel_size=3, padding=1), nn.BatchNorm2d(l), nn.ReLU(), ] in_channels = l self.features = nn.Sequential(*layers) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.classifier = nn.Sequential( nn.Linear(512 * 1 * 1, 4096), nn.ReLU(), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(), nn.Linear(4096, num_classes), ) if init_weights: self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) def forward(self, x): x = self.features(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.classifier(x) return x def vgg16(num_classes=1000, init_weights=False): vgg16_cfg = [ 64, 64, "pool", 128, 128, "pool", 256, 256, 256, "pool", 512, 512, 512, "pool", 512, 512, 512, "pool", ] return VGG(vgg16_cfg, num_classes, init_weights) PARSER = argparse.ArgumentParser( description="Load pre-trained VGG model and then tune with FP8 and PTQ. For having a pre-trained VGG model, please refer to https://github.com/pytorch/TensorRT/tree/main/examples/int8/training/vgg16" ) PARSER.add_argument( "--ckpt", type=str, required=True, help="Path to the pre-trained checkpoint" ) PARSER.add_argument( "--batch-size", default=128, type=int, help="Batch size for tuning the model with PTQ and FP8", ) PARSER.add_argument( "--quantize-type", default="int8", type=str, help="quantization type, currently supported int8 or fp8 for PTQ", ) args = PARSER.parse_args() model = vgg16(num_classes=10, init_weights=False) model = model.cuda() .. GENERATED FROM PYTHON SOURCE LINES 126-128 Load the pre-trained model weights ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 128-145 .. code-block:: python ckpt = torch.load(args.ckpt) weights = ckpt["model_state_dict"] if torch.cuda.device_count() > 1: from collections import OrderedDict new_state_dict = OrderedDict() for k, v in weights.items(): name = k[7:] # remove `module.` new_state_dict[name] = v weights = new_state_dict model.load_state_dict(weights) # Don't forget to set the model to evaluation mode! model.eval() .. GENERATED FROM PYTHON SOURCE LINES 146-148 Load training dataset and define loss function for PTQ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 148-175 .. code-block:: python training_dataset = datasets.CIFAR10( root="./data", train=True, download=True, transform=transforms.Compose( [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ] ), ) training_dataloader = torch.utils.data.DataLoader( training_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, drop_last=True, ) data = iter(training_dataloader) images, _ = next(data) crit = nn.CrossEntropyLoss() .. GENERATED FROM PYTHON SOURCE LINES 176-178 Define Calibration Loop for quantization ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 178-196 .. code-block:: python def calibrate_loop(model): # calibrate over the training dataset total = 0 correct = 0 loss = 0.0 for data, labels in training_dataloader: data, labels = data.cuda(), labels.cuda(non_blocking=True) out = model(data) loss += crit(out, labels) preds = torch.max(out, 1)[1] total += labels.size(0) correct += (preds == labels).sum().item() print("PTQ Loss: {:.5f} Acc: {:.2f}%".format(loss / total, 100 * correct / total)) .. GENERATED FROM PYTHON SOURCE LINES 197-199 Tune the pre-trained model with FP8 and PTQ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 199-207 .. code-block:: python if args.quantize_type == "int8": quant_cfg = mtq.INT8_DEFAULT_CFG elif args.quantize_type == "fp8": quant_cfg = mtq.FP8_DEFAULT_CFG # PTQ with in-place replacement to quantized modules mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) # model has FP8 qdq nodes at this point .. GENERATED FROM PYTHON SOURCE LINES 208-210 Inference ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 210-275 .. code-block:: python # Load the testing dataset testing_dataset = datasets.CIFAR10( root="./data", train=False, download=True, transform=transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ] ), ) testing_dataloader = torch.utils.data.DataLoader( testing_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2, drop_last=True, ) # set drop_last=True to drop the last incomplete batch for static shape `torchtrt.dynamo.compile()` with torch.no_grad(): with export_torch_mode(): # Compile the model with Torch-TensorRT Dynamo backend input_tensor = images.cuda() # torch.export.export() failed due to RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode() from torch.export._trace import _export exp_program = _export(model, (input_tensor,)) if args.quantize_type == "int8": enabled_precisions = {torch.int8} elif args.quantize_type == "fp8": enabled_precisions = {torch.float8_e4m3fn} trt_model = torchtrt.dynamo.compile( exp_program, inputs=[input_tensor], enabled_precisions=enabled_precisions, min_block_size=1, debug=False, ) # You can also use torch compile path to compile the model with Torch-TensorRT: # trt_model = torch.compile(model, backend="tensorrt") # Inference compiled Torch-TensorRT model over the testing dataset total = 0 correct = 0 loss = 0.0 class_probs = [] class_preds = [] for data, labels in testing_dataloader: data, labels = data.cuda(), labels.cuda(non_blocking=True) out = trt_model(data) loss += crit(out, labels) preds = torch.max(out, 1)[1] class_probs.append([F.softmax(i, dim=0) for i in out]) class_preds.append(preds) total += labels.size(0) correct += (preds == labels).sum().item() test_probs = torch.cat([torch.stack(batch) for batch in class_probs]) test_preds = torch.cat(class_preds) test_loss = loss / total test_acc = correct / total print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc)) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_tutorials__rendered_examples_dynamo_vgg16_ptq.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: vgg16_ptq.py <vgg16_ptq.py>` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: vgg16_ptq.ipynb <vgg16_ptq.ipynb>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_