.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/plot_transforms_v2_e2e.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_plot_transforms_v2_e2e.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_plot_transforms_v2_e2e.py:


==================================================
Transforms v2: End-to-end object detection example
==================================================

Object detection is not supported out of the box by ``torchvision.transforms`` v1, since it only supports images.
``torchvision.transforms.v2`` enables jointly transforming images, videos, bounding boxes, and masks. This example
showcases an end-to-end object detection training using the stable ``torchvisio.datasets`` and ``torchvision.models`` as
well as the new ``torchvision.transforms.v2`` v2 API.

.. GENERATED FROM PYTHON SOURCE LINES 11-51

.. code-block:: default


    import pathlib
    from collections import defaultdict

    import PIL.Image

    import torch
    import torch.utils.data

    import torchvision


    def show(sample):
        import matplotlib.pyplot as plt

        from torchvision.transforms.v2 import functional as F
        from torchvision.utils import draw_bounding_boxes

        image, target = sample
        if isinstance(image, PIL.Image.Image):
            image = F.to_image_tensor(image)
        image = F.convert_dtype(image, torch.uint8)
        annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)

        fig, ax = plt.subplots()
        ax.imshow(annotated_image.permute(1, 2, 0).numpy())
        ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        fig.tight_layout()

        fig.show()


    # We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
    # some APIs may slightly change in the future
    torchvision.disable_beta_transforms_warning()

    from torchvision import models, datasets
    import torchvision.transforms.v2 as transforms









.. GENERATED FROM PYTHON SOURCE LINES 52-54

We start off by loading the :class:`~torchvision.datasets.CocoDetection` dataset to have a look at what it currently
returns, and we'll see how to convert it to a format that is compatible with our new transforms.

.. GENERATED FROM PYTHON SOURCE LINES 54-71

.. code-block:: default



    def load_example_coco_detection_dataset(**kwargs):
        # This loads fake data for illustration purposes of this example. In practice, you'll have
        # to replace this with the proper data
        root = pathlib.Path("assets") / "coco"
        return datasets.CocoDetection(str(root / "images"), str(root / "instances.json"), **kwargs)


    dataset = load_example_coco_detection_dataset()

    sample = dataset[0]
    image, target = sample
    print(type(image))
    print(type(target), type(target[0]), list(target[0].keys()))






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    loading annotations into memory...
    Done (t=0.00s)
    creating index...
    index created!
    <class 'PIL.Image.Image'>
    <class 'list'> <class 'dict'> ['segmentation', 'iscrowd', 'image_id', 'bbox', 'category_id', 'id']




.. GENERATED FROM PYTHON SOURCE LINES 72-79

The dataset returns a two-tuple with the first item being a :class:`PIL.Image.Image` and second one a list of
dictionaries, which each containing the annotations for a single object instance. As is, this format is not compatible
with the ``torchvision.transforms.v2``, nor with the models. To overcome that, we provide the
:func:`~torchvision.datasets.wrap_dataset_for_transforms_v2` function. For
:class:`~torchvision.datasets.CocoDetection`, this changes the target structure to a single dictionary of lists. It
also adds the key-value-pairs ``"boxes"``, ``"masks"``, and ``"labels"`` wrapped in the corresponding
``torchvision.datapoints``.

.. GENERATED FROM PYTHON SOURCE LINES 79-88

.. code-block:: default


    dataset = datasets.wrap_dataset_for_transforms_v2(dataset)

    sample = dataset[0]
    image, target = sample
    print(type(image))
    print(type(target), list(target.keys()))
    print(type(target["boxes"]), type(target["masks"]), type(target["labels"]))





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    <class 'PIL.Image.Image'>
    <class 'dict'> ['segmentation', 'iscrowd', 'image_id', 'bbox', 'category_id', 'id', 'boxes', 'masks', 'labels']
    <class 'torchvision.datapoints._bounding_box.BoundingBox'> <class 'torchvision.datapoints._mask.Mask'> <class 'torch.Tensor'>




.. GENERATED FROM PYTHON SOURCE LINES 89-90

As baseline, let's have a look at a sample without transformations:

.. GENERATED FROM PYTHON SOURCE LINES 90-94

.. code-block:: default


    show(sample)





.. image-sg:: /auto_examples/images/sphx_glr_plot_transforms_v2_e2e_001.png
   :alt: plot transforms v2 e2e
   :srcset: /auto_examples/images/sphx_glr_plot_transforms_v2_e2e_001.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 95-97

With the dataset properly set up, we can now define the augmentation pipeline. This is done the same way it is done in
``torchvision.transforms`` v1, but now handles bounding boxes and masks without any extra configuration.

.. GENERATED FROM PYTHON SOURCE LINES 97-112

.. code-block:: default


    transform = transforms.Compose(
        [
            transforms.RandomPhotometricDistort(),
            transforms.RandomZoomOut(
                fill=defaultdict(lambda: 0, {PIL.Image.Image: (123, 117, 104)})
            ),
            transforms.RandomIoUCrop(),
            transforms.RandomHorizontalFlip(),
            transforms.ToImageTensor(),
            transforms.ConvertImageDtype(torch.float32),
            transforms.SanitizeBoundingBox(),
        ]
    )








.. GENERATED FROM PYTHON SOURCE LINES 113-120

.. note::
   Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBox` transform is a no-op in this example, but it
   should be placed at least once at the end of a detection pipeline to remove degenerate bounding boxes as well as
   the corresponding labels and optionally masks. It is particularly critical to add it if
   :class:`~torchvision.transforms.v2.RandomIoUCrop` was used.

Let's look how the sample looks like with our augmentation pipeline in place:

.. GENERATED FROM PYTHON SOURCE LINES 120-130

.. code-block:: default


    dataset = load_example_coco_detection_dataset(transforms=transform)
    dataset = datasets.wrap_dataset_for_transforms_v2(dataset)

    torch.manual_seed(3141)
    sample = dataset[0]

    show(sample)





.. image-sg:: /auto_examples/images/sphx_glr_plot_transforms_v2_e2e_002.png
   :alt: plot transforms v2 e2e
   :srcset: /auto_examples/images/sphx_glr_plot_transforms_v2_e2e_002.png
   :class: sphx-glr-single-img


.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    loading annotations into memory...
    Done (t=0.00s)
    creating index...
    index created!




.. GENERATED FROM PYTHON SOURCE LINES 132-134

We can see that the color of the image was distorted, we zoomed out on it (off center) and flipped it horizontally.
In all of this, the bounding box was transformed accordingly. And without any further ado, we can start training.

.. GENERATED FROM PYTHON SOURCE LINES 134-153

.. code-block:: default


    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=2,
        # We need a custom collation function here, since the object detection models expect a
        # sequence of images and target dictionaries. The default collation function tries to
        # `torch.stack` the individual elements, which fails in general for object detection,
        # because the number of object instances varies between the samples. This is the same for
        # `torchvision.transforms` v1
        collate_fn=lambda batch: tuple(zip(*batch)),
    )

    model = models.get_model("ssd300_vgg16", weights=None, weights_backbone=None).train()

    for images, targets in data_loader:
        loss_dict = model(images, targets)
        print(loss_dict)
        # Put your training logic here
        break




.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    {'bbox_regression': tensor(2.4982, grad_fn=<DivBackward0>), 'classification': tensor(57.6569, grad_fn=<DivBackward0>)}





.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  2.846 seconds)


.. _sphx_glr_download_auto_examples_plot_transforms_v2_e2e.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example




    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_transforms_v2_e2e.py <plot_transforms_v2_e2e.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_transforms_v2_e2e.ipynb <plot_transforms_v2_e2e.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_