Note
Go to the end to download the full example code
Getting started with transforms v2¶
Most computer vision tasks are not supported out of the box by torchvision.transforms
v1, since it only supports
images. torchvision.transforms.v2
enables jointly transforming images, videos, bounding boxes, and masks. This
example showcases the core functionality of the new torchvision.transforms.v2
API.
import pathlib
import torch
import torchvision
def load_data():
from torchvision.io import read_image
from torchvision import datapoints
from torchvision.ops import masks_to_boxes
assets_directory = pathlib.Path("assets")
path = assets_directory / "FudanPed00054.png"
image = datapoints.Image(read_image(str(path)))
merged_masks = read_image(str(assets_directory / "FudanPed00054_mask.png"))
labels = torch.unique(merged_masks)[1:]
masks = datapoints.Mask(merged_masks == labels.view(-1, 1, 1))
bounding_boxes = datapoints.BoundingBox(
masks_to_boxes(masks), format=datapoints.BoundingBoxFormat.XYXY, spatial_size=image.shape[-2:]
)
return path, image, bounding_boxes, masks, labels
The torchvision.transforms.v2
API supports images, videos, bounding boxes, and instance and segmentation
masks. Thus, it offers native support for many Computer Vision tasks, like image and video classification, object
detection or instance and semantic segmentation. Still, the interface is the same, making
torchvision.transforms.v2
a drop-in replacement for the existing torchvision.transforms
API, aka v1.
# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
# some APIs may slightly change in the future
torchvision.disable_beta_transforms_warning()
import torchvision.transforms.v2 as transforms
transform = transforms.Compose(
[
transforms.ColorJitter(contrast=0.5),
transforms.RandomRotation(30),
transforms.CenterCrop(480),
]
)
torchvision.transforms.v2
natively supports jointly transforming multiple inputs while making sure that
potential random behavior is consistent across all inputs. However, it doesn’t enforce a specific input structure or
order.
path, image, bounding_boxes, masks, labels = load_data()
torch.manual_seed(0)
new_image = transform(image) # Image Classification
new_image, new_bounding_boxes, new_labels = transform(image, bounding_boxes, labels) # Object Detection
new_image, new_bounding_boxes, new_masks, new_labels = transform(
image, bounding_boxes, masks, labels
) # Instance Segmentation
new_image, new_target = transform((image, {"boxes": bounding_boxes, "labels": labels})) # Arbitrary Structure
Under the hood, torchvision.transforms.v2
relies on torchvision.datapoints
for the dispatch to the
appropriate function for the input data: Datapoints FAQ. Note however, that as
regular user, you likely don’t have to touch this yourself. See
Transforms v2: End-to-end object detection example.
All “foreign” types like str
’s or pathlib.Path
’s are passed through, allowing to store extra
information directly with the sample:
sample = {"path": path, "image": image}
new_sample = transform(sample)
assert new_sample["path"] is sample["path"]
As stated above, torchvision.transforms.v2
is a drop-in replacement for torchvision.transforms
and thus
also supports transforming plain torch.Tensor
’s as image or video if applicable. This is achieved with a
simple heuristic:
If we find an explicit image or video (
torchvision.datapoints.Image
,torchvision.datapoints.Video
, orPIL.Image.Image
) in the input, all other plain tensors are passed through.If there is no explicit image or video, only the first plain
torch.Tensor
will be transformed as image or video, while all others will be passed through.
plain_tensor_image = torch.rand(image.shape)
print(image.shape, plain_tensor_image.shape)
# passing a plain tensor together with an explicit image, will not transform the former
plain_tensor_image, image = transform(plain_tensor_image, image)
print(image.shape, plain_tensor_image.shape)
# passing a plain tensor without an explicit image, will transform the former
plain_tensor_image, _ = transform(plain_tensor_image, bounding_boxes)
print(image.shape, plain_tensor_image.shape)
torch.Size([3, 498, 533]) torch.Size([3, 498, 533])
torch.Size([3, 480, 480]) torch.Size([3, 498, 533])
torch.Size([3, 480, 480]) torch.Size([3, 480, 480])
Total running time of the script: ( 0 minutes 0.091 seconds)