.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/others/plot_visualization_utils.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_others_plot_visualization_utils.py:


=======================
Visualization utilities
=======================

.. note::
    Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_visualization_utils.ipynb>`_
    or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_visualization_utils.py>` to download the full example code.

This example illustrates some of the utilities that torchvision offers for
visualizing images, bounding boxes, segmentation masks and keypoints.

.. GENERATED FROM PYTHON SOURCE LINES 13-36

.. code-block:: Python



    import torch
    import numpy as np
    import matplotlib.pyplot as plt

    import torchvision.transforms.functional as F


    plt.rcParams["savefig.bbox"] = 'tight'


    def show(imgs):
        if not isinstance(imgs, list):
            imgs = [imgs]
        fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
        for i, img in enumerate(imgs):
            img = img.detach()
            img = F.to_pil_image(img)
            axs[0, i].imshow(np.asarray(img))
            axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])









.. GENERATED FROM PYTHON SOURCE LINES 38-43

Visualizing a grid of images
----------------------------
The :func:`~torchvision.utils.make_grid` function can be used to create a
tensor that represents multiple images in a grid.  This util requires a single
image of dtype ``uint8`` as input.

.. GENERATED FROM PYTHON SOURCE LINES 43-55

.. code-block:: Python


    from torchvision.utils import make_grid
    from torchvision.io import read_image
    from pathlib import Path

    dog1_int = read_image(str(Path('../assets') / 'dog1.jpg'))
    dog2_int = read_image(str(Path('../assets') / 'dog2.jpg'))
    dog_list = [dog1_int, dog2_int]

    grid = make_grid(dog_list)
    show(grid)




.. image-sg:: /auto_examples/others/images/sphx_glr_plot_visualization_utils_001.png
   :alt: plot visualization utils
   :srcset: /auto_examples/others/images/sphx_glr_plot_visualization_utils_001.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 56-61

Visualizing bounding boxes
--------------------------
We can use :func:`~torchvision.utils.draw_bounding_boxes` to draw boxes on an
image. We can set the colors, labels, width as well as font and font size.
The boxes are in ``(xmin, ymin, xmax, ymax)`` format.

.. GENERATED FROM PYTHON SOURCE LINES 61-71

.. code-block:: Python


    from torchvision.utils import draw_bounding_boxes


    boxes = torch.tensor([[50, 50, 100, 200], [210, 150, 350, 430]], dtype=torch.float)
    colors = ["blue", "yellow"]
    result = draw_bounding_boxes(dog1_int, boxes, colors=colors, width=5)
    show(result)





.. image-sg:: /auto_examples/others/images/sphx_glr_plot_visualization_utils_002.png
   :alt: plot visualization utils
   :srcset: /auto_examples/others/images/sphx_glr_plot_visualization_utils_002.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 72-77

Naturally, we can also plot bounding boxes produced by torchvision detection
models.  Here is a demo with a Faster R-CNN model loaded from
:func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn`
model. For more details on the output of such models, you may
refer to :ref:`instance_seg_output`.

.. GENERATED FROM PYTHON SOURCE LINES 77-92

.. code-block:: Python


    from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights


    weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
    transforms = weights.transforms()

    images = [transforms(d) for d in dog_list]

    model = fasterrcnn_resnet50_fpn(weights=weights, progress=False)
    model = model.eval()

    outputs = model(images)
    print(outputs)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    [{'boxes': tensor([[215.9767, 171.1661, 402.0078, 378.7391],
            [344.6341, 172.6735, 357.6114, 220.1435],
            [153.1306, 185.5567, 172.9223, 254.7014]], grad_fn=<StackBackward0>), 'labels': tensor([18,  1,  1]), 'scores': tensor([0.9989, 0.0701, 0.0611], grad_fn=<IndexBackward0>)}, {'boxes': tensor([[ 23.5964, 132.4331, 449.9359, 493.0222],
            [225.8182, 124.6292, 467.2861, 492.2621],
            [ 18.5248, 135.4171, 420.9786, 479.2225]], grad_fn=<StackBackward0>), 'labels': tensor([18, 18, 17]), 'scores': tensor([0.9980, 0.0879, 0.0671], grad_fn=<IndexBackward0>)}]




.. GENERATED FROM PYTHON SOURCE LINES 93-95

Let's plot the boxes detected by our model. We will only plot the boxes with a
score greater than a given threshold.

.. GENERATED FROM PYTHON SOURCE LINES 95-103

.. code-block:: Python


    score_threshold = .8
    dogs_with_boxes = [
        draw_bounding_boxes(dog_int, boxes=output['boxes'][output['scores'] > score_threshold], width=4)
        for dog_int, output in zip(dog_list, outputs)
    ]
    show(dogs_with_boxes)




.. image-sg:: /auto_examples/others/images/sphx_glr_plot_visualization_utils_003.png
   :alt: plot visualization utils
   :srcset: /auto_examples/others/images/sphx_glr_plot_visualization_utils_003.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 104-119

Visualizing segmentation masks
------------------------------
The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to
draw segmentation masks on images. Semantic segmentation and instance
segmentation models have different outputs, so we will treat each
independently.

.. _semantic_seg_output:

Semantic segmentation models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We will see how to use it with torchvision's FCN Resnet-50, loaded with
:func:`~torchvision.models.segmentation.fcn_resnet50`. Let's start by looking
at the output of the model.

.. GENERATED FROM PYTHON SOURCE LINES 119-132

.. code-block:: Python


    from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights

    weights = FCN_ResNet50_Weights.DEFAULT
    transforms = weights.transforms(resize_size=None)

    model = fcn_resnet50(weights=weights, progress=False)
    model = model.eval()

    batch = torch.stack([transforms(d) for d in dog_list])
    output = model(batch)['out']
    print(output.shape, output.min().item(), output.max().item())





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Downloading: "https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth" to /root/.cache/torch/hub/checkpoints/fcn_resnet50_coco-1167a1af.pth
    torch.Size([2, 21, 500, 500]) -7.089669704437256 14.858257293701172




.. GENERATED FROM PYTHON SOURCE LINES 133-141

As we can see above, the output of the segmentation model is a tensor of shape
``(batch_size, num_classes, H, W)``. Each value is a non-normalized score, and
we can normalize them into ``[0, 1]`` by using a softmax. After the softmax,
we can interpret each value as a probability indicating how likely a given
pixel is to belong to a given class.

Let's plot the masks that have been detected for the dog class and for the
boat class:

.. GENERATED FROM PYTHON SOURCE LINES 141-154

.. code-block:: Python


    sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])}

    normalized_masks = torch.nn.functional.softmax(output, dim=1)

    dog_and_boat_masks = [
        normalized_masks[img_idx, sem_class_to_idx[cls]]
        for img_idx in range(len(dog_list))
        for cls in ('dog', 'boat')
    ]

    show(dog_and_boat_masks)




.. image-sg:: /auto_examples/others/images/sphx_glr_plot_visualization_utils_004.png
   :alt: plot visualization utils
   :srcset: /auto_examples/others/images/sphx_glr_plot_visualization_utils_004.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 155-162

As expected, the model is confident about the dog class, but not so much for
the boat class.

The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to
plots those masks on top of the original image. This function expects the
masks to be boolean masks, but our masks above contain probabilities in ``[0,
1]``. To get boolean masks, we can do the following:

.. GENERATED FROM PYTHON SOURCE LINES 162-169

.. code-block:: Python


    class_dim = 1
    boolean_dog_masks = (normalized_masks.argmax(class_dim) == sem_class_to_idx['dog'])
    print(f"shape = {boolean_dog_masks.shape}, dtype = {boolean_dog_masks.dtype}")
    show([m.float() for m in boolean_dog_masks])





.. image-sg:: /auto_examples/others/images/sphx_glr_plot_visualization_utils_005.png
   :alt: plot visualization utils
   :srcset: /auto_examples/others/images/sphx_glr_plot_visualization_utils_005.png
   :class: sphx-glr-single-img


.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    shape = torch.Size([2, 500, 500]), dtype = torch.bool




.. GENERATED FROM PYTHON SOURCE LINES 170-182

The line above where we define ``boolean_dog_masks`` is a bit cryptic, but you
can read it as the following query: "For which pixels is 'dog' the most likely
class?"

.. note::
  While we're using the ``normalized_masks`` here, we would have
  gotten the same result by using the non-normalized scores of the model
  directly (as the softmax operation preserves the order).

Now that we have boolean masks, we can use them with
:func:`~torchvision.utils.draw_segmentation_masks` to plot them on top of the
original images:

.. GENERATED FROM PYTHON SOURCE LINES 182-191

.. code-block:: Python


    from torchvision.utils import draw_segmentation_masks

    dogs_with_masks = [
        draw_segmentation_masks(img, masks=mask, alpha=0.7)
        for img, mask in zip(dog_list, boolean_dog_masks)
    ]
    show(dogs_with_masks)




.. image-sg:: /auto_examples/others/images/sphx_glr_plot_visualization_utils_006.png
   :alt: plot visualization utils
   :srcset: /auto_examples/others/images/sphx_glr_plot_visualization_utils_006.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 192-199

We can plot more than one mask per image! Remember that the model returned as
many masks as there are classes. Let's ask the same query as above, but this
time for *all* classes, not just the dog class: "For each pixel and each class
C, is class C the most likely class?"

This one is a bit more involved, so we'll first show how to do it with a
single image, and then we'll generalize to the batch

.. GENERATED FROM PYTHON SOURCE LINES 199-211

.. code-block:: Python


    num_classes = normalized_masks.shape[1]
    dog1_masks = normalized_masks[0]
    class_dim = 0
    dog1_all_classes_masks = dog1_masks.argmax(class_dim) == torch.arange(num_classes)[:, None, None]

    print(f"dog1_masks shape = {dog1_masks.shape}, dtype = {dog1_masks.dtype}")
    print(f"dog1_all_classes_masks = {dog1_all_classes_masks.shape}, dtype = {dog1_all_classes_masks.dtype}")

    dog_with_all_masks = draw_segmentation_masks(dog1_int, masks=dog1_all_classes_masks, alpha=.6)
    show(dog_with_all_masks)




.. image-sg:: /auto_examples/others/images/sphx_glr_plot_visualization_utils_007.png
   :alt: plot visualization utils
   :srcset: /auto_examples/others/images/sphx_glr_plot_visualization_utils_007.png
   :class: sphx-glr-single-img


.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    dog1_masks shape = torch.Size([21, 500, 500]), dtype = torch.float32
    dog1_all_classes_masks = torch.Size([21, 500, 500]), dtype = torch.bool




.. GENERATED FROM PYTHON SOURCE LINES 212-224

We can see in the image above that only 2 masks were drawn: the mask for the
background and the mask for the dog. This is because the model thinks that
only these 2 classes are the most likely ones across all the pixels. If the
model had detected another class as the most likely among other pixels, we
would have seen its mask above.

Removing the background mask is as simple as passing
``masks=dog1_all_classes_masks[1:]``, because the background class is the
class with index 0.

Let's now do the same but for an entire batch of images. The code is similar
but involves a bit more juggling with the dimensions.

.. GENERATED FROM PYTHON SOURCE LINES 224-238

.. code-block:: Python


    class_dim = 1
    all_classes_masks = normalized_masks.argmax(class_dim) == torch.arange(num_classes)[:, None, None, None]
    print(f"shape = {all_classes_masks.shape}, dtype = {all_classes_masks.dtype}")
    # The first dimension is the classes now, so we need to swap it
    all_classes_masks = all_classes_masks.swapaxes(0, 1)

    dogs_with_masks = [
        draw_segmentation_masks(img, masks=mask, alpha=.6)
        for img, mask in zip(dog_list, all_classes_masks)
    ]
    show(dogs_with_masks)





.. image-sg:: /auto_examples/others/images/sphx_glr_plot_visualization_utils_008.png
   :alt: plot visualization utils
   :srcset: /auto_examples/others/images/sphx_glr_plot_visualization_utils_008.png
   :class: sphx-glr-single-img


.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    shape = torch.Size([21, 2, 500, 500]), dtype = torch.bool




.. GENERATED FROM PYTHON SOURCE LINES 239-258

.. _instance_seg_output:

Instance segmentation models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Instance segmentation models have a significantly different output from the
semantic segmentation models. We will see here how to plot the masks for such
models. Let's start by analyzing the output of a Mask-RCNN model. Note that
these models don't require the images to be normalized, so we don't need to
use the normalized batch.

.. note::

    We will here describe the output of a Mask-RCNN model. The models in
    :ref:`object_det_inst_seg_pers_keypoint_det` all have a similar output
    format, but some of them may have extra info like keypoints for
    :func:`~torchvision.models.detection.keypointrcnn_resnet50_fpn`, and some
    of them may not have masks, like
    :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn`.

.. GENERATED FROM PYTHON SOURCE LINES 258-272

.. code-block:: Python


    from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights

    weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
    transforms = weights.transforms()

    images = [transforms(d) for d in dog_list]

    model = maskrcnn_resnet50_fpn(weights=weights, progress=False)
    model = model.eval()

    output = model(images)
    print(output)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /root/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth
    [{'boxes': tensor([[219.7444, 168.1722, 400.7378, 384.0263],
            [343.9716, 171.2287, 358.3447, 222.6263],
            [301.0303, 192.6917, 313.8879, 232.3154]], grad_fn=<StackBackward0>), 'labels': tensor([18,  1,  1]), 'scores': tensor([0.9987, 0.7187, 0.6525], grad_fn=<IndexBackward0>), 'masks': tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.],
              ...,
              [0., 0., 0.,  ..., 0., 0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.]]],


            [[[0., 0., 0.,  ..., 0., 0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.],
              ...,
              [0., 0., 0.,  ..., 0., 0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.]]],


            [[[0., 0., 0.,  ..., 0., 0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.],
              ...,
              [0., 0., 0.,  ..., 0., 0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.]]]], grad_fn=<UnsqueezeBackward0>)}, {'boxes': tensor([[ 44.6767, 137.9018, 446.5324, 487.3429],
            [  0.0000, 288.0053, 489.9292, 490.2352]], grad_fn=<StackBackward0>), 'labels': tensor([18, 15]), 'scores': tensor([0.9978, 0.0697], grad_fn=<IndexBackward0>), 'masks': tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.],
              ...,
              [0., 0., 0.,  ..., 0., 0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.]]],


            [[[0., 0., 0.,  ..., 0., 0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.],
              ...,
              [0., 0., 0.,  ..., 0., 0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.]]]], grad_fn=<UnsqueezeBackward0>)}]




.. GENERATED FROM PYTHON SOURCE LINES 273-289

Let's break this down. For each image in the batch, the model outputs some
detections (or instances). The number of detections varies for each input
image. Each instance is described by its bounding box, its label, its score
and its mask.

The way the output is organized is as follows: the output is a list of length
``batch_size``. Each entry in the list corresponds to an input image, and it
is a dict with keys 'boxes', 'labels', 'scores', and 'masks'. Each value
associated to those keys has ``num_instances`` elements in it.  In our case
above there are 3 instances detected in the first image, and 2 instances in
the second one.

The boxes can be plotted with :func:`~torchvision.utils.draw_bounding_boxes`
as above, but here we're more interested in the masks. These masks are quite
different from the masks that we saw above for the semantic segmentation
models.

.. GENERATED FROM PYTHON SOURCE LINES 289-295

.. code-block:: Python


    dog1_output = output[0]
    dog1_masks = dog1_output['masks']
    print(f"shape = {dog1_masks.shape}, dtype = {dog1_masks.dtype}, "
          f"min = {dog1_masks.min()}, max = {dog1_masks.max()}")





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    shape = torch.Size([3, 1, 500, 500]), dtype = torch.float32, min = 0.0, max = 0.9999862909317017




.. GENERATED FROM PYTHON SOURCE LINES 296-300

Here the masks correspond to probabilities indicating, for each pixel, how
likely it is to belong to the predicted label of that instance. Those
predicted labels correspond to the 'labels' element in the same output dict.
Let's see which labels were predicted for the instances of the first image.

.. GENERATED FROM PYTHON SOURCE LINES 300-304

.. code-block:: Python


    print("For the first dog, the following instances were detected:")
    print([weights.meta["categories"][label] for label in dog1_output['labels']])





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    For the first dog, the following instances were detected:
    ['dog', 'person', 'person']




.. GENERATED FROM PYTHON SOURCE LINES 305-312

Interestingly, the model detects two persons in the image. Let's go ahead and
plot those masks. Since :func:`~torchvision.utils.draw_segmentation_masks`
expects boolean masks, we need to convert those probabilities into boolean
values. Remember that the semantic of those masks is "How likely is this pixel
to belong to the predicted class?". As a result, a natural way of converting
those masks into boolean values is to threshold them with the 0.5 probability
(one could also choose a different threshold).

.. GENERATED FROM PYTHON SOURCE LINES 312-322

.. code-block:: Python


    proba_threshold = 0.5
    dog1_bool_masks = dog1_output['masks'] > proba_threshold
    print(f"shape = {dog1_bool_masks.shape}, dtype = {dog1_bool_masks.dtype}")

    # There's an extra dimension (1) to the masks. We need to remove it
    dog1_bool_masks = dog1_bool_masks.squeeze(1)

    show(draw_segmentation_masks(dog1_int, dog1_bool_masks, alpha=0.9))




.. image-sg:: /auto_examples/others/images/sphx_glr_plot_visualization_utils_009.png
   :alt: plot visualization utils
   :srcset: /auto_examples/others/images/sphx_glr_plot_visualization_utils_009.png
   :class: sphx-glr-single-img


.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    shape = torch.Size([3, 1, 500, 500]), dtype = torch.bool




.. GENERATED FROM PYTHON SOURCE LINES 323-326

The model seems to have properly detected the dog, but it also confused trees
with people. Looking more closely at the scores will help us plot more
relevant masks:

.. GENERATED FROM PYTHON SOURCE LINES 326-329

.. code-block:: Python


    print(dog1_output['scores'])





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    tensor([0.9987, 0.7187, 0.6525], grad_fn=<IndexBackward0>)




.. GENERATED FROM PYTHON SOURCE LINES 330-334

Clearly the model is more confident about the dog detection than it is about
the people detections. That's good news. When plotting the masks, we can ask
for only those that have a good score. Let's use a score threshold of .75
here, and also plot the masks of the second dog.

.. GENERATED FROM PYTHON SOURCE LINES 334-348

.. code-block:: Python


    score_threshold = .75

    boolean_masks = [
        out['masks'][out['scores'] > score_threshold] > proba_threshold
        for out in output
    ]

    dogs_with_masks = [
        draw_segmentation_masks(img, mask.squeeze(1))
        for img, mask in zip(dog_list, boolean_masks)
    ]
    show(dogs_with_masks)




.. image-sg:: /auto_examples/others/images/sphx_glr_plot_visualization_utils_010.png
   :alt: plot visualization utils
   :srcset: /auto_examples/others/images/sphx_glr_plot_visualization_utils_010.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 349-352

The two 'people' masks in the first image where not selected because they have
a lower score than the score threshold. Similarly, in the second image, the
instance with class 15 (which corresponds to 'bench') was not selected.

.. GENERATED FROM PYTHON SOURCE LINES 354-363

.. _keypoint_output:

Visualizing keypoints
------------------------------
The :func:`~torchvision.utils.draw_keypoints` function can be used to
draw keypoints on images. We will see how to use it with
torchvision's KeypointRCNN loaded with :func:`~torchvision.models.detection.keypointrcnn_resnet50_fpn`.
We will first have a look at output of the model.


.. GENERATED FROM PYTHON SOURCE LINES 363-380

.. code-block:: Python


    from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
    from torchvision.io import read_image

    person_int = read_image(str(Path("../assets") / "person1.jpg"))

    weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
    transforms = weights.transforms()

    person_float = transforms(person_int)

    model = keypointrcnn_resnet50_fpn(weights=weights, progress=False)
    model = model.eval()

    outputs = model([person_float])
    print(outputs)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Downloading: "https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth" to /root/.cache/torch/hub/checkpoints/keypointrcnn_resnet50_fpn_coco-fc266e95.pth
    [{'boxes': tensor([[124.3751, 177.9242, 327.6354, 574.7064],
            [124.3625, 180.7574, 290.1061, 390.7958]], grad_fn=<StackBackward0>), 'labels': tensor([1, 1]), 'scores': tensor([0.9998, 0.1070], grad_fn=<IndexBackward0>), 'keypoints': tensor([[[208.0176, 214.2408,   1.0000],
             [208.0176, 207.0375,   1.0000],
             [197.8246, 210.6392,   1.0000],
             [208.0176, 211.8398,   1.0000],
             [178.6378, 217.8425,   1.0000],
             [221.2086, 253.8590,   1.0000],
             [160.6502, 269.4662,   1.0000],
             [243.9929, 304.2822,   1.0000],
             [138.4655, 328.8935,   1.0000],
             [277.5698, 340.8990,   1.0000],
             [153.4551, 374.5144,   1.0000],
             [226.0053, 375.7150,   1.0000],
             [226.0053, 370.3125,   1.0000],
             [221.8082, 455.5516,   1.0000],
             [273.9723, 448.9486,   1.0000],
             [193.6275, 546.1932,   1.0000],
             [273.3727, 545.5930,   1.0000]],

            [[207.8327, 214.6636,   1.0000],
             [207.2343, 207.4622,   1.0000],
             [198.2590, 209.8627,   1.0000],
             [208.4310, 210.4628,   1.0000],
             [178.5134, 218.2642,   1.0000],
             [219.7997, 251.8704,   1.0000],
             [162.3579, 269.2736,   1.0000],
             [245.5288, 304.6800,   1.0000],
             [138.4238, 330.4848,   1.0000],
             [278.4382, 346.0876,   1.0000],
             [153.3826, 374.8929,   1.0000],
             [233.5618, 368.2917,   1.0000],
             [225.7832, 367.6916,   1.0000],
             [289.8069, 357.4897,   1.0000],
             [245.5288, 389.8956,   1.0000],
             [281.4300, 349.0882,   1.0000],
             [209.0294, 389.8956,   1.0000]]], grad_fn=<CopySlices>), 'keypoints_scores': tensor([[16.0163, 16.6672, 15.8312,  4.6510, 14.2053,  8.8280,  9.1136, 12.2084,
             12.1901, 13.8453, 10.7090,  5.5852,  7.5005, 11.3378,  9.3700,  8.2987,
              8.4479],
            [12.9326, 13.8158, 14.9053,  3.9368, 12.9585,  6.4240,  6.8328, 10.4227,
              9.2907, 10.1066, 10.1019,  0.1822,  4.3058, -4.9904, -2.7409, -2.7874,
             -3.9329]], grad_fn=<CopySlices>)}]




.. GENERATED FROM PYTHON SOURCE LINES 381-388

As we see the output contains a list of dictionaries.
The output list is of length batch_size.
We currently have just a single image so length of list is 1.
Each entry in the list corresponds to an input image,
and it is a dict with keys `boxes`, `labels`, `scores`, `keypoints` and `keypoint_scores`.
Each value associated to those keys has `num_instances` elements in it.
In our case above there are 2 instances detected in the image.

.. GENERATED FROM PYTHON SOURCE LINES 388-395

.. code-block:: Python


    kpts = outputs[0]['keypoints']
    scores = outputs[0]['scores']

    print(kpts)
    print(scores)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    tensor([[[208.0176, 214.2408,   1.0000],
             [208.0176, 207.0375,   1.0000],
             [197.8246, 210.6392,   1.0000],
             [208.0176, 211.8398,   1.0000],
             [178.6378, 217.8425,   1.0000],
             [221.2086, 253.8590,   1.0000],
             [160.6502, 269.4662,   1.0000],
             [243.9929, 304.2822,   1.0000],
             [138.4655, 328.8935,   1.0000],
             [277.5698, 340.8990,   1.0000],
             [153.4551, 374.5144,   1.0000],
             [226.0053, 375.7150,   1.0000],
             [226.0053, 370.3125,   1.0000],
             [221.8082, 455.5516,   1.0000],
             [273.9723, 448.9486,   1.0000],
             [193.6275, 546.1932,   1.0000],
             [273.3727, 545.5930,   1.0000]],

            [[207.8327, 214.6636,   1.0000],
             [207.2343, 207.4622,   1.0000],
             [198.2590, 209.8627,   1.0000],
             [208.4310, 210.4628,   1.0000],
             [178.5134, 218.2642,   1.0000],
             [219.7997, 251.8704,   1.0000],
             [162.3579, 269.2736,   1.0000],
             [245.5288, 304.6800,   1.0000],
             [138.4238, 330.4848,   1.0000],
             [278.4382, 346.0876,   1.0000],
             [153.3826, 374.8929,   1.0000],
             [233.5618, 368.2917,   1.0000],
             [225.7832, 367.6916,   1.0000],
             [289.8069, 357.4897,   1.0000],
             [245.5288, 389.8956,   1.0000],
             [281.4300, 349.0882,   1.0000],
             [209.0294, 389.8956,   1.0000]]], grad_fn=<CopySlices>)
    tensor([0.9998, 0.1070], grad_fn=<IndexBackward0>)




.. GENERATED FROM PYTHON SOURCE LINES 396-402

The KeypointRCNN model detects there are two instances in the image.
If you plot the boxes by using :func:`~draw_bounding_boxes`
you would recognize they are the person and the surfboard.
If we look at the scores, we will realize that the model is much more confident about the person than surfboard.
We could now set a threshold confidence and plot instances which we are confident enough.
Let us set a threshold of 0.75 and filter out the keypoints corresponding to the person.

.. GENERATED FROM PYTHON SOURCE LINES 402-409

.. code-block:: Python


    detect_threshold = 0.75
    idx = torch.where(scores > detect_threshold)
    keypoints = kpts[idx]

    print(keypoints)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    tensor([[[208.0176, 214.2408,   1.0000],
             [208.0176, 207.0375,   1.0000],
             [197.8246, 210.6392,   1.0000],
             [208.0176, 211.8398,   1.0000],
             [178.6378, 217.8425,   1.0000],
             [221.2086, 253.8590,   1.0000],
             [160.6502, 269.4662,   1.0000],
             [243.9929, 304.2822,   1.0000],
             [138.4655, 328.8935,   1.0000],
             [277.5698, 340.8990,   1.0000],
             [153.4551, 374.5144,   1.0000],
             [226.0053, 375.7150,   1.0000],
             [226.0053, 370.3125,   1.0000],
             [221.8082, 455.5516,   1.0000],
             [273.9723, 448.9486,   1.0000],
             [193.6275, 546.1932,   1.0000],
             [273.3727, 545.5930,   1.0000]]], grad_fn=<IndexBackward0>)




.. GENERATED FROM PYTHON SOURCE LINES 410-414

Great, now we have the keypoints corresponding to the person.
Each keypoint is represented by x, y coordinates and the visibility.
We can now use the :func:`~torchvision.utils.draw_keypoints` function to draw keypoints.
Note that the utility expects uint8 images.

.. GENERATED FROM PYTHON SOURCE LINES 414-420

.. code-block:: Python


    from torchvision.utils import draw_keypoints

    res = draw_keypoints(person_int, keypoints, colors="blue", radius=3)
    show(res)




.. image-sg:: /auto_examples/others/images/sphx_glr_plot_visualization_utils_011.png
   :alt: plot visualization utils
   :srcset: /auto_examples/others/images/sphx_glr_plot_visualization_utils_011.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 421-423

As we see, the keypoints appear as colored circles over the image.
The coco keypoints for a person are ordered and represent the following list.\

.. GENERATED FROM PYTHON SOURCE LINES 423-431

.. code-block:: Python


    coco_keypoints = [
        "nose", "left_eye", "right_eye", "left_ear", "right_ear",
        "left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
        "left_wrist", "right_wrist", "left_hip", "right_hip",
        "left_knee", "right_knee", "left_ankle", "right_ankle",
    ]








.. GENERATED FROM PYTHON SOURCE LINES 432-451

What if we are interested in joining the keypoints?
This is especially useful in creating pose detection or action recognition.
We can join the keypoints easily using the `connectivity` parameter.
A close observation would reveal that we would need to join the points in below
order to construct human skeleton.

nose -> left_eye -> left_ear.                              (0, 1), (1, 3)

nose -> right_eye -> right_ear.                            (0, 2), (2, 4)

nose -> left_shoulder -> left_elbow -> left_wrist.         (0, 5), (5, 7), (7, 9)

nose -> right_shoulder -> right_elbow -> right_wrist.      (0, 6), (6, 8), (8, 10)

left_shoulder -> left_hip -> left_knee -> left_ankle.      (5, 11), (11, 13), (13, 15)

right_shoulder -> right_hip -> right_knee -> right_ankle.  (6, 12), (12, 14), (14, 16)

We will create a list containing these keypoint ids to be connected.

.. GENERATED FROM PYTHON SOURCE LINES 451-457

.. code-block:: Python


    connect_skeleton = [
        (0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (0, 6), (5, 7), (6, 8),
        (7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14), (13, 15), (14, 16)
    ]








.. GENERATED FROM PYTHON SOURCE LINES 458-460

We pass the above list to the connectivity parameter to connect the keypoints.


.. GENERATED FROM PYTHON SOURCE LINES 460-464

.. code-block:: Python


    res = draw_keypoints(person_int, keypoints, connectivity=connect_skeleton, colors="blue", radius=4, width=3)
    show(res)




.. image-sg:: /auto_examples/others/images/sphx_glr_plot_visualization_utils_012.png
   :alt: plot visualization utils
   :srcset: /auto_examples/others/images/sphx_glr_plot_visualization_utils_012.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 465-472

That looks pretty good.

.. _draw_keypoints_with_visibility:

Drawing Keypoints with Visibility
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Let's have a look at the results, another keypoint prediction module produced, and show the connectivity:

.. GENERATED FROM PYTHON SOURCE LINES 472-496

.. code-block:: Python


    prediction = torch.tensor(
        [[[208.0176, 214.2409, 1.0000],
          [000.0000, 000.0000, 0.0000],
          [197.8246, 210.6392, 1.0000],
          [000.0000, 000.0000, 0.0000],
          [178.6378, 217.8425, 1.0000],
          [221.2086, 253.8591, 1.0000],
          [160.6502, 269.4662, 1.0000],
          [243.9929, 304.2822, 1.0000],
          [138.4654, 328.8935, 1.0000],
          [277.5698, 340.8990, 1.0000],
          [153.4551, 374.5145, 1.0000],
          [000.0000, 000.0000, 0.0000],
          [226.0053, 370.3125, 1.0000],
          [221.8081, 455.5516, 1.0000],
          [273.9723, 448.9486, 1.0000],
          [193.6275, 546.1933, 1.0000],
          [273.3727, 545.5930, 1.0000]]]
    )

    res = draw_keypoints(person_int, prediction, connectivity=connect_skeleton, colors="blue", radius=4, width=3)
    show(res)




.. image-sg:: /auto_examples/others/images/sphx_glr_plot_visualization_utils_013.png
   :alt: plot visualization utils
   :srcset: /auto_examples/others/images/sphx_glr_plot_visualization_utils_013.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 497-507

What happened there?
The model, which predicted the new keypoints,
can't detect the three points that are hidden on the upper left body of the skateboarder.
More precisely, the model predicted that `(x, y, vis) = (0, 0, 0)` for the left_eye, left_ear, and left_hip.
So we definitely don't want to display those keypoints and connections, and you don't have to.
Looking at the parameters of :func:`~torchvision.utils.draw_keypoints`,
we can see that we can pass a visibility tensor as an additional argument.
Given the models' prediction, we have the visibility as the third keypoint dimension, we just need to extract it.
Let's split the ``prediction`` into the keypoint coordinates and their respective visibility,
and pass both of them as arguments to :func:`~torchvision.utils.draw_keypoints`.

.. GENERATED FROM PYTHON SOURCE LINES 507-516

.. code-block:: Python


    coordinates, visibility = prediction.split([2, 1], dim=-1)
    visibility = visibility.bool()

    res = draw_keypoints(
        person_int, coordinates, visibility=visibility, connectivity=connect_skeleton, colors="blue", radius=4, width=3
    )
    show(res)




.. image-sg:: /auto_examples/others/images/sphx_glr_plot_visualization_utils_014.png
   :alt: plot visualization utils
   :srcset: /auto_examples/others/images/sphx_glr_plot_visualization_utils_014.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 517-523

We can see that the undetected keypoints are not draw and the invisible keypoint connections were skipped.
This can reduce the noise on images with multiple detections, or in cases like ours,
when the keypoint-prediction model missed some detections.
Most torch keypoint-prediction models return the visibility for every prediction, ready for you to use it.
The :func:`~torchvision.models.detection.keypointrcnn_resnet50_fpn` model,
which we used in the first case, does so too.


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 14.868 seconds)


.. _sphx_glr_download_auto_examples_others_plot_visualization_utils.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_visualization_utils.ipynb <plot_visualization_utils.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_visualization_utils.py <plot_visualization_utils.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: plot_visualization_utils.zip <plot_visualization_utils.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_