.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/plot_scripted_tensor_transforms.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_plot_scripted_tensor_transforms.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_plot_scripted_tensor_transforms.py:


=========================
Tensor transforms and JIT
=========================

This example illustrates various features that are now supported by the
:ref:`image transformations <transforms>` on Tensor images. In particular, we
show how image transforms can be performed on GPU, and how one can also script
them using JIT compilation.

Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric
and presented multiple limitations due to that. Now, since v0.8.0, transforms
implementations are Tensor and PIL compatible, and we can achieve the following
new features:

- transform multi-band torch tensor images (with more than 3-4 channels)
- torchscript transforms together with your model for deployment
- support for GPU acceleration
- batched transformation such as for videos
- read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats)

.. note::
    These features are only possible with **Tensor** images.

.. GENERATED FROM PYTHON SOURCE LINES 25-48

.. code-block:: default


    from pathlib import Path

    import matplotlib.pyplot as plt
    import numpy as np

    import torch
    import torchvision.transforms as T
    from torchvision.io import read_image


    plt.rcParams["savefig.bbox"] = 'tight'
    torch.manual_seed(1)


    def show(imgs):
        fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
        for i, img in enumerate(imgs):
            img = T.ToPILImage()(img.to('cpu'))
            axs[0, i].imshow(np.asarray(img))
            axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])









.. GENERATED FROM PYTHON SOURCE LINES 49-51

The :func:`~torchvision.io.read_image` function allows to read an image and
directly load it as a tensor

.. GENERATED FROM PYTHON SOURCE LINES 51-56

.. code-block:: default


    dog1 = read_image(str(Path('assets') / 'dog1.jpg'))
    dog2 = read_image(str(Path('assets') / 'dog2.jpg'))
    show([dog1, dog2])




.. image-sg:: /auto_examples/images/sphx_glr_plot_scripted_tensor_transforms_001.png
   :alt: plot scripted tensor transforms
   :srcset: /auto_examples/images/sphx_glr_plot_scripted_tensor_transforms_001.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 57-63

Transforming images on GPU
--------------------------
Most transforms natively support tensors on top of PIL images (to visualize
the effect of the transforms, you may refer to see
:ref:`sphx_glr_auto_examples_plot_transforms.py`).
Using tensor images, we can run the transforms on GPUs if cuda is available!

.. GENERATED FROM PYTHON SOURCE LINES 63-79

.. code-block:: default


    import torch.nn as nn

    transforms = torch.nn.Sequential(
        T.RandomCrop(224),
        T.RandomHorizontalFlip(p=0.3),
    )

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    dog1 = dog1.to(device)
    dog2 = dog2.to(device)

    transformed_dog1 = transforms(dog1)
    transformed_dog2 = transforms(dog2)
    show([transformed_dog1, transformed_dog2])




.. image-sg:: /auto_examples/images/sphx_glr_plot_scripted_tensor_transforms_002.png
   :alt: plot scripted tensor transforms
   :srcset: /auto_examples/images/sphx_glr_plot_scripted_tensor_transforms_002.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 80-87

Scriptable transforms for easier deployment via torchscript
-----------------------------------------------------------
We now show how to combine image transformations and a model forward pass,
while using ``torch.jit.script`` to obtain a single scripted module.

Let's define a ``Predictor`` module that transforms the input tensor and then
applies an ImageNet model on it.

.. GENERATED FROM PYTHON SOURCE LINES 87-106

.. code-block:: default


    from torchvision.models import resnet18, ResNet18_Weights


    class Predictor(nn.Module):

        def __init__(self):
            super().__init__()
            weights = ResNet18_Weights.DEFAULT
            self.resnet18 = resnet18(weights=weights, progress=False).eval()
            self.transforms = weights.transforms()

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            with torch.no_grad():
                x = self.transforms(x)
                y_pred = self.resnet18(x)
                return y_pred.argmax(dim=1)









.. GENERATED FROM PYTHON SOURCE LINES 107-109

Now, let's define scripted and non-scripted instances of ``Predictor`` and
apply it on multiple tensor images of the same size

.. GENERATED FROM PYTHON SOURCE LINES 109-118

.. code-block:: default


    predictor = Predictor().to(device)
    scripted_predictor = torch.jit.script(predictor).to(device)

    batch = torch.stack([dog1, dog2]).to(device)

    res = predictor(batch)
    res_scripted = scripted_predictor(batch)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    /home/circleci/project/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).
      warnings.warn(
    /home/circleci/project/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).
      warnings.warn(




.. GENERATED FROM PYTHON SOURCE LINES 119-121

We can verify that the prediction of the scripted and non-scripted models are
the same:

.. GENERATED FROM PYTHON SOURCE LINES 121-131

.. code-block:: default


    import json

    with open(Path('assets') / 'imagenet_class_index.json') as labels_file:
        labels = json.load(labels_file)

    for i, (pred, pred_scripted) in enumerate(zip(res, res_scripted)):
        assert pred == pred_scripted
        print(f"Prediction for Dog {i + 1}: {labels[str(pred.item())]}")





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Prediction for Dog 1: ['n02113023', 'Pembroke']
    Prediction for Dog 2: ['n02106662', 'German_shepherd']




.. GENERATED FROM PYTHON SOURCE LINES 132-133

Since the model is scripted, it can be easily dumped on disk and re-used

.. GENERATED FROM PYTHON SOURCE LINES 133-142

.. code-block:: default


    import tempfile

    with tempfile.NamedTemporaryFile() as f:
        scripted_predictor.save(f.name)

        dumped_scripted_predictor = torch.jit.load(f.name)
        res_scripted_dumped = dumped_scripted_predictor(batch)
    assert (res_scripted_dumped == res_scripted).all()




.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    code/__torch__/torchvision/transforms/functional.py:188: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).





.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  1.535 seconds)


.. _sphx_glr_download_auto_examples_plot_scripted_tensor_transforms.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example




    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_scripted_tensor_transforms.py <plot_scripted_tensor_transforms.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_scripted_tensor_transforms.ipynb <plot_scripted_tensor_transforms.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_