.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "beginner/vt_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_beginner_vt_tutorial.py: Optimizing Vision Transformer Model for Deployment ================================================== `Jeff Tang `_, `Geeta Chauhan `_ Vision Transformer models apply the cutting-edge attention-based transformer models, introduced in Natural Language Processing to achieve all kinds of the state of the art (SOTA) results, to Computer Vision tasks. Facebook Data-efficient Image Transformers `DeiT `_ is a Vision Transformer model trained on ImageNet for image classification. In this tutorial, we will first cover what DeiT is and how to use it, then go through the complete steps of scripting, quantizing, optimizing, and using the model in iOS and Android apps. We will also compare the performance of quantized, optimized and non-quantized, non-optimized models, and show the benefits of applying quantization and optimization to the model along the steps. .. GENERATED FROM PYTHON SOURCE LINES 27-47 What is DeiT --------------------- Convolutional Neural Networks (CNNs) have been the main models for image classification since deep learning took off in 2012, but CNNs typically require hundreds of millions of images for training to achieve the SOTA results. DeiT is a vision transformer model that requires a lot less data and computing resources for training to compete with the leading CNNs in performing image classification, which is made possible by two key components of of DeiT: - Data augmentation that simulates training on a much larger dataset; - Native distillation that allows the transformer network to learn from a CNN’s output. DeiT shows that Transformers can be successfully applied to computer vision tasks, with limited access to data and resources. For more details on DeiT, see the `repo `_ and `paper `_. .. GENERATED FROM PYTHON SOURCE LINES 50-60 Classifying Images with DeiT ------------------------------- Follow the ``README.md`` at the DeiT repository for detailed information on how to classify images using DeiT, or for a quick test, first install the required packages: .. code-block:: python pip install torch torchvision timm pandas requests .. GENERATED FROM PYTHON SOURCE LINES 62-67 To run in Google Colab, install dependencies by running the following command: .. code-block:: python !pip install timm pandas requests .. GENERATED FROM PYTHON SOURCE LINES 69-70 then run the script below: .. GENERATED FROM PYTHON SOURCE LINES 70-99 .. code-block:: default from PIL import Image import torch import timm import requests import torchvision.transforms as transforms from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD print(torch.__version__) # should be 1.8.0 model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True) model.eval() transform = transforms.Compose([ transforms.Resize(256, interpolation=3), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), ]) img = Image.open(requests.get("https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png", stream=True).raw) img = transform(img)[None,] out = model(img) clsidx = torch.argmax(out) print(clsidx.item()) .. rst-class:: sphx-glr-script-out .. code-block:: none 2.2.1+cu121 Downloading: "https://github.com/facebookresearch/deit/zipball/main" to /root/.cache/torch/hub/main.zip /root/.cache/torch/hub/facebookresearch_deit_main/models.py:63: UserWarning: Overwriting deit_tiny_patch16_224 in registry with models.deit_tiny_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected. /root/.cache/torch/hub/facebookresearch_deit_main/models.py:78: UserWarning: Overwriting deit_small_patch16_224 in registry with models.deit_small_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected. /root/.cache/torch/hub/facebookresearch_deit_main/models.py:93: UserWarning: Overwriting deit_base_patch16_224 in registry with models.deit_base_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected. /root/.cache/torch/hub/facebookresearch_deit_main/models.py:108: UserWarning: Overwriting deit_tiny_distilled_patch16_224 in registry with models.deit_tiny_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected. /root/.cache/torch/hub/facebookresearch_deit_main/models.py:123: UserWarning: Overwriting deit_small_distilled_patch16_224 in registry with models.deit_small_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected. /root/.cache/torch/hub/facebookresearch_deit_main/models.py:138: UserWarning: Overwriting deit_base_distilled_patch16_224 in registry with models.deit_base_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected. /root/.cache/torch/hub/facebookresearch_deit_main/models.py:153: UserWarning: Overwriting deit_base_patch16_384 in registry with models.deit_base_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected. /root/.cache/torch/hub/facebookresearch_deit_main/models.py:168: UserWarning: Overwriting deit_base_distilled_patch16_384 in registry with models.deit_base_distilled_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected. Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /root/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth 0%| | 0.00/330M [00:00`_, maps to ``timber wolf, grey wolf, gray wolf, Canis lupus``. Now that we have verified that we can use the DeiT model to classify images, let’s see how to modify the model so it can run on iOS and Android apps. .. GENERATED FROM PYTHON SOURCE LINES 111-118 Scripting DeiT ---------------------- To use the model on mobile, we first need to script the model. See the `Script and Optimize recipe `_ for a quick overview. Run the code below to convert the DeiT model used in the previous step to the TorchScript format that can run on mobile. .. GENERATED FROM PYTHON SOURCE LINES 118-126 .. code-block:: default model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True) model.eval() scripted_model = torch.jit.script(model) scripted_model.save("fbdeit_scripted.pt") .. rst-class:: sphx-glr-script-out .. code-block:: none Using cache found in /root/.cache/torch/hub/facebookresearch_deit_main .. GENERATED FROM PYTHON SOURCE LINES 127-130 The scripted model file ``fbdeit_scripted.pt`` of size about 346MB is generated. .. GENERATED FROM PYTHON SOURCE LINES 133-144 Quantizing DeiT --------------------- To reduce the trained model size significantly while keeping the inference accuracy about the same, quantization can be applied to the model. Thanks to the transformer model used in DeiT, we can easily apply dynamic-quantization to the model, because dynamic quantization works best for LSTM and transformer models (see `here `_ for more details). Now run the code below: .. GENERATED FROM PYTHON SOURCE LINES 144-155 .. code-block:: default # Use 'x86' for server inference (the old 'fbgemm' is still available but 'x86' is the recommended default) and ``qnnpack`` for mobile inference. backend = "x86" # replaced with ``qnnpack`` causing much worse inference speed for quantized model on this notebook model.qconfig = torch.quantization.get_default_qconfig(backend) torch.backends.quantized.engine = backend quantized_model = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8) scripted_quantized_model = torch.jit.script(quantized_model) scripted_quantized_model.save("fbdeit_scripted_quantized.pt") .. rst-class:: sphx-glr-script-out .. code-block:: none /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/ao/quantization/observer.py:220: UserWarning: Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch. .. GENERATED FROM PYTHON SOURCE LINES 156-160 This generates the scripted and quantized version of the model ``fbdeit_quantized_scripted.pt``, with size about 89MB, a 74% reduction of the non-quantized model size of 346MB! .. GENERATED FROM PYTHON SOURCE LINES 162-165 You can use the ``scripted_quantized_model`` to generate the same inference result: .. GENERATED FROM PYTHON SOURCE LINES 165-171 .. code-block:: default out = scripted_quantized_model(img) clsidx = torch.argmax(out) print(clsidx.item()) # The same output 269 should be printed .. rst-class:: sphx-glr-script-out .. code-block:: none 269 .. GENERATED FROM PYTHON SOURCE LINES 172-177 Optimizing DeiT --------------------- The final step before using the quantized and scripted model on mobile is to optimize it: .. GENERATED FROM PYTHON SOURCE LINES 177-183 .. code-block:: default from torch.utils.mobile_optimizer import optimize_for_mobile optimized_scripted_quantized_model = optimize_for_mobile(scripted_quantized_model) optimized_scripted_quantized_model.save("fbdeit_optimized_scripted_quantized.pt") .. GENERATED FROM PYTHON SOURCE LINES 184-188 The generated ``fbdeit_optimized_scripted_quantized.pt`` file has about the same size as the quantized, scripted, but non-optimized model. The inference result remains the same. .. GENERATED FROM PYTHON SOURCE LINES 188-197 .. code-block:: default out = optimized_scripted_quantized_model(img) clsidx = torch.argmax(out) print(clsidx.item()) # Again, the same output 269 should be printed .. rst-class:: sphx-glr-script-out .. code-block:: none 269 .. GENERATED FROM PYTHON SOURCE LINES 198-204 Using Lite Interpreter ------------------------ To see how much model size reduction and inference speed up the Lite Interpreter can result in, let’s create the lite version of the model. .. GENERATED FROM PYTHON SOURCE LINES 204-209 .. code-block:: default optimized_scripted_quantized_model._save_for_lite_interpreter("fbdeit_optimized_scripted_quantized_lite.ptl") ptl = torch.jit.load("fbdeit_optimized_scripted_quantized_lite.ptl") .. GENERATED FROM PYTHON SOURCE LINES 210-213 Although the lite model size is comparable to the non-lite version, when running the lite version on mobile, the inference speed up is expected. .. GENERATED FROM PYTHON SOURCE LINES 216-223 Comparing Inference Speed --------------------------- To see how the inference speed differs for the four models - the original model, the scripted model, the quantized-and-scripted model, the optimized-quantized-and-scripted model - run the code below: .. GENERATED FROM PYTHON SOURCE LINES 223-241 .. code-block:: default with torch.autograd.profiler.profile(use_cuda=False) as prof1: out = model(img) with torch.autograd.profiler.profile(use_cuda=False) as prof2: out = scripted_model(img) with torch.autograd.profiler.profile(use_cuda=False) as prof3: out = scripted_quantized_model(img) with torch.autograd.profiler.profile(use_cuda=False) as prof4: out = optimized_scripted_quantized_model(img) with torch.autograd.profiler.profile(use_cuda=False) as prof5: out = ptl(img) print("original model: {:.2f}ms".format(prof1.self_cpu_time_total/1000)) print("scripted model: {:.2f}ms".format(prof2.self_cpu_time_total/1000)) print("scripted & quantized model: {:.2f}ms".format(prof3.self_cpu_time_total/1000)) print("scripted & quantized & optimized model: {:.2f}ms".format(prof4.self_cpu_time_total/1000)) print("lite model: {:.2f}ms".format(prof5.self_cpu_time_total/1000)) .. rst-class:: sphx-glr-script-out .. code-block:: none original model: 159.77ms scripted model: 104.92ms scripted & quantized model: 123.95ms scripted & quantized & optimized model: 137.04ms lite model: 150.63ms .. GENERATED FROM PYTHON SOURCE LINES 242-252 The results running on a Google Colab are: .. code-block:: sh original model: 1236.69ms scripted model: 1226.72ms scripted & quantized model: 593.19ms scripted & quantized & optimized model: 598.01ms lite model: 600.72ms .. GENERATED FROM PYTHON SOURCE LINES 255-259 The following results summarize the inference time taken by each model and the percentage reduction of each model relative to the original model. .. GENERATED FROM PYTHON SOURCE LINES 259-287 .. code-block:: default import pandas as pd import numpy as np df = pd.DataFrame({'Model': ['original model','scripted model', 'scripted & quantized model', 'scripted & quantized & optimized model', 'lite model']}) df = pd.concat([df, pd.DataFrame([ ["{:.2f}ms".format(prof1.self_cpu_time_total/1000), "0%"], ["{:.2f}ms".format(prof2.self_cpu_time_total/1000), "{:.2f}%".format((prof1.self_cpu_time_total-prof2.self_cpu_time_total)/prof1.self_cpu_time_total*100)], ["{:.2f}ms".format(prof3.self_cpu_time_total/1000), "{:.2f}%".format((prof1.self_cpu_time_total-prof3.self_cpu_time_total)/prof1.self_cpu_time_total*100)], ["{:.2f}ms".format(prof4.self_cpu_time_total/1000), "{:.2f}%".format((prof1.self_cpu_time_total-prof4.self_cpu_time_total)/prof1.self_cpu_time_total*100)], ["{:.2f}ms".format(prof5.self_cpu_time_total/1000), "{:.2f}%".format((prof1.self_cpu_time_total-prof5.self_cpu_time_total)/prof1.self_cpu_time_total*100)]], columns=['Inference Time', 'Reduction'])], axis=1) print(df) """ Model Inference Time Reduction 0 original model 1236.69ms 0% 1 scripted model 1226.72ms 0.81% 2 scripted & quantized model 593.19ms 52.03% 3 scripted & quantized & optimized model 598.01ms 51.64% 4 lite model 600.72ms 51.43% """ .. rst-class:: sphx-glr-script-out .. code-block:: none Model ... Reduction 0 original model ... 0% 1 scripted model ... 34.33% 2 scripted & quantized model ... 22.42% 3 scripted & quantized & optimized model ... 14.22% 4 lite model ... 5.72% [5 rows x 3 columns] '\n Model Inference Time Reduction\n0\toriginal model 1236.69ms 0%\n1\tscripted model 1226.72ms 0.81%\n2\tscripted & quantized model 593.19ms 52.03%\n3\tscripted & quantized & optimized model 598.01ms 51.64%\n4\tlite model 600.72ms 51.43%\n' .. GENERATED FROM PYTHON SOURCE LINES 288-294 Learn More ~~~~~~~~~~~~~~~~~ - `Facebook Data-efficient Image Transformers `__ - `Vision Transformer with ImageNet and MNIST on iOS `__ - `Vision Transformer with ImageNet and MNIST on Android `__ .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 20.659 seconds) .. _sphx_glr_download_beginner_vt_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: vt_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: vt_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_