.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/others/plot_repurposing_annotations.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_others_plot_repurposing_annotations.py:


=====================================
Repurposing masks into bounding boxes
=====================================

.. note::
    Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_repurposing_annotations.ipynb>`_
    or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_repurposing_annotations.py>` to download the full example code.

The following example illustrates the operations available
the :ref:`torchvision.ops <ops>` module for repurposing
segmentation masks into object localization annotations for different tasks
(e.g. transforming masks used by instance and panoptic segmentation
methods into bounding boxes used by object detection methods).

.. GENERATED FROM PYTHON SOURCE LINES 16-42

.. code-block:: Python



    import os
    import numpy as np
    import torch
    import matplotlib.pyplot as plt

    import torchvision.transforms.functional as F


    ASSETS_DIRECTORY = "../assets"

    plt.rcParams["savefig.bbox"] = "tight"


    def show(imgs):
        if not isinstance(imgs, list):
            imgs = [imgs]
        fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
        for i, img in enumerate(imgs):
            img = img.detach()
            img = F.to_pil_image(img)
            axs[0, i].imshow(np.asarray(img))
            axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])









.. GENERATED FROM PYTHON SOURCE LINES 44-59

Masks
-----
In tasks like instance and panoptic segmentation, masks are commonly defined, and are defined by this package,
as a multi-dimensional array (e.g. a NumPy array or a PyTorch tensor) with the following shape:

      (num_objects, height, width)

Where num_objects is the number of annotated objects in the image. Each (height, width) object corresponds to exactly
one object. For example, if your input image has the dimensions 224 x 224 and has four annotated objects the shape
of your masks annotation has the following shape:

      (4, 224, 224).

A nice property of masks is that they can be easily repurposed to be used in methods to solve a variety of object
localization tasks.

.. GENERATED FROM PYTHON SOURCE LINES 61-67

Converting Masks to Bounding Boxes
-----------------------------------------------
For example, the :func:`~torchvision.ops.masks_to_boxes` operation can be used to
transform masks into bounding boxes that can be
used as input to detection models such as FasterRCNN and RetinaNet.
We will take images and masks from the `PenFudan Dataset <https://www.cis.upenn.edu/~jshi/ped_html/>`_.

.. GENERATED FROM PYTHON SOURCE LINES 67-77

.. code-block:: Python



    from torchvision.io import read_image

    img_path = os.path.join(ASSETS_DIRECTORY, "FudanPed00054.png")
    mask_path = os.path.join(ASSETS_DIRECTORY, "FudanPed00054_mask.png")
    img = read_image(img_path)
    mask = read_image(mask_path)









.. GENERATED FROM PYTHON SOURCE LINES 78-81

Here the masks are represented as a PNG Image, with floating point values.
Each pixel is encoded as different colors, with 0 being background.
Notice that the spatial dimensions of image and mask match.

.. GENERATED FROM PYTHON SOURCE LINES 81-86

.. code-block:: Python


    print(mask.size())
    print(img.size())
    print(mask)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    torch.Size([1, 498, 533])
    torch.Size([3, 498, 533])
    tensor([[[0, 0, 0,  ..., 0, 0, 0],
             [0, 0, 0,  ..., 0, 0, 0],
             [0, 0, 0,  ..., 0, 0, 0],
             ...,
             [0, 0, 0,  ..., 0, 0, 0],
             [0, 0, 0,  ..., 0, 0, 0],
             [0, 0, 0,  ..., 0, 0, 0]]], dtype=torch.uint8)




.. GENERATED FROM PYTHON SOURCE LINES 87-98

.. code-block:: Python


    # We get the unique colors, as these would be the object ids.
    obj_ids = torch.unique(mask)

    # first id is the background, so remove it.
    obj_ids = obj_ids[1:]

    # split the color-encoded mask into a set of boolean masks.
    # Note that this snippet would work as well if the masks were float values instead of ints.
    masks = mask == obj_ids[:, None, None]








.. GENERATED FROM PYTHON SOURCE LINES 99-104

Now the masks are a boolean tensor.
The first dimension in this case 3 and denotes the number of instances: there are 3 people in the image.
The other two dimensions are height and width, which are equal to the dimensions of the image.
For each instance, the boolean tensors represent if the particular pixel
belongs to the segmentation mask of the image.

.. GENERATED FROM PYTHON SOURCE LINES 104-108

.. code-block:: Python


    print(masks.size())
    print(masks)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    torch.Size([3, 498, 533])
    tensor([[[False, False, False,  ..., False, False, False],
             [False, False, False,  ..., False, False, False],
             [False, False, False,  ..., False, False, False],
             ...,
             [False, False, False,  ..., False, False, False],
             [False, False, False,  ..., False, False, False],
             [False, False, False,  ..., False, False, False]],

            [[False, False, False,  ..., False, False, False],
             [False, False, False,  ..., False, False, False],
             [False, False, False,  ..., False, False, False],
             ...,
             [False, False, False,  ..., False, False, False],
             [False, False, False,  ..., False, False, False],
             [False, False, False,  ..., False, False, False]],

            [[False, False, False,  ..., False, False, False],
             [False, False, False,  ..., False, False, False],
             [False, False, False,  ..., False, False, False],
             ...,
             [False, False, False,  ..., False, False, False],
             [False, False, False,  ..., False, False, False],
             [False, False, False,  ..., False, False, False]]])




.. GENERATED FROM PYTHON SOURCE LINES 109-111

Let us visualize an image and plot its corresponding segmentation masks.
We will use the :func:`~torchvision.utils.draw_segmentation_masks` to draw the segmentation masks.

.. GENERATED FROM PYTHON SOURCE LINES 111-120

.. code-block:: Python


    from torchvision.utils import draw_segmentation_masks

    drawn_masks = []
    for mask in masks:
        drawn_masks.append(draw_segmentation_masks(img, mask, alpha=0.8, colors="blue"))

    show(drawn_masks)




.. image-sg:: /auto_examples/others/images/sphx_glr_plot_repurposing_annotations_001.png
   :alt: plot repurposing annotations
   :srcset: /auto_examples/others/images/sphx_glr_plot_repurposing_annotations_001.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 121-124

To convert the boolean masks into bounding boxes.
We will use the :func:`~torchvision.ops.masks_to_boxes` from the torchvision.ops module
It returns the boxes in ``(xmin, ymin, xmax, ymax)`` format.

.. GENERATED FROM PYTHON SOURCE LINES 124-131

.. code-block:: Python


    from torchvision.ops import masks_to_boxes

    boxes = masks_to_boxes(masks)
    print(boxes.size())
    print(boxes)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    torch.Size([3, 4])
    tensor([[ 96., 134., 181., 417.],
            [286., 113., 357., 331.],
            [363., 120., 436., 328.]])




.. GENERATED FROM PYTHON SOURCE LINES 132-135

As the shape denotes, there are 3 boxes and in ``(xmin, ymin, xmax, ymax)`` format.
These can be visualized very easily with :func:`~torchvision.utils.draw_bounding_boxes` utility
provided in :ref:`torchvision.utils <utils>`.

.. GENERATED FROM PYTHON SOURCE LINES 135-141

.. code-block:: Python


    from torchvision.utils import draw_bounding_boxes

    drawn_boxes = draw_bounding_boxes(img, boxes, colors="red")
    show(drawn_boxes)




.. image-sg:: /auto_examples/others/images/sphx_glr_plot_repurposing_annotations_002.png
   :alt: plot repurposing annotations
   :srcset: /auto_examples/others/images/sphx_glr_plot_repurposing_annotations_002.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 142-145

These boxes can now directly be used by detection models in torchvision.
Here is demo with a Faster R-CNN model loaded from
:func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn`

.. GENERATED FROM PYTHON SOURCE LINES 145-160

.. code-block:: Python


    from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights

    weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
    model = fasterrcnn_resnet50_fpn(weights=weights, progress=False)
    print(img.size())

    tranforms = weights.transforms()
    img = tranforms(img)
    target = {}
    target["boxes"] = boxes
    target["labels"] = labels = torch.ones((masks.size(0),), dtype=torch.int64)
    detection_outputs = model(img.unsqueeze(0), [target])






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
    torch.Size([3, 498, 533])




.. GENERATED FROM PYTHON SOURCE LINES 161-169

Converting Segmentation Dataset to Detection Dataset
----------------------------------------------------

With this utility it becomes very simple to convert a segmentation dataset to a detection dataset.
With this we can now use a segmentation dataset to train a detection model.
One can similarly convert panoptic dataset to detection dataset.
Here is an example where we re-purpose the dataset from the
`PenFudan Detection Tutorial <https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html>`_.

.. GENERATED FROM PYTHON SOURCE LINES 169-212

.. code-block:: Python


    class SegmentationToDetectionDataset(torch.utils.data.Dataset):
        def __init__(self, root, transforms):
            self.root = root
            self.transforms = transforms
            # load all image files, sorting them to
            # ensure that they are aligned
            self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
            self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))

        def __getitem__(self, idx):
            # load images and masks
            img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
            mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])

            img = read_image(img_path)
            mask = read_image(mask_path)

            img = F.convert_image_dtype(img, dtype=torch.float)
            mask = F.convert_image_dtype(mask, dtype=torch.float)

            # We get the unique colors, as these would be the object ids.
            obj_ids = torch.unique(mask)

            # first id is the background, so remove it.
            obj_ids = obj_ids[1:]

            # split the color-encoded mask into a set of boolean masks.
            masks = mask == obj_ids[:, None, None]

            boxes = masks_to_boxes(masks)

            # there is only one class
            labels = torch.ones((masks.shape[0],), dtype=torch.int64)

            target = {}
            target["boxes"] = boxes
            target["labels"] = labels

            if self.transforms is not None:
                img, target = self.transforms(img, target)

            return img, target








.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 2.558 seconds)


.. _sphx_glr_download_auto_examples_others_plot_repurposing_annotations.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_repurposing_annotations.ipynb <plot_repurposing_annotations.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_repurposing_annotations.py <plot_repurposing_annotations.py>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_