• Docs >
  • Torch-TensorRT - Using Dynamic Shapes
Shortcuts
[1]:
# Copyright 2020 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

Torch-TensorRT - Using Dynamic Shapes

Torch-TensorRT is a compiler for PyTorch/TorchScript, targeting NVIDIA GPUs via NVIDIA’s TensorRT Deep Learning Optimizer and Runtime. Unlike PyTorch’s Just-In-Time (JIT) compiler, Torch-TensorRT is an Ahead-of-Time (AOT) compiler, meaning that before you deploy your TorchScript code, you go through an explicit compile step to convert a standard TorchScript program into an module targeting a TensorRT engine. Torch-TensorRT operates as a PyTorch extention and compiles modules that integrate into the JIT runtime seamlessly. After compilation using the optimized graph should feel no different than running a TorchScript module. You also have access to TensorRT’s suite of configurations at compile time, so you are able to specify operating precision (FP32/FP16/INT8) and other settings for your module.

We highly encorage users to use our NVIDIA’s PyTorch container to run this notebook. It comes packaged with a host of NVIDIA libraries and optimizations to widely used third party libraries. This container is tested and updated on a monthly cadence!

This notebook has the following sections: 1. TL;DR Explanation 1. Setting up the model 1. Working with Dynamic shapes in Torch TRT

torch_tensorrt.Input( min_shape=(1, 224, 224, 3), opt_shape=(1, 512, 512, 3), max_shape=(1, 1024, 1024, 3), dtype=torch.int32 format=torch.channel_last ) … ``` In this example, we are going to use a simple ResNet model to demonstrate the use of the API. We will be using different batch sizes in the example, but you can use the same method to alter any of the dimensions of the tensor.

[2]:
!nvidia-smi
!pip install ipywidgets --trusted-host pypi.org --trusted-host pypi.python.org --trusted-host=files.pythonhosted.org
Mon May  2 20:40:30 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA Graphics...  On   | 00000000:01:00.0 Off |                    0 |
| 41%   51C    P0    62W / 200W |      0MiB / 47681MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Requirement already satisfied: ipywidgets in /opt/conda/lib/python3.8/site-packages (7.7.0)
Requirement already satisfied: traitlets>=4.3.1 in /opt/conda/lib/python3.8/site-packages (from ipywidgets) (5.1.1)
Requirement already satisfied: nbformat>=4.2.0 in /opt/conda/lib/python3.8/site-packages (from ipywidgets) (5.3.0)
Requirement already satisfied: ipykernel>=4.5.1 in /opt/conda/lib/python3.8/site-packages (from ipywidgets) (6.13.0)
Requirement already satisfied: ipython-genutils~=0.2.0 in /opt/conda/lib/python3.8/site-packages (from ipywidgets) (0.2.0)
Requirement already satisfied: widgetsnbextension~=3.6.0 in /opt/conda/lib/python3.8/site-packages (from ipywidgets) (3.6.0)
Requirement already satisfied: ipython>=4.0.0 in /opt/conda/lib/python3.8/site-packages (from ipywidgets) (8.2.0)
Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /opt/conda/lib/python3.8/site-packages (from ipywidgets) (1.1.0)
Requirement already satisfied: psutil in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets) (5.9.0)
Requirement already satisfied: tornado>=6.1 in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets) (6.1)
Requirement already satisfied: packaging in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets) (21.3)
Requirement already satisfied: nest-asyncio in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets) (1.5.5)
Requirement already satisfied: matplotlib-inline>=0.1 in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets) (0.1.3)
Requirement already satisfied: jupyter-client>=6.1.12 in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets) (7.2.2)
Requirement already satisfied: debugpy>=1.0 in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets) (1.6.0)
Requirement already satisfied: decorator in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (5.1.1)
Requirement already satisfied: setuptools>=18.5 in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (59.5.0)
Requirement already satisfied: pickleshare in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (0.7.5)
Requirement already satisfied: backcall in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (0.2.0)
Requirement already satisfied: stack-data in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (0.2.0)
Requirement already satisfied: pexpect>4.3 in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (4.8.0)
Requirement already satisfied: jedi>=0.16 in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (0.18.1)
Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (3.0.29)
Requirement already satisfied: pygments>=2.4.0 in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (2.11.2)
Requirement already satisfied: parso<0.9.0,>=0.8.0 in /opt/conda/lib/python3.8/site-packages (from jedi>=0.16->ipython>=4.0.0->ipywidgets) (0.8.3)
Requirement already satisfied: pyzmq>=22.3 in /opt/conda/lib/python3.8/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets) (22.3.0)
Requirement already satisfied: entrypoints in /opt/conda/lib/python3.8/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets) (0.4)
Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/lib/python3.8/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets) (2.8.2)
Requirement already satisfied: jupyter-core>=4.9.2 in /opt/conda/lib/python3.8/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets) (4.9.2)
Requirement already satisfied: jsonschema>=2.6 in /opt/conda/lib/python3.8/site-packages (from nbformat>=4.2.0->ipywidgets) (4.4.0)
Requirement already satisfied: fastjsonschema in /opt/conda/lib/python3.8/site-packages (from nbformat>=4.2.0->ipywidgets) (2.15.3)
Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /opt/conda/lib/python3.8/site-packages (from jsonschema>=2.6->nbformat>=4.2.0->ipywidgets) (0.18.1)
Requirement already satisfied: attrs>=17.4.0 in /opt/conda/lib/python3.8/site-packages (from jsonschema>=2.6->nbformat>=4.2.0->ipywidgets) (21.4.0)
Requirement already satisfied: importlib-resources>=1.4.0 in /opt/conda/lib/python3.8/site-packages (from jsonschema>=2.6->nbformat>=4.2.0->ipywidgets) (5.7.0)
Requirement already satisfied: zipp>=3.1.0 in /opt/conda/lib/python3.8/site-packages (from importlib-resources>=1.4.0->jsonschema>=2.6->nbformat>=4.2.0->ipywidgets) (3.8.0)
Requirement already satisfied: ptyprocess>=0.5 in /opt/conda/lib/python3.8/site-packages (from pexpect>4.3->ipython>=4.0.0->ipywidgets) (0.7.0)
Requirement already satisfied: wcwidth in /opt/conda/lib/python3.8/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=4.0.0->ipywidgets) (0.2.5)
Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.8/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets) (1.16.0)
Requirement already satisfied: notebook>=4.4.1 in /opt/conda/lib/python3.8/site-packages (from widgetsnbextension~=3.6.0->ipywidgets) (6.4.1)
Requirement already satisfied: jinja2 in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (3.1.1)
Requirement already satisfied: prometheus-client in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.14.1)
Requirement already satisfied: Send2Trash>=1.5.0 in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.8.0)
Requirement already satisfied: argon2-cffi in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (21.3.0)
Requirement already satisfied: nbconvert in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (6.5.0)
Requirement already satisfied: terminado>=0.8.3 in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.13.3)
Requirement already satisfied: argon2-cffi-bindings in /opt/conda/lib/python3.8/site-packages (from argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (21.2.0)
Requirement already satisfied: cffi>=1.0.1 in /opt/conda/lib/python3.8/site-packages (from argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.15.0)
Requirement already satisfied: pycparser in /opt/conda/lib/python3.8/site-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (2.21)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.8/site-packages (from jinja2->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (2.1.1)
Requirement already satisfied: defusedxml in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.7.1)
Requirement already satisfied: jupyterlab-pygments in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.2.2)
Requirement already satisfied: mistune<2,>=0.8.1 in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.8.4)
Requirement already satisfied: nbclient>=0.5.0 in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.6.0)
Requirement already satisfied: bleach in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (5.0.0)
Requirement already satisfied: tinycss2 in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.1.1)
Requirement already satisfied: pandocfilters>=1.4.1 in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.5.0)
Requirement already satisfied: beautifulsoup4 in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (4.11.1)
Requirement already satisfied: soupsieve>1.2 in /opt/conda/lib/python3.8/site-packages (from beautifulsoup4->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (2.3.1)
Requirement already satisfied: webencodings in /opt/conda/lib/python3.8/site-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.5.1)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.8/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets) (3.0.8)
Requirement already satisfied: asttokens in /opt/conda/lib/python3.8/site-packages (from stack-data->ipython>=4.0.0->ipywidgets) (2.0.5)
Requirement already satisfied: pure-eval in /opt/conda/lib/python3.8/site-packages (from stack-data->ipython>=4.0.0->ipywidgets) (0.2.2)
Requirement already satisfied: executing in /opt/conda/lib/python3.8/site-packages (from stack-data->ipython>=4.0.0->ipywidgets) (0.8.3)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

Setting up the model

In this section, we will: * Get sample data. * Download model from torch hub. * Build simple utility functions

Getting sample data

[3]:
!mkdir -p ./data
!wget  -O ./data/img0.JPG "https://d17fnq9dkz9hgj.cloudfront.net/breed-uploads/2018/08/siberian-husky-detail.jpg?bust=1535566590&width=630"
!wget  -O ./data/img1.JPG "https://www.hakaimagazine.com/wp-content/uploads/header-gulf-birds.jpg"
!wget  -O ./data/img2.JPG "https://www.artis.nl/media/filer_public_thumbnails/filer_public/00/f1/00f1b6db-fbed-4fef-9ab0-84e944ff11f8/chimpansee_amber_r_1920x1080.jpg__1920x1080_q85_subject_location-923%2C365_subsampling-2.jpg"
!wget  -O ./data/img3.JPG "https://www.familyhandyman.com/wp-content/uploads/2018/09/How-to-Avoid-Snakes-Slithering-Up-Your-Toilet-shutterstock_780480850.jpg"

!wget  -O ./data/imagenet_class_index.json "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
--2022-05-02 20:40:33--  https://d17fnq9dkz9hgj.cloudfront.net/breed-uploads/2018/08/siberian-husky-detail.jpg?bust=1535566590&width=630
Resolving d17fnq9dkz9hgj.cloudfront.net (d17fnq9dkz9hgj.cloudfront.net)... 18.65.227.37, 18.65.227.99, 18.65.227.223, ...
Connecting to d17fnq9dkz9hgj.cloudfront.net (d17fnq9dkz9hgj.cloudfront.net)|18.65.227.37|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 24112 (24K) [image/jpeg]
Saving to: ‘./data/img0.JPG’

./data/img0.JPG     100%[===================>]  23.55K  --.-KB/s    in 0.005s

2022-05-02 20:40:33 (4.69 MB/s) - ‘./data/img0.JPG’ saved [24112/24112]

--2022-05-02 20:40:34--  https://www.hakaimagazine.com/wp-content/uploads/header-gulf-birds.jpg
Resolving www.hakaimagazine.com (www.hakaimagazine.com)... 164.92.73.117
Connecting to www.hakaimagazine.com (www.hakaimagazine.com)|164.92.73.117|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 452718 (442K) [image/jpeg]
Saving to: ‘./data/img1.JPG’

./data/img1.JPG     100%[===================>] 442.11K  --.-KB/s    in 0.02s

2022-05-02 20:40:34 (26.2 MB/s) - ‘./data/img1.JPG’ saved [452718/452718]

--2022-05-02 20:40:34--  https://www.artis.nl/media/filer_public_thumbnails/filer_public/00/f1/00f1b6db-fbed-4fef-9ab0-84e944ff11f8/chimpansee_amber_r_1920x1080.jpg__1920x1080_q85_subject_location-923%2C365_subsampling-2.jpg
Resolving www.artis.nl (www.artis.nl)... 94.75.225.20
Connecting to www.artis.nl (www.artis.nl)|94.75.225.20|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 361413 (353K) [image/jpeg]
Saving to: ‘./data/img2.JPG’

./data/img2.JPG     100%[===================>] 352.94K   608KB/s    in 0.6s

2022-05-02 20:40:36 (608 KB/s) - ‘./data/img2.JPG’ saved [361413/361413]

--2022-05-02 20:40:37--  https://www.familyhandyman.com/wp-content/uploads/2018/09/How-to-Avoid-Snakes-Slithering-Up-Your-Toilet-shutterstock_780480850.jpg
Resolving www.familyhandyman.com (www.familyhandyman.com)... 104.18.201.107, 104.18.202.107, 2606:4700::6812:c96b, ...
Connecting to www.familyhandyman.com (www.familyhandyman.com)|104.18.201.107|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 90994 (89K) [image/jpeg]
Saving to: ‘./data/img3.JPG’

./data/img3.JPG     100%[===================>]  88.86K  --.-KB/s    in 0.006s

2022-05-02 20:40:37 (15.4 MB/s) - ‘./data/img3.JPG’ saved [90994/90994]

--2022-05-02 20:40:37--  https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.33.238
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.33.238|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 35363 (35K) [application/octet-stream]
Saving to: ‘./data/imagenet_class_index.json’

./data/imagenet_cla 100%[===================>]  34.53K  --.-KB/s    in 0.07s

2022-05-02 20:40:38 (489 KB/s) - ‘./data/imagenet_class_index.json’ saved [35363/35363]

[4]:
# visualizing the downloaded images

from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json

fig, axes = plt.subplots(nrows=2, ncols=2)

for i in range(4):
    img_path = './data/img%d.JPG'%i
    img = Image.open(img_path)
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = preprocess(img)
    plt.subplot(2,2,i+1)
    plt.imshow(img)
    plt.axis('off')

# loading labels
with open("./data/imagenet_class_index.json") as json_file:
    d = json.load(json_file)
../_images/_notebooks_dynamic-shapes_8_0.png

Download model from torch hub.

[5]:
import torch

torch.hub._validate_not_a_forked_repo=lambda a,b,c: True

resnet50_model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
resnet50_model.eval()
Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0
[5]:
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer2): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (3): Bottleneck(
      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer3): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (3): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (4): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (5): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=2048, out_features=1000, bias=True)
)

Build simple utility functions

[6]:
import numpy as np
import time
import torch.backends.cudnn as cudnn
cudnn.benchmark = True

def rn50_preprocess():
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return preprocess

# decode the results into ([predicted class, description], probability)
def predict(img_path, model):
    img = Image.open(img_path)
    preprocess = rn50_preprocess()
    input_tensor = preprocess(img)
    input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

    # move the input and model to GPU for speed if available
    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')

    with torch.no_grad():
        output = model(input_batch)
        # Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
        sm_output = torch.nn.functional.softmax(output[0], dim=0)

    ind = torch.argmax(sm_output)
    return d[str(ind.item())], sm_output[ind] #([predicted class, description], probability)

# benchmarking models
def benchmark(model, input_shape=(1024, 1, 224, 224), dtype='fp32', nwarmup=50, nruns=10000):
    input_data = torch.randn(input_shape)
    input_data = input_data.to("cuda")
    if dtype=='fp16':
        input_data = input_data.half()

    print("Warm up ...")
    with torch.no_grad():
        for _ in range(nwarmup):
            features = model(input_data)
    torch.cuda.synchronize()
    print("Start timing ...")
    timings = []
    with torch.no_grad():
        for i in range(1, nruns+1):
            start_time = time.time()
            features = model(input_data)
            torch.cuda.synchronize()
            end_time = time.time()
            timings.append(end_time - start_time)
            if i%10==0:
                print('Iteration %d/%d, ave batch time %.2f ms'%(i, nruns, np.mean(timings)*1000))
                print('Images processed per second=', int(1000*input_shape[0]/(np.mean(timings)*1000)))
    print("Input shape:", input_data.size())
    print("Output features size:", features.size())
    print('Average batch time: %.2f ms'%(np.mean(timings)*1000))

Let’s test our util functions on the model we have set up, starting with simple predictions

[7]:
for i in range(4):
    img_path = './data/img%d.JPG'%i
    img = Image.open(img_path)

    pred, prob = predict(img_path, resnet50_model)
    print('{} - Predicted: {}, Probablility: {}'.format(img_path, pred, prob))

    plt.subplot(2,2,i+1)
    plt.imshow(img);
    plt.axis('off');
    plt.title(pred[1])
./data/img0.JPG - Predicted: ['n02110185', 'Siberian_husky'], Probablility: 0.49788108468055725
./data/img1.JPG - Predicted: ['n01820546', 'lorikeet'], Probablility: 0.6442285180091858
./data/img2.JPG - Predicted: ['n02481823', 'chimpanzee'], Probablility: 0.9899841547012329
./data/img3.JPG - Predicted: ['n01749939', 'green_mamba'], Probablility: 0.45675724744796753
../_images/_notebooks_dynamic-shapes_14_1.png

Onwards, to benchmarking.

[8]:
# Model benchmark without Torch-TensorRT
model = resnet50_model.eval().to("cuda")
benchmark(model, input_shape=(16, 3, 224, 224), nruns=100)
Warm up ...
Start timing ...
Iteration 10/100, ave batch time 10.01 ms
Images processed per second= 1598
Iteration 20/100, ave batch time 10.01 ms
Images processed per second= 1598
Iteration 30/100, ave batch time 10.21 ms
Images processed per second= 1566
Iteration 40/100, ave batch time 10.33 ms
Images processed per second= 1549
Iteration 50/100, ave batch time 10.31 ms
Images processed per second= 1552
Iteration 60/100, ave batch time 10.25 ms
Images processed per second= 1560
Iteration 70/100, ave batch time 10.20 ms
Images processed per second= 1568
Iteration 80/100, ave batch time 10.18 ms
Images processed per second= 1572
Iteration 90/100, ave batch time 10.16 ms
Images processed per second= 1574
Iteration 100/100, ave batch time 10.15 ms
Images processed per second= 1575
Input shape: torch.Size([16, 3, 224, 224])
Output features size: torch.Size([16, 1000])
Average batch time: 10.15 ms

Benchmarking with Torch-TRT (without dynamic shapes)

[9]:
import torch_tensorrt

trt_model_without_ds = torch_tensorrt.compile(model, inputs = [torch_tensorrt.Input((32, 3, 224, 224), dtype=torch.float32)],
    enabled_precisions = torch.float32, # Run with FP32
    workspace_size = 1 << 33
)
WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter
[10]:
benchmark(trt_model_without_ds, input_shape=(32, 3, 224, 224), nruns=100)
Warm up ...
Start timing ...
Iteration 10/100, ave batch time 6.10 ms
Images processed per second= 5242
Iteration 20/100, ave batch time 6.12 ms
Images processed per second= 5231
Iteration 30/100, ave batch time 6.14 ms
Images processed per second= 5215
Iteration 40/100, ave batch time 6.14 ms
Images processed per second= 5207
Iteration 50/100, ave batch time 6.15 ms
Images processed per second= 5202
Iteration 60/100, ave batch time 6.28 ms
Images processed per second= 5094
Iteration 70/100, ave batch time 6.26 ms
Images processed per second= 5110
Iteration 80/100, ave batch time 6.25 ms
Images processed per second= 5118
Iteration 90/100, ave batch time 6.25 ms
Images processed per second= 5115
Iteration 100/100, ave batch time 6.40 ms
Images processed per second= 5002
Input shape: torch.Size([32, 3, 224, 224])
Output features size: torch.Size([32, 1000])
Average batch time: 6.40 ms

With the baseline ready, we can proceed to the section working discussing dynamic shapes!


Working with Dynamic shapes in Torch TRT

Enabling “Dynamic Shaped” tensors to be used is essentially enabling the ability to defer defining the shape of tensors until runetime. Torch TensorRT simply leverages TensorRT’s Dynamic shape support. You can read more about TensorRT’s implementation in the TensorRT Documentation.

To make use of dynamic shapes, you need to provide three shapes: * min_shape: The minimum size of the tensor considered for optimizations. * opt_shape: The optimizations will be done with an effort to maximize performance for this shape. * min_shape: The maximum size of the tensor considered for optimizations.

Generally, users can expect best performance within the specified ranges. Performance for other shapes may be be lower for other shapes (depending on the model ops and GPU used)

In the following example, we will showcase varing batch size, which is the zeroth dimension of our input tensors. As Convolution operations require that the channel dimension be a build-time constant, we won’t be changing sizes of other channels in this example, but for models which contain ops conducive to changes in other channels, this functionality can be freely used.

[11]:
# The compiled module will have precision as specified by "op_precision".
# Here, it will have FP32 precision.
trt_model_with_ds = torch_tensorrt.compile(model, inputs = [torch_tensorrt.Input(
        min_shape=(16, 3, 224, 224),
        opt_shape=(32, 3, 224, 224),
        max_shape=(64, 3, 224, 224),
        dtype=torch.float32)],
    enabled_precisions = torch.float32, # Run with FP32
    workspace_size = 1 << 33
)
WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter
[12]:
benchmark(trt_model_with_ds, input_shape=(16, 3, 224, 224), nruns=100)
Warm up ...
Start timing ...
Iteration 10/100, ave batch time 3.88 ms
Images processed per second= 4122
Iteration 20/100, ave batch time 3.89 ms
Images processed per second= 4116
Iteration 30/100, ave batch time 3.88 ms
Images processed per second= 4123
Iteration 40/100, ave batch time 3.86 ms
Images processed per second= 4142
Iteration 50/100, ave batch time 3.85 ms
Images processed per second= 4156
Iteration 60/100, ave batch time 3.84 ms
Images processed per second= 4166
Iteration 70/100, ave batch time 3.84 ms
Images processed per second= 4170
Iteration 80/100, ave batch time 3.83 ms
Images processed per second= 4172
Iteration 90/100, ave batch time 3.83 ms
Images processed per second= 4176
Iteration 100/100, ave batch time 3.83 ms
Images processed per second= 4178
Input shape: torch.Size([16, 3, 224, 224])
Output features size: torch.Size([16, 1000])
Average batch time: 3.83 ms
[13]:
benchmark(trt_model_with_ds, input_shape=(32, 3, 224, 224), nruns=100)
Warm up ...
Start timing ...
Iteration 10/100, ave batch time 6.71 ms
Images processed per second= 4767
Iteration 20/100, ave batch time 6.48 ms
Images processed per second= 4935
Iteration 30/100, ave batch time 6.39 ms
Images processed per second= 5005
Iteration 40/100, ave batch time 6.38 ms
Images processed per second= 5014
Iteration 50/100, ave batch time 6.38 ms
Images processed per second= 5016
Iteration 60/100, ave batch time 6.37 ms
Images processed per second= 5020
Iteration 70/100, ave batch time 6.37 ms
Images processed per second= 5024
Iteration 80/100, ave batch time 6.37 ms
Images processed per second= 5027
Iteration 90/100, ave batch time 6.37 ms
Images processed per second= 5026
Iteration 100/100, ave batch time 6.38 ms
Images processed per second= 5018
Input shape: torch.Size([32, 3, 224, 224])
Output features size: torch.Size([32, 1000])
Average batch time: 6.38 ms
[14]:
benchmark(trt_model_with_ds, input_shape=(64, 3, 224, 224), nruns=100)
Warm up ...
Start timing ...
Iteration 10/100, ave batch time 12.31 ms
Images processed per second= 5197
Iteration 20/100, ave batch time 12.42 ms
Images processed per second= 5153
Iteration 30/100, ave batch time 12.85 ms
Images processed per second= 4980
Iteration 40/100, ave batch time 12.71 ms
Images processed per second= 5033
Iteration 50/100, ave batch time 12.67 ms
Images processed per second= 5052
Iteration 60/100, ave batch time 12.63 ms
Images processed per second= 5067
Iteration 70/100, ave batch time 12.58 ms
Images processed per second= 5088
Iteration 80/100, ave batch time 12.56 ms
Images processed per second= 5096
Iteration 90/100, ave batch time 12.55 ms
Images processed per second= 5100
Iteration 100/100, ave batch time 12.57 ms
Images processed per second= 5091
Input shape: torch.Size([64, 3, 224, 224])
Output features size: torch.Size([64, 1000])
Average batch time: 12.57 ms

What’s Next?

Check out the TensorRT Getting started page for more tutorials, or visit the Torch-TensorRT documentation for more information!

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources