• Tutorials >
  • TorchVision Object Detection Finetuning Tutorial
Shortcuts

TorchVision Object Detection Finetuning Tutorial

For this tutorial, we will be finetuning a pre-trained Mask R-CNN model on the Penn-Fudan Database for Pedestrian Detection and Segmentation. It contains 170 images with 345 instances of pedestrians, and we will use it to illustrate how to use the new features in torchvision in order to train an object detection and instance segmentation model on a custom dataset.

Note

This tutorial works only with torchvision version >=0.16 or nightly. If you’re using torchvision<=0.15, please follow this tutorial instead.

Defining the Dataset

The reference scripts for training object detection, instance segmentation and person keypoint detection allows for easily supporting adding new custom datasets. The dataset should inherit from the standard torch.utils.data.Dataset class, and implement __len__ and __getitem__.

The only specificity that we require is that the dataset __getitem__ should return a tuple:

  • image: torchvision.tv_tensors.Image of shape [3, H, W], a pure tensor, or a PIL Image of size (H, W)

  • target: a dict containing the following fields

    • boxes, torchvision.tv_tensors.BoundingBoxes of shape [N, 4]: the coordinates of the N bounding boxes in [x0, y0, x1, y1] format, ranging from 0 to W and 0 to H

    • labels, integer torch.Tensor of shape [N]: the label for each bounding box. 0 represents always the background class.

    • image_id, int: an image identifier. It should be unique between all the images in the dataset, and is used during evaluation

    • area, float torch.Tensor of shape [N]: the area of the bounding box. This is used during evaluation with the COCO metric, to separate the metric scores between small, medium and large boxes.

    • iscrowd, uint8 torch.Tensor of shape [N]: instances with iscrowd=True will be ignored during evaluation.

    • (optionally) masks, torchvision.tv_tensors.Mask of shape [N, H, W]: the segmentation masks for each one of the objects

If your dataset is compliant with above requirements then it will work for both training and evaluation codes from the reference script. Evaluation code will use scripts from pycocotools which can be installed with pip install pycocotools.

Note

For Windows, please install pycocotools from gautamchitnis with command

pip install git+https://github.com/gautamchitnis/cocoapi.git@cocodataset-master#subdirectory=PythonAPI

One note on the labels. The model considers class 0 as background. If your dataset does not contain the background class, you should not have 0 in your labels. For example, assuming you have just two classes, cat and dog, you can define 1 (not 0) to represent cats and 2 to represent dogs. So, for instance, if one of the images has both classes, your labels tensor should look like [1, 2].

Additionally, if you want to use aspect ratio grouping during training (so that each batch only contains images with similar aspect ratios), then it is recommended to also implement a get_height_and_width method, which returns the height and the width of the image. If this method is not provided, we query all elements of the dataset via __getitem__ , which loads the image in memory and is slower than if a custom method is provided.

Writing a custom dataset for PennFudan

Let’s write a dataset for the PennFudan dataset. First, let’s download the dataset and extract the zip file:

wget https://www.cis.upenn.edu/~jshi/ped_html/PennFudanPed.zip -P data
cd data && unzip PennFudanPed.zip

We have the following folder structure:

PennFudanPed/
  PedMasks/
    FudanPed00001_mask.png
    FudanPed00002_mask.png
    FudanPed00003_mask.png
    FudanPed00004_mask.png
    ...
  PNGImages/
    FudanPed00001.png
    FudanPed00002.png
    FudanPed00003.png
    FudanPed00004.png

Here is one example of a pair of images and segmentation masks

import matplotlib.pyplot as plt
from torchvision.io import read_image


image = read_image("data/PennFudanPed/PNGImages/FudanPed00046.png")
mask = read_image("data/PennFudanPed/PedMasks/FudanPed00046_mask.png")

plt.figure(figsize=(16, 8))
plt.subplot(121)
plt.title("Image")
plt.imshow(image.permute(1, 2, 0))
plt.subplot(122)
plt.title("Mask")
plt.imshow(mask.permute(1, 2, 0))
Image, Mask
<matplotlib.image.AxesImage object at 0x7f6992e19360>

So each image has a corresponding segmentation mask, where each color correspond to a different instance. Let’s write a torch.utils.data.Dataset class for this dataset. In the code below, we are wrapping images, bounding boxes and masks into torchvision.tv_tensors.TVTensor classes so that we will be able to apply torchvision built-in transformations (new Transforms API) for the given object detection and segmentation task. Namely, image tensors will be wrapped by torchvision.tv_tensors.Image, bounding boxes into torchvision.tv_tensors.BoundingBoxes and masks into torchvision.tv_tensors.Mask. As torchvision.tv_tensors.TVTensor are torch.Tensor subclasses, wrapped objects are also tensors and inherit the plain torch.Tensor API. For more information about torchvision tv_tensors see this documentation.

import os
import torch

from torchvision.io import read_image
from torchvision.ops.boxes import masks_to_boxes
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F


class PennFudanDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to
        # ensure that they are aligned
        self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
        self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))

    def __getitem__(self, idx):
        # load images and masks
        img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
        mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
        img = read_image(img_path)
        mask = read_image(mask_path)
        # instances are encoded as different colors
        obj_ids = torch.unique(mask)
        # first id is the background, so remove it
        obj_ids = obj_ids[1:]
        num_objs = len(obj_ids)

        # split the color-encoded mask into a set
        # of binary masks
        masks = (mask == obj_ids[:, None, None]).to(dtype=torch.uint8)

        # get bounding box coordinates for each mask
        boxes = masks_to_boxes(masks)

        # there is only one class
        labels = torch.ones((num_objs,), dtype=torch.int64)

        image_id = idx
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        # Wrap sample and targets into torchvision tv_tensors:
        img = tv_tensors.Image(img)

        target = {}
        target["boxes"] = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=F.get_size(img))
        target["masks"] = tv_tensors.Mask(masks)
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)

That’s all for the dataset. Now let’s define a model that can perform predictions on this dataset.

Defining your model

In this tutorial, we will be using Mask R-CNN, which is based on top of Faster R-CNN. Faster R-CNN is a model that predicts both bounding boxes and class scores for potential objects in the image.

../_static/img/tv_tutorial/tv_image03.png

Mask R-CNN adds an extra branch into Faster R-CNN, which also predicts segmentation masks for each instance.

../_static/img/tv_tutorial/tv_image04.png

There are two common situations where one might want to modify one of the available models in TorchVision Model Zoo. The first is when we want to start from a pre-trained model, and just finetune the last layer. The other is when we want to replace the backbone of the model with a different one (for faster predictions, for example).

Let’s go see how we would do one or another in the following sections.

1 - Finetuning from a pretrained model

Let’s suppose that you want to start from a model pre-trained on COCO and want to finetune it for your particular classes. Here is a possible way of doing it:

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# load a model pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")

# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 2  # 1 class (person) + background
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /var/lib/jenkins/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth

  0%|          | 0.00/160M [00:00<?, ?B/s]
 21%|##1       | 33.8M/160M [00:00<00:00, 354MB/s]
 42%|####2     | 67.6M/160M [00:00<00:00, 354MB/s]
 64%|######3   | 102M/160M [00:00<00:00, 357MB/s]
 85%|########5 | 136M/160M [00:00<00:00, 295MB/s]
100%|##########| 160M/160M [00:00<00:00, 302MB/s]

2 - Modifying the model to add a different backbone

import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

# load a pre-trained model for classification and return
# only the features
backbone = torchvision.models.mobilenet_v2(weights="DEFAULT").features
# ``FasterRCNN`` needs to know the number of
# output channels in a backbone. For mobilenet_v2, it's 1280
# so we need to add it here
backbone.out_channels = 1280

# let's make the RPN generate 5 x 3 anchors per spatial
# location, with 5 different sizes and 3 different aspect
# ratios. We have a Tuple[Tuple[int]] because each feature
# map could potentially have different sizes and
# aspect ratios
anchor_generator = AnchorGenerator(
    sizes=((32, 64, 128, 256, 512),),
    aspect_ratios=((0.5, 1.0, 2.0),)
)

# let's define what are the feature maps that we will
# use to perform the region of interest cropping, as well as
# the size of the crop after rescaling.
# if your backbone returns a Tensor, featmap_names is expected to
# be [0]. More generally, the backbone should return an
# ``OrderedDict[Tensor]``, and in ``featmap_names`` you can choose which
# feature maps to use.
roi_pooler = torchvision.ops.MultiScaleRoIAlign(
    featmap_names=['0'],
    output_size=7,
    sampling_ratio=2
)

# put the pieces together inside a Faster-RCNN model
model = FasterRCNN(
    backbone,
    num_classes=2,
    rpn_anchor_generator=anchor_generator,
    box_roi_pool=roi_pooler
)
Downloading: "https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth" to /var/lib/jenkins/.cache/torch/hub/checkpoints/mobilenet_v2-7ebf99e0.pth

  0%|          | 0.00/13.6M [00:00<?, ?B/s]
100%|##########| 13.6M/13.6M [00:00<00:00, 327MB/s]

Object detection and instance segmentation model for PennFudan Dataset

In our case, we want to finetune from a pre-trained model, given that our dataset is very small, so we will be following approach number 1.

Here we want to also compute the instance segmentation masks, so we will be using Mask R-CNN:

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor


def get_model_instance_segmentation(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(
        in_features_mask,
        hidden_layer,
        num_classes
    )

    return model

That’s it, this will make model be ready to be trained and evaluated on your custom dataset.

Putting everything together

In references/detection/, we have a number of helper functions to simplify training and evaluating detection models. Here, we will use references/detection/engine.py and references/detection/utils.py. Just download everything under references/detection to your folder and use them here. On Linux if you have wget, you can download them using below commands:

os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/engine.py")
os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/utils.py")
os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/coco_utils.py")
os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/coco_eval.py")
os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/transforms.py")
0

Since v0.15.0 torchvision provides new Transforms API to easily write data augmentation pipelines for Object Detection and Segmentation tasks.

Let’s write some helper functions for data augmentation / transformation:

from torchvision.transforms import v2 as T


def get_transform(train):
    transforms = []
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    transforms.append(T.ToDtype(torch.float, scale=True))
    transforms.append(T.ToPureTensor())
    return T.Compose(transforms)

Testing forward() method (Optional)

Before iterating over the dataset, it’s good to see what the model expects during training and inference time on sample data.

import utils


model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
dataset = PennFudanDataset('data/PennFudanPed', get_transform(train=True))
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    num_workers=4,
    collate_fn=utils.collate_fn
)

# For Training
images, targets = next(iter(data_loader))
images = list(image for image in images)
targets = [{k: v for k, v in t.items()} for t in targets]
output = model(images, targets)  # Returns losses and detections
print(output)

# For inference
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = model(x)  # Returns predictions
print(predictions[0])
{'loss_classifier': tensor(0.0689, grad_fn=<NllLossBackward0>), 'loss_box_reg': tensor(0.0268, grad_fn=<DivBackward0>), 'loss_objectness': tensor(0.0055, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'loss_rpn_box_reg': tensor(0.0036, grad_fn=<DivBackward0>)}
{'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward0>), 'labels': tensor([], dtype=torch.int64), 'scores': tensor([], grad_fn=<IndexBackward0>)}

Let’s now write the main function which performs the training and the validation:

from engine import train_one_epoch, evaluate

# train on the GPU or on the CPU, if a GPU is not available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# our dataset has two classes only - background and person
num_classes = 2
# use our dataset and defined transformations
dataset = PennFudanDataset('data/PennFudanPed', get_transform(train=True))
dataset_test = PennFudanDataset('data/PennFudanPed', get_transform(train=False))

# split the dataset in train and test set
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-50])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])

# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    num_workers=4,
    collate_fn=utils.collate_fn
)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=1,
    shuffle=False,
    num_workers=4,
    collate_fn=utils.collate_fn
)

# get the model using our helper function
model = get_model_instance_segmentation(num_classes)

# move model to the right device
model.to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params,
    lr=0.005,
    momentum=0.9,
    weight_decay=0.0005
)

# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=3,
    gamma=0.1
)

# let's train it just for 2 epochs
num_epochs = 2

for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
    evaluate(model, data_loader_test, device=device)

print("That's it!")
Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /var/lib/jenkins/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth

  0%|          | 0.00/170M [00:00<?, ?B/s]
 21%|##        | 35.3M/170M [00:00<00:00, 370MB/s]
 42%|####1     | 71.2M/170M [00:00<00:00, 374MB/s]
 63%|######3   | 107M/170M [00:00<00:00, 376MB/s]
 84%|########4 | 143M/170M [00:00<00:00, 377MB/s]
100%|##########| 170M/170M [00:00<00:00, 375MB/s]
Epoch: [0]  [ 0/60]  eta: 0:00:55  lr: 0.000090  loss: 3.8920 (3.8920)  loss_classifier: 0.4864 (0.4864)  loss_box_reg: 0.2539 (0.2539)  loss_mask: 3.1356 (3.1356)  loss_objectness: 0.0107 (0.0107)  loss_rpn_box_reg: 0.0055 (0.0055)  time: 0.9283  data: 0.2106  max mem: 4549
Epoch: [0]  [10/60]  eta: 0:00:18  lr: 0.000936  loss: 1.6996 (2.3417)  loss_classifier: 0.3916 (0.3631)  loss_box_reg: 0.2626 (0.2681)  loss_mask: 1.0991 (1.6869)  loss_objectness: 0.0139 (0.0193)  loss_rpn_box_reg: 0.0049 (0.0043)  time: 0.3739  data: 0.0245  max mem: 4549
Epoch: [0]  [20/60]  eta: 0:00:13  lr: 0.001783  loss: 0.9805 (1.5784)  loss_classifier: 0.2409 (0.2750)  loss_box_reg: 0.2626 (0.2757)  loss_mask: 0.3496 (1.0027)  loss_objectness: 0.0139 (0.0179)  loss_rpn_box_reg: 0.0051 (0.0071)  time: 0.3053  data: 0.0052  max mem: 4549
Epoch: [0]  [30/60]  eta: 0:00:09  lr: 0.002629  loss: 0.6030 (1.2400)  loss_classifier: 0.1016 (0.2110)  loss_box_reg: 0.2572 (0.2585)  loss_mask: 0.2146 (0.7479)  loss_objectness: 0.0093 (0.0155)  loss_rpn_box_reg: 0.0056 (0.0070)  time: 0.2856  data: 0.0046  max mem: 4549
Epoch: [0]  [40/60]  eta: 0:00:05  lr: 0.003476  loss: 0.5206 (1.0532)  loss_classifier: 0.0699 (0.1749)  loss_box_reg: 0.2426 (0.2513)  loss_mask: 0.1821 (0.6071)  loss_objectness: 0.0049 (0.0128)  loss_rpn_box_reg: 0.0049 (0.0071)  time: 0.2578  data: 0.0046  max mem: 4549
Epoch: [0]  [50/60]  eta: 0:00:02  lr: 0.004323  loss: 0.3508 (0.9180)  loss_classifier: 0.0386 (0.1485)  loss_box_reg: 0.1586 (0.2311)  loss_mask: 0.1575 (0.5210)  loss_objectness: 0.0016 (0.0105)  loss_rpn_box_reg: 0.0047 (0.0069)  time: 0.2390  data: 0.0046  max mem: 4549
Epoch: [0]  [59/60]  eta: 0:00:00  lr: 0.005000  loss: 0.3382 (0.8374)  loss_classifier: 0.0380 (0.1339)  loss_box_reg: 0.1350 (0.2182)  loss_mask: 0.1587 (0.4690)  loss_objectness: 0.0012 (0.0092)  loss_rpn_box_reg: 0.0048 (0.0071)  time: 0.2443  data: 0.0046  max mem: 4549
Epoch: [0] Total time: 0:00:16 (0.2824 s / it)
creating index...
index created!
Test:  [ 0/50]  eta: 0:00:19  model_time: 0.1739 (0.1739)  evaluator_time: 0.0047 (0.0047)  time: 0.3844  data: 0.2051  max mem: 4549
Test:  [49/50]  eta: 0:00:00  model_time: 0.0449 (0.0935)  evaluator_time: 0.0043 (0.0059)  time: 0.0753  data: 0.0025  max mem: 4549
Test: Total time: 0:00:05 (0.1078 s / it)
Averaged stats: model_time: 0.0449 (0.0935)  evaluator_time: 0.0043 (0.0059)
Accumulating evaluation results...
DONE (t=0.01s).
Accumulating evaluation results...
DONE (t=0.01s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.701
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.972
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.894
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.322
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.630
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.723
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.318
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.750
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.750
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.400
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.736
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.762
IoU metric: segm
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.712
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.978
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.912
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.342
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.339
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.732
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.322
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.751
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.751
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.633
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.727
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.757
Epoch: [1]  [ 0/60]  eta: 0:00:28  lr: 0.005000  loss: 0.3231 (0.3231)  loss_classifier: 0.0407 (0.0407)  loss_box_reg: 0.1253 (0.1253)  loss_mask: 0.1402 (0.1402)  loss_objectness: 0.0079 (0.0079)  loss_rpn_box_reg: 0.0090 (0.0090)  time: 0.4709  data: 0.2407  max mem: 4549
Epoch: [1]  [10/60]  eta: 0:00:14  lr: 0.005000  loss: 0.3231 (0.3219)  loss_classifier: 0.0430 (0.0413)  loss_box_reg: 0.0950 (0.1036)  loss_mask: 0.1607 (0.1682)  loss_objectness: 0.0014 (0.0019)  loss_rpn_box_reg: 0.0063 (0.0069)  time: 0.2994  data: 0.0263  max mem: 4549
Epoch: [1]  [20/60]  eta: 0:00:10  lr: 0.005000  loss: 0.2995 (0.2890)  loss_classifier: 0.0395 (0.0375)  loss_box_reg: 0.0869 (0.0926)  loss_mask: 0.1431 (0.1517)  loss_objectness: 0.0010 (0.0016)  loss_rpn_box_reg: 0.0049 (0.0055)  time: 0.2637  data: 0.0048  max mem: 4549
Epoch: [1]  [30/60]  eta: 0:00:07  lr: 0.005000  loss: 0.2444 (0.2793)  loss_classifier: 0.0256 (0.0358)  loss_box_reg: 0.0786 (0.0883)  loss_mask: 0.1301 (0.1479)  loss_objectness: 0.0010 (0.0017)  loss_rpn_box_reg: 0.0026 (0.0057)  time: 0.2317  data: 0.0047  max mem: 4549
Epoch: [1]  [40/60]  eta: 0:00:05  lr: 0.005000  loss: 0.2498 (0.2748)  loss_classifier: 0.0263 (0.0354)  loss_box_reg: 0.0763 (0.0837)  loss_mask: 0.1430 (0.1488)  loss_objectness: 0.0011 (0.0017)  loss_rpn_box_reg: 0.0028 (0.0052)  time: 0.2294  data: 0.0047  max mem: 4549
Epoch: [1]  [50/60]  eta: 0:00:02  lr: 0.005000  loss: 0.2789 (0.2802)  loss_classifier: 0.0357 (0.0365)  loss_box_reg: 0.0763 (0.0854)  loss_mask: 0.1626 (0.1512)  loss_objectness: 0.0011 (0.0017)  loss_rpn_box_reg: 0.0030 (0.0053)  time: 0.2396  data: 0.0047  max mem: 4549
Epoch: [1]  [59/60]  eta: 0:00:00  lr: 0.005000  loss: 0.2860 (0.2806)  loss_classifier: 0.0406 (0.0370)  loss_box_reg: 0.0897 (0.0853)  loss_mask: 0.1563 (0.1513)  loss_objectness: 0.0008 (0.0016)  loss_rpn_box_reg: 0.0044 (0.0054)  time: 0.2477  data: 0.0046  max mem: 4549
Epoch: [1] Total time: 0:00:15 (0.2504 s / it)
creating index...
index created!
Test:  [ 0/50]  eta: 0:00:14  model_time: 0.0583 (0.0583)  evaluator_time: 0.0044 (0.0044)  time: 0.2858  data: 0.2224  max mem: 4549
Test:  [49/50]  eta: 0:00:00  model_time: 0.0510 (0.0525)  evaluator_time: 0.0036 (0.0048)  time: 0.0595  data: 0.0024  max mem: 4549
Test: Total time: 0:00:03 (0.0660 s / it)
Averaged stats: model_time: 0.0510 (0.0525)  evaluator_time: 0.0036 (0.0048)
Accumulating evaluation results...
DONE (t=0.01s).
Accumulating evaluation results...
DONE (t=0.01s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.789
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.980
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.952
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.357
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.653
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.811
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.364
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.827
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.827
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.467
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.764
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.844
IoU metric: segm
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.741
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.982
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.917
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.330
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.495
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.769
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.344
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.778
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.778
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.500
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.673
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.797
That's it!

So after one epoch of training, we obtain a COCO-style mAP > 50, and a mask mAP of 65.

But what do the predictions look like? Let’s take one image in the dataset and verify

import matplotlib.pyplot as plt

from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks


image = read_image("data/PennFudanPed/PNGImages/FudanPed00046.png")
eval_transform = get_transform(train=False)

model.eval()
with torch.no_grad():
    x = eval_transform(image)
    # convert RGBA -> RGB and move to device
    x = x[:3, ...].to(device)
    predictions = model([x, ])
    pred = predictions[0]


image = (255.0 * (image - image.min()) / (image.max() - image.min())).to(torch.uint8)
image = image[:3, ...]
pred_labels = [f"pedestrian: {score:.3f}" for label, score in zip(pred["labels"], pred["scores"])]
pred_boxes = pred["boxes"].long()
output_image = draw_bounding_boxes(image, pred_boxes, pred_labels, colors="red")

masks = (pred["masks"] > 0.7).squeeze(1)
output_image = draw_segmentation_masks(output_image, masks, alpha=0.5, colors="blue")


plt.figure(figsize=(12, 12))
plt.imshow(output_image.permute(1, 2, 0))
torchvision tutorial
<matplotlib.image.AxesImage object at 0x7f69ab3eee60>

The results look good!

Wrapping up

In this tutorial, you have learned how to create your own training pipeline for object detection models on a custom dataset. For that, you wrote a torch.utils.data.Dataset class that returns the images and the ground truth boxes and segmentation masks. You also leveraged a Mask R-CNN model pre-trained on COCO train2017 in order to perform transfer learning on this new dataset.

For a more complete example, which includes multi-machine / multi-GPU training, check references/detection/train.py, which is present in the torchvision repository.

Total running time of the script: ( 0 minutes 54.648 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources