Shortcuts

Getting started with transforms v2

Most computer vision tasks are not supported out of the box by torchvision.transforms v1, since it only supports images. torchvision.transforms.v2 enables jointly transforming images, videos, bounding boxes, and masks. This example showcases the core functionality of the new torchvision.transforms.v2 API.

import pathlib

import torch
import torchvision


def load_data():
    from torchvision.io import read_image
    from torchvision import datapoints
    from torchvision.ops import masks_to_boxes

    assets_directory = pathlib.Path("assets")

    path = assets_directory / "FudanPed00054.png"
    image = datapoints.Image(read_image(str(path)))
    merged_masks = read_image(str(assets_directory / "FudanPed00054_mask.png"))

    labels = torch.unique(merged_masks)[1:]

    masks = datapoints.Mask(merged_masks == labels.view(-1, 1, 1))

    bounding_boxes = datapoints.BoundingBox(
        masks_to_boxes(masks), format=datapoints.BoundingBoxFormat.XYXY, spatial_size=image.shape[-2:]
    )

    return path, image, bounding_boxes, masks, labels

The torchvision.transforms.v2 API supports images, videos, bounding boxes, and instance and segmentation masks. Thus, it offers native support for many Computer Vision tasks, like image and video classification, object detection or instance and semantic segmentation. Still, the interface is the same, making torchvision.transforms.v2 a drop-in replacement for the existing torchvision.transforms API, aka v1.

# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
# some APIs may slightly change in the future
torchvision.disable_beta_transforms_warning()
import torchvision.transforms.v2 as transforms

transform = transforms.Compose(
    [
        transforms.ColorJitter(contrast=0.5),
        transforms.RandomRotation(30),
        transforms.CenterCrop(480),
    ]
)

torchvision.transforms.v2 natively supports jointly transforming multiple inputs while making sure that potential random behavior is consistent across all inputs. However, it doesn’t enforce a specific input structure or order.

path, image, bounding_boxes, masks, labels = load_data()

torch.manual_seed(0)
new_image = transform(image)  # Image Classification
new_image, new_bounding_boxes, new_labels = transform(image, bounding_boxes, labels)  # Object Detection
new_image, new_bounding_boxes, new_masks, new_labels = transform(
    image, bounding_boxes, masks, labels
)  # Instance Segmentation
new_image, new_target = transform((image, {"boxes": bounding_boxes, "labels": labels}))  # Arbitrary Structure

Under the hood, torchvision.transforms.v2 relies on torchvision.datapoints for the dispatch to the appropriate function for the input data: Datapoints FAQ. Note however, that as regular user, you likely don’t have to touch this yourself. See Transforms v2: End-to-end object detection example.

All “foreign” types like str’s or pathlib.Path’s are passed through, allowing to store extra information directly with the sample:

sample = {"path": path, "image": image}
new_sample = transform(sample)

assert new_sample["path"] is sample["path"]

As stated above, torchvision.transforms.v2 is a drop-in replacement for torchvision.transforms and thus also supports transforming plain torch.Tensor’s as image or video if applicable. This is achieved with a simple heuristic:

plain_tensor_image = torch.rand(image.shape)

print(image.shape, plain_tensor_image.shape)

# passing a plain tensor together with an explicit image, will not transform the former
plain_tensor_image, image = transform(plain_tensor_image, image)

print(image.shape, plain_tensor_image.shape)

# passing a plain tensor without an explicit image, will transform the former
plain_tensor_image, _ = transform(plain_tensor_image, bounding_boxes)

print(image.shape, plain_tensor_image.shape)
torch.Size([3, 498, 533]) torch.Size([3, 498, 533])
torch.Size([3, 480, 480]) torch.Size([3, 498, 533])
torch.Size([3, 480, 480]) torch.Size([3, 480, 480])

Total running time of the script: ( 0 minutes 0.091 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources