.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_visualization_utils.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_visualization_utils.py: ======================= Visualization utilities ======================= This example illustrates some of the utilities that torchvision offers for visualizing images, bounding boxes, segmentation masks and keypoints. .. GENERATED FROM PYTHON SOURCE LINES 9-33 .. code-block:: default # sphinx_gallery_thumbnail_path = "../../gallery/assets/visualization_utils_thumbnail2.png" import torch import numpy as np import matplotlib.pyplot as plt import torchvision.transforms.functional as F plt.rcParams["savefig.bbox"] = 'tight' def show(imgs): if not isinstance(imgs, list): imgs = [imgs] fix, axs = plt.subplots(ncols=len(imgs), squeeze=False) for i, img in enumerate(imgs): img = img.detach() img = F.to_pil_image(img) axs[0, i].imshow(np.asarray(img)) axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) .. GENERATED FROM PYTHON SOURCE LINES 34-39 Visualizing a grid of images ---------------------------- The :func:`~torchvision.utils.make_grid` function can be used to create a tensor that represents multiple images in a grid. This util requires a single image of dtype ``uint8`` as input. .. GENERATED FROM PYTHON SOURCE LINES 39-50 .. code-block:: default from torchvision.utils import make_grid from torchvision.io import read_image from pathlib import Path dog1_int = read_image(str(Path('assets') / 'dog1.jpg')) dog2_int = read_image(str(Path('assets') / 'dog2.jpg')) grid = make_grid([dog1_int, dog2_int, dog1_int, dog2_int]) show(grid) .. image-sg:: /auto_examples/images/sphx_glr_plot_visualization_utils_001.png :alt: plot visualization utils :srcset: /auto_examples/images/sphx_glr_plot_visualization_utils_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 51-56 Visualizing bounding boxes -------------------------- We can use :func:`~torchvision.utils.draw_bounding_boxes` to draw boxes on an image. We can set the colors, labels, width as well as font and font size. The boxes are in ``(xmin, ymin, xmax, ymax)`` format. .. GENERATED FROM PYTHON SOURCE LINES 56-66 .. code-block:: default from torchvision.utils import draw_bounding_boxes boxes = torch.tensor([[50, 50, 100, 200], [210, 150, 350, 430]], dtype=torch.float) colors = ["blue", "yellow"] result = draw_bounding_boxes(dog1_int, boxes, colors=colors, width=5) show(result) .. image-sg:: /auto_examples/images/sphx_glr_plot_visualization_utils_002.png :alt: plot visualization utils :srcset: /auto_examples/images/sphx_glr_plot_visualization_utils_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 67-75 Naturally, we can also plot bounding boxes produced by torchvision detection models. Here is demo with a Faster R-CNN model loaded from :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` model. You can also try using a RetinaNet with :func:`~torchvision.models.detection.retinanet_resnet50_fpn`, an SSDlite with :func:`~torchvision.models.detection.ssdlite320_mobilenet_v3_large` or an SSD with :func:`~torchvision.models.detection.ssd300_vgg16`. For more details on the output of such models, you may refer to :ref:`instance_seg_output`. .. GENERATED FROM PYTHON SOURCE LINES 75-89 .. code-block:: default from torchvision.models.detection import fasterrcnn_resnet50_fpn from torchvision.transforms.functional import convert_image_dtype batch_int = torch.stack([dog1_int, dog2_int]) batch = convert_image_dtype(batch_int, dtype=torch.float) model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False) model = model.eval() outputs = model(batch) print(outputs) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none [{'boxes': tensor([[215.9767, 171.1661, 402.0078, 378.7391], [344.6341, 172.6735, 357.6114, 220.1435], [153.1306, 185.5568, 172.9223, 254.7014]], grad_fn=), 'labels': tensor([18, 1, 1]), 'scores': tensor([0.9989, 0.0701, 0.0611], grad_fn=)}, {'boxes': tensor([[ 23.5963, 132.4332, 449.9359, 493.0222], [225.8183, 124.6292, 467.2861, 492.2621], [ 18.5249, 135.4171, 420.9786, 479.2226]], grad_fn=), 'labels': tensor([18, 18, 17]), 'scores': tensor([0.9980, 0.0879, 0.0671], grad_fn=)}] .. GENERATED FROM PYTHON SOURCE LINES 90-92 Let's plot the boxes detected by our model. We will only plot the boxes with a score greater than a given threshold. .. GENERATED FROM PYTHON SOURCE LINES 92-100 .. code-block:: default score_threshold = .8 dogs_with_boxes = [ draw_bounding_boxes(dog_int, boxes=output['boxes'][output['scores'] > score_threshold], width=4) for dog_int, output in zip(batch_int, outputs) ] show(dogs_with_boxes) .. image-sg:: /auto_examples/images/sphx_glr_plot_visualization_utils_003.png :alt: plot visualization utils :srcset: /auto_examples/images/sphx_glr_plot_visualization_utils_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 101-122 Visualizing segmentation masks ------------------------------ The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to draw segmentation masks on images. Semantic segmentation and instance segmentation models have different outputs, so we will treat each independently. .. _semantic_seg_output: Semantic segmentation models ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ We will see how to use it with torchvision's FCN Resnet-50, loaded with :func:`~torchvision.models.segmentation.fcn_resnet50`. You can also try using DeepLabv3 (:func:`~torchvision.models.segmentation.deeplabv3_resnet50`) or lraspp mobilenet models (:func:`~torchvision.models.segmentation.lraspp_mobilenet_v3_large`). Let's start by looking at the output of the model. Remember that in general, images must be normalized before they're passed to a semantic segmentation model. .. GENERATED FROM PYTHON SOURCE LINES 122-133 .. code-block:: default from torchvision.models.segmentation import fcn_resnet50 model = fcn_resnet50(pretrained=True, progress=False) model = model.eval() normalized_batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) output = model(normalized_batch)['out'] print(output.shape, output.min().item(), output.max().item()) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Downloading: "https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth" to /root/.cache/torch/hub/checkpoints/fcn_resnet50_coco-1167a1af.pth torch.Size([2, 21, 500, 500]) -7.089669704437256 14.858256340026855 .. GENERATED FROM PYTHON SOURCE LINES 134-142 As we can see above, the output of the segmentation model is a tensor of shape ``(batch_size, num_classes, H, W)``. Each value is a non-normalized score, and we can normalize them into ``[0, 1]`` by using a softmax. After the softmax, we can interpret each value as a probability indicating how likely a given pixel is to belong to a given class. Let's plot the masks that have been detected for the dog class and for the boat class: .. GENERATED FROM PYTHON SOURCE LINES 142-160 .. code-block:: default sem_classes = [ '__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' ] sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)} normalized_masks = torch.nn.functional.softmax(output, dim=1) dog_and_boat_masks = [ normalized_masks[img_idx, sem_class_to_idx[cls]] for img_idx in range(batch.shape[0]) for cls in ('dog', 'boat') ] show(dog_and_boat_masks) .. image-sg:: /auto_examples/images/sphx_glr_plot_visualization_utils_004.png :alt: plot visualization utils :srcset: /auto_examples/images/sphx_glr_plot_visualization_utils_004.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 161-168 As expected, the model is confident about the dog class, but not so much for the boat class. The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to plots those masks on top of the original image. This function expects the masks to be boolean masks, but our masks above contain probabilities in ``[0, 1]``. To get boolean masks, we can do the following: .. GENERATED FROM PYTHON SOURCE LINES 168-175 .. code-block:: default class_dim = 1 boolean_dog_masks = (normalized_masks.argmax(class_dim) == sem_class_to_idx['dog']) print(f"shape = {boolean_dog_masks.shape}, dtype = {boolean_dog_masks.dtype}") show([m.float() for m in boolean_dog_masks]) .. image-sg:: /auto_examples/images/sphx_glr_plot_visualization_utils_005.png :alt: plot visualization utils :srcset: /auto_examples/images/sphx_glr_plot_visualization_utils_005.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none shape = torch.Size([2, 500, 500]), dtype = torch.bool .. GENERATED FROM PYTHON SOURCE LINES 176-188 The line above where we define ``boolean_dog_masks`` is a bit cryptic, but you can read it as the following query: "For which pixels is 'dog' the most likely class?" .. note:: While we're using the ``normalized_masks`` here, we would have gotten the same result by using the non-normalized scores of the model directly (as the softmax operation preserves the order). Now that we have boolean masks, we can use them with :func:`~torchvision.utils.draw_segmentation_masks` to plot them on top of the original images: .. GENERATED FROM PYTHON SOURCE LINES 188-197 .. code-block:: default from torchvision.utils import draw_segmentation_masks dogs_with_masks = [ draw_segmentation_masks(img, masks=mask, alpha=0.7) for img, mask in zip(batch_int, boolean_dog_masks) ] show(dogs_with_masks) .. image-sg:: /auto_examples/images/sphx_glr_plot_visualization_utils_006.png :alt: plot visualization utils :srcset: /auto_examples/images/sphx_glr_plot_visualization_utils_006.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 198-205 We can plot more than one mask per image! Remember that the model returned as many masks as there are classes. Let's ask the same query as above, but this time for *all* classes, not just the dog class: "For each pixel and each class C, is class C the most most likely class?" This one is a bit more involved, so we'll first show how to do it with a single image, and then we'll generalize to the batch .. GENERATED FROM PYTHON SOURCE LINES 205-217 .. code-block:: default num_classes = normalized_masks.shape[1] dog1_masks = normalized_masks[0] class_dim = 0 dog1_all_classes_masks = dog1_masks.argmax(class_dim) == torch.arange(num_classes)[:, None, None] print(f"dog1_masks shape = {dog1_masks.shape}, dtype = {dog1_masks.dtype}") print(f"dog1_all_classes_masks = {dog1_all_classes_masks.shape}, dtype = {dog1_all_classes_masks.dtype}") dog_with_all_masks = draw_segmentation_masks(dog1_int, masks=dog1_all_classes_masks, alpha=.6) show(dog_with_all_masks) .. image-sg:: /auto_examples/images/sphx_glr_plot_visualization_utils_007.png :alt: plot visualization utils :srcset: /auto_examples/images/sphx_glr_plot_visualization_utils_007.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none dog1_masks shape = torch.Size([21, 500, 500]), dtype = torch.float32 dog1_all_classes_masks = torch.Size([21, 500, 500]), dtype = torch.bool .. GENERATED FROM PYTHON SOURCE LINES 218-230 We can see in the image above that only 2 masks were drawn: the mask for the background and the mask for the dog. This is because the model thinks that only these 2 classes are the most likely ones across all the pixels. If the model had detected another class as the most likely among other pixels, we would have seen its mask above. Removing the background mask is as simple as passing ``masks=dog1_all_classes_masks[1:]``, because the background class is the class with index 0. Let's now do the same but for an entire batch of images. The code is similar but involves a bit more juggling with the dimensions. .. GENERATED FROM PYTHON SOURCE LINES 230-244 .. code-block:: default class_dim = 1 all_classes_masks = normalized_masks.argmax(class_dim) == torch.arange(num_classes)[:, None, None, None] print(f"shape = {all_classes_masks.shape}, dtype = {all_classes_masks.dtype}") # The first dimension is the classes now, so we need to swap it all_classes_masks = all_classes_masks.swapaxes(0, 1) dogs_with_masks = [ draw_segmentation_masks(img, masks=mask, alpha=.6) for img, mask in zip(batch_int, all_classes_masks) ] show(dogs_with_masks) .. image-sg:: /auto_examples/images/sphx_glr_plot_visualization_utils_008.png :alt: plot visualization utils :srcset: /auto_examples/images/sphx_glr_plot_visualization_utils_008.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none shape = torch.Size([21, 2, 500, 500]), dtype = torch.bool .. GENERATED FROM PYTHON SOURCE LINES 245-264 .. _instance_seg_output: Instance segmentation models ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Instance segmentation models have a significantly different output from the semantic segmentation models. We will see here how to plot the masks for such models. Let's start by analyzing the output of a Mask-RCNN model. Note that these models don't require the images to be normalized, so we don't need to use the normalized batch. .. note:: We will here describe the output of a Mask-RCNN model. The models in :ref:`object_det_inst_seg_pers_keypoint_det` all have a similar output format, but some of them may have extra info like keypoints for :func:`~torchvision.models.detection.keypointrcnn_resnet50_fpn`, and some of them may not have masks, like :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn`. .. GENERATED FROM PYTHON SOURCE LINES 264-272 .. code-block:: default from torchvision.models.detection import maskrcnn_resnet50_fpn model = maskrcnn_resnet50_fpn(pretrained=True, progress=False) model = model.eval() output = model(batch) print(output) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /root/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth [{'boxes': tensor([[219.7444, 168.1722, 400.7379, 384.0263], [343.9716, 171.2287, 358.3447, 222.6263], [301.0303, 192.6917, 313.8879, 232.3154]], grad_fn=), 'labels': tensor([18, 1, 1]), 'scores': tensor([0.9987, 0.7187, 0.6525], grad_fn=), 'masks': tensor([[[[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]]], [[[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]]], [[[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]]]], grad_fn=)}, {'boxes': tensor([[ 44.6767, 137.9018, 446.5324, 487.3429], [ 0.0000, 288.0053, 489.9293, 490.2352]], grad_fn=), 'labels': tensor([18, 15]), 'scores': tensor([0.9978, 0.0697], grad_fn=), 'masks': tensor([[[[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]]], [[[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]]]], grad_fn=)}] .. GENERATED FROM PYTHON SOURCE LINES 273-289 Let's break this down. For each image in the batch, the model outputs some detections (or instances). The number of detections varies for each input image. Each instance is described by its bounding box, its label, its score and its mask. The way the output is organized is as follows: the output is a list of length ``batch_size``. Each entry in the list corresponds to an input image, and it is a dict with keys 'boxes', 'labels', 'scores', and 'masks'. Each value associated to those keys has ``num_instances`` elements in it. In our case above there are 3 instances detected in the first image, and 2 instances in the second one. The boxes can be plotted with :func:`~torchvision.utils.draw_bounding_boxes` as above, but here we're more interested in the masks. These masks are quite different from the masks that we saw above for the semantic segmentation models. .. GENERATED FROM PYTHON SOURCE LINES 289-295 .. code-block:: default dog1_output = output[0] dog1_masks = dog1_output['masks'] print(f"shape = {dog1_masks.shape}, dtype = {dog1_masks.dtype}, " f"min = {dog1_masks.min()}, max = {dog1_masks.max()}") .. rst-class:: sphx-glr-script-out Out: .. code-block:: none shape = torch.Size([3, 1, 500, 500]), dtype = torch.float32, min = 0.0, max = 0.9999862909317017 .. GENERATED FROM PYTHON SOURCE LINES 296-300 Here the masks corresponds to probabilities indicating, for each pixel, how likely it is to belong to the predicted label of that instance. Those predicted labels correspond to the 'labels' element in the same output dict. Let's see which labels were predicted for the instances of the first image. .. GENERATED FROM PYTHON SOURCE LINES 300-321 .. code-block:: default inst_classes = [ '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' ] inst_class_to_idx = {cls: idx for (idx, cls) in enumerate(inst_classes)} print("For the first dog, the following instances were detected:") print([inst_classes[label] for label in dog1_output['labels']]) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none For the first dog, the following instances were detected: ['dog', 'person', 'person'] .. GENERATED FROM PYTHON SOURCE LINES 322-329 Interestingly, the model detects two persons in the image. Let's go ahead and plot those masks. Since :func:`~torchvision.utils.draw_segmentation_masks` expects boolean masks, we need to convert those probabilities into boolean values. Remember that the semantic of those masks is "How likely is this pixel to belong to the predicted class?". As a result, a natural way of converting those masks into boolean values is to threshold them with the 0.5 probability (one could also choose a different threshold). .. GENERATED FROM PYTHON SOURCE LINES 329-339 .. code-block:: default proba_threshold = 0.5 dog1_bool_masks = dog1_output['masks'] > proba_threshold print(f"shape = {dog1_bool_masks.shape}, dtype = {dog1_bool_masks.dtype}") # There's an extra dimension (1) to the masks. We need to remove it dog1_bool_masks = dog1_bool_masks.squeeze(1) show(draw_segmentation_masks(dog1_int, dog1_bool_masks, alpha=0.9)) .. image-sg:: /auto_examples/images/sphx_glr_plot_visualization_utils_009.png :alt: plot visualization utils :srcset: /auto_examples/images/sphx_glr_plot_visualization_utils_009.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none shape = torch.Size([3, 1, 500, 500]), dtype = torch.bool .. GENERATED FROM PYTHON SOURCE LINES 340-343 The model seems to have properly detected the dog, but it also confused trees with people. Looking more closely at the scores will help us plotting more relevant masks: .. GENERATED FROM PYTHON SOURCE LINES 343-346 .. code-block:: default print(dog1_output['scores']) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none tensor([0.9987, 0.7187, 0.6525], grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 347-351 Clearly the model is more confident about the dog detection than it is about the people detections. That's good news. When plotting the masks, we can ask for only those that have a good score. Let's use a score threshold of .75 here, and also plot the masks of the second dog. .. GENERATED FROM PYTHON SOURCE LINES 351-365 .. code-block:: default score_threshold = .75 boolean_masks = [ out['masks'][out['scores'] > score_threshold] > proba_threshold for out in output ] dogs_with_masks = [ draw_segmentation_masks(img, mask.squeeze(1)) for img, mask in zip(batch_int, boolean_masks) ] show(dogs_with_masks) .. image-sg:: /auto_examples/images/sphx_glr_plot_visualization_utils_010.png :alt: plot visualization utils :srcset: /auto_examples/images/sphx_glr_plot_visualization_utils_010.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 366-369 The two 'people' masks in the first image where not selected because they have a lower score than the score threshold. Similarly in the second image, the instance with class 15 (which corresponds to 'bench') was not selected. .. GENERATED FROM PYTHON SOURCE LINES 371-380 Visualizing keypoints ------------------------------ The :func:`~torchvision.utils.draw_keypoints` function can be used to draw keypoints on images. We will see how to use it with torchvision's KeypointRCNN loaded with :func:`~torchvision.models.detection.keypointrcnn_resnet50_fpn`. We will first have a look at output of the model. Note that the keypoint detection model does not need normalized images. .. GENERATED FROM PYTHON SOURCE LINES 380-393 .. code-block:: default from torchvision.models.detection import keypointrcnn_resnet50_fpn from torchvision.io import read_image person_int = read_image(str(Path("assets") / "person1.jpg")) person_float = convert_image_dtype(person_int, dtype=torch.float) model = keypointrcnn_resnet50_fpn(pretrained=True, progress=False) model = model.eval() outputs = model([person_float]) print(outputs) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Downloading: "https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth" to /root/.cache/torch/hub/checkpoints/keypointrcnn_resnet50_fpn_coco-fc266e95.pth [{'boxes': tensor([[124.3751, 177.9242, 327.6353, 574.7064], [124.3625, 180.7574, 290.1061, 390.7958]], grad_fn=), 'labels': tensor([1, 1]), 'scores': tensor([0.9998, 0.1070], grad_fn=), 'keypoints': tensor([[[208.0176, 214.2408, 1.0000], [208.0176, 207.0375, 1.0000], [197.8246, 210.6392, 1.0000], [208.0176, 211.8398, 1.0000], [178.6378, 217.8425, 1.0000], [221.2085, 253.8590, 1.0000], [160.6502, 269.4662, 1.0000], [243.9929, 304.2822, 1.0000], [138.4654, 328.8935, 1.0000], [277.5698, 340.8990, 1.0000], [153.4551, 374.5144, 1.0000], [226.0052, 375.7150, 1.0000], [226.0052, 370.3125, 1.0000], [221.8081, 455.5516, 1.0000], [273.9723, 448.9486, 1.0000], [193.6275, 546.1932, 1.0000], [273.3727, 545.5930, 1.0000]], [[207.8327, 214.6636, 1.0000], [207.2343, 207.4622, 1.0000], [198.2590, 209.8627, 1.0000], [208.4310, 210.4628, 1.0000], [178.5134, 218.2642, 1.0000], [219.7997, 251.8704, 1.0000], [162.3579, 269.2736, 1.0000], [245.5289, 304.6800, 1.0000], [138.4238, 330.4848, 1.0000], [278.4382, 346.0876, 1.0000], [153.3826, 374.8929, 1.0000], [233.5618, 368.2917, 1.0000], [225.7832, 367.6916, 1.0000], [289.8069, 357.4897, 1.0000], [245.5289, 389.8956, 1.0000], [281.4300, 349.0882, 1.0000], [209.0294, 389.8956, 1.0000]]], grad_fn=), 'keypoints_scores': tensor([[16.0163, 16.6672, 15.8312, 4.6510, 14.2053, 8.8280, 9.1136, 12.2084, 12.1901, 13.8453, 10.7090, 5.5852, 7.5005, 11.3378, 9.3700, 8.2987, 8.4479], [12.9326, 13.8158, 14.9053, 3.9368, 12.9585, 6.4240, 6.8328, 10.4227, 9.2907, 10.1066, 10.1019, 0.1822, 4.3057, -4.9904, -2.7409, -2.7874, -3.9329]], grad_fn=)}] .. GENERATED FROM PYTHON SOURCE LINES 394-401 As we see the output contains a list of dictionaries. The output list is of length batch_size. We currently have just a single image so length of list is 1. Each entry in the list corresponds to an input image, and it is a dict with keys `boxes`, `labels`, `scores`, `keypoints` and `keypoint_scores`. Each value associated to those keys has `num_instances` elements in it. In our case above there are 2 instances detected in the image. .. GENERATED FROM PYTHON SOURCE LINES 401-408 .. code-block:: default kpts = outputs[0]['keypoints'] scores = outputs[0]['scores'] print(kpts) print(scores) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none tensor([[[208.0176, 214.2408, 1.0000], [208.0176, 207.0375, 1.0000], [197.8246, 210.6392, 1.0000], [208.0176, 211.8398, 1.0000], [178.6378, 217.8425, 1.0000], [221.2085, 253.8590, 1.0000], [160.6502, 269.4662, 1.0000], [243.9929, 304.2822, 1.0000], [138.4654, 328.8935, 1.0000], [277.5698, 340.8990, 1.0000], [153.4551, 374.5144, 1.0000], [226.0052, 375.7150, 1.0000], [226.0052, 370.3125, 1.0000], [221.8081, 455.5516, 1.0000], [273.9723, 448.9486, 1.0000], [193.6275, 546.1932, 1.0000], [273.3727, 545.5930, 1.0000]], [[207.8327, 214.6636, 1.0000], [207.2343, 207.4622, 1.0000], [198.2590, 209.8627, 1.0000], [208.4310, 210.4628, 1.0000], [178.5134, 218.2642, 1.0000], [219.7997, 251.8704, 1.0000], [162.3579, 269.2736, 1.0000], [245.5289, 304.6800, 1.0000], [138.4238, 330.4848, 1.0000], [278.4382, 346.0876, 1.0000], [153.3826, 374.8929, 1.0000], [233.5618, 368.2917, 1.0000], [225.7832, 367.6916, 1.0000], [289.8069, 357.4897, 1.0000], [245.5289, 389.8956, 1.0000], [281.4300, 349.0882, 1.0000], [209.0294, 389.8956, 1.0000]]], grad_fn=) tensor([0.9998, 0.1070], grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 409-415 The KeypointRCNN model detects there are two instances in the image. If you plot the boxes by using :func:`~draw_bounding_boxes` you would recognize they are the person and the surfboard. If we look at the scores, we will realize that the model is much more confident about the person than surfboard. We could now set a threshold confidence and plot instances which we are confident enough. Let us set a threshold of 0.75 and filter out the keypoints corresponding to the person. .. GENERATED FROM PYTHON SOURCE LINES 415-422 .. code-block:: default detect_threshold = 0.75 idx = torch.where(scores > detect_threshold) keypoints = kpts[idx] print(keypoints) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none tensor([[[208.0176, 214.2408, 1.0000], [208.0176, 207.0375, 1.0000], [197.8246, 210.6392, 1.0000], [208.0176, 211.8398, 1.0000], [178.6378, 217.8425, 1.0000], [221.2085, 253.8590, 1.0000], [160.6502, 269.4662, 1.0000], [243.9929, 304.2822, 1.0000], [138.4654, 328.8935, 1.0000], [277.5698, 340.8990, 1.0000], [153.4551, 374.5144, 1.0000], [226.0052, 375.7150, 1.0000], [226.0052, 370.3125, 1.0000], [221.8081, 455.5516, 1.0000], [273.9723, 448.9486, 1.0000], [193.6275, 546.1932, 1.0000], [273.3727, 545.5930, 1.0000]]], grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 423-427 Great, now we have the keypoints corresponding to the person. Each keypoint is represented by x, y coordinates and the visibility. We can now use the :func:`~torchvision.utils.draw_keypoints` function to draw keypoints. Note that the utility expects uint8 images. .. GENERATED FROM PYTHON SOURCE LINES 427-433 .. code-block:: default from torchvision.utils import draw_keypoints res = draw_keypoints(person_int, keypoints, colors="blue", radius=3) show(res) .. image-sg:: /auto_examples/images/sphx_glr_plot_visualization_utils_011.png :alt: plot visualization utils :srcset: /auto_examples/images/sphx_glr_plot_visualization_utils_011.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 434-436 As we see the keypoints appear as colored circles over the image. The coco keypoints for a person are ordered and represent the following list.\ .. GENERATED FROM PYTHON SOURCE LINES 436-444 .. code-block:: default coco_keypoints = [ "nose", "left_eye", "right_eye", "left_ear", "right_ear", "left_shoulder", "right_shoulder", "left_elbow", "right_elbow", "left_wrist", "right_wrist", "left_hip", "right_hip", "left_knee", "right_knee", "left_ankle", "right_ankle", ] .. GENERATED FROM PYTHON SOURCE LINES 445-464 What if we are interested in joining the keypoints? This is especially useful in creating pose detection or action recognition. We can join the keypoints easily using the `connectivity` parameter. A close observation would reveal that we would need to join the points in below order to construct human skeleton. nose -> left_eye -> left_ear. (0, 1), (1, 3) nose -> right_eye -> right_ear. (0, 2), (2, 4) nose -> left_shoulder -> left_elbow -> left_wrist. (0, 5), (5, 7), (7, 9) nose -> right_shoulder -> right_elbow -> right_wrist. (0, 6), (6, 8), (8, 10) left_shoulder -> left_hip -> left_knee -> left_ankle. (5, 11), (11, 13), (13, 15) right_shoulder -> right_hip -> right_knee -> right_ankle. (6, 12), (12, 14), (14, 16) We will create a list containing these keypoint ids to be connected. .. GENERATED FROM PYTHON SOURCE LINES 464-470 .. code-block:: default connect_skeleton = [ (0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (0, 6), (5, 7), (6, 8), (7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14), (13, 15), (14, 16) ] .. GENERATED FROM PYTHON SOURCE LINES 471-473 We pass the above list to the connectivity parameter to connect the keypoints. .. GENERATED FROM PYTHON SOURCE LINES 473-476 .. code-block:: default res = draw_keypoints(person_int, keypoints, connectivity=connect_skeleton, colors="blue", radius=4, width=3) show(res) .. image-sg:: /auto_examples/images/sphx_glr_plot_visualization_utils_012.png :alt: plot visualization utils :srcset: /auto_examples/images/sphx_glr_plot_visualization_utils_012.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 11.766 seconds) .. _sphx_glr_download_auto_examples_plot_visualization_utils.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_visualization_utils.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_visualization_utils.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_