# Tensor transforms and JIT¶

This example illustrates various features that are now supported by the image transformations on Tensor images. In particular, we show how image transforms can be performed on GPU, and how one can also script them using JIT compilation.

Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric and presented multiple limitations due to that. Now, since v0.8.0, transforms implementations are Tensor and PIL compatible and we can achieve the following new features:

• transform multi-band torch tensor images (with more than 3-4 channels)

• torchscript transforms together with your model for deployment

• support for GPU acceleration

• batched transformation such as for videos

• read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats)

Note

These features are only possible with Tensor images.

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision.transforms as T

plt.rcParams["savefig.bbox"] = 'tight'
torch.manual_seed(1)

def show(imgs):
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
for i, img in enumerate(imgs):
img = T.ToPILImage()(img.to('cpu'))
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])


The read_image() function allows to read an image and directly load it as a tensor

dog1 = read_image(str(Path('assets') / 'dog1.jpg'))
show([dog1, dog2])


## Transforming images on GPU¶

Most transforms natively support tensors on top of PIL images (to visualize the effect of the transforms, you may refer to see Illustration of transforms). Using tensor images, we can run the transforms on GPUs if cuda is available!

import torch.nn as nn

transforms = torch.nn.Sequential(
T.RandomCrop(224),
T.RandomHorizontalFlip(p=0.3),
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
dog1 = dog1.to(device)
dog2 = dog2.to(device)

transformed_dog1 = transforms(dog1)
transformed_dog2 = transforms(dog2)
show([transformed_dog1, transformed_dog2])


## Scriptable transforms for easier deployment via torchscript¶

We now show how to combine image transformations and a model forward pass, while using torch.jit.script to obtain a single scripted module.

Let’s define a Predictor module that transforms the input tensor and then applies an ImageNet model on it.

from torchvision.models import resnet18

class Predictor(nn.Module):

def __init__(self):
super().__init__()
self.resnet18 = resnet18(pretrained=True, progress=False).eval()
self.transforms = nn.Sequential(
T.Resize([256, ]),  # We use single int value inside a list due to torchscript type restrictions
T.CenterCrop(224),
T.ConvertImageDtype(torch.float),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.transforms(x)
y_pred = self.resnet18(x)
return y_pred.argmax(dim=1)


Now, let’s define scripted and non-scripted instances of Predictor and apply it on multiple tensor images of the same size

predictor = Predictor().to(device)
scripted_predictor = torch.jit.script(predictor).to(device)

batch = torch.stack([dog1, dog2]).to(device)

res = predictor(batch)
res_scripted = scripted_predictor(batch)


Out:

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/matti/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
/home/matti/miniconda3/envs/pytorch-test/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /opt/conda/conda-bld/pytorch_1623448216815/work/c10/core/TensorImpl.h:1156.)


We can verify that the prediction of the scripted and non-scripted models are the same:

import json

with open(Path('assets') / 'imagenet_class_index.json', 'r') as labels_file:

for i, (pred, pred_scripted) in enumerate(zip(res, res_scripted)):
assert pred == pred_scripted
print(f"Prediction for Dog {i + 1}: {labels[str(pred.item())]}")


Out:

Prediction for Dog 1: ['n02113023', 'Pembroke']
Prediction for Dog 2: ['n02106662', 'German_shepherd']


Since the model is scripted, it can be easily dumped on disk an re-used

import tempfile

with tempfile.NamedTemporaryFile() as f:
scripted_predictor.save(f.name)