Shortcuts

Datapoints FAQ

The torchvision.datapoints namespace was introduced together with torchvision.transforms.v2. This example showcases what these datapoints are and how they behave. This is a fairly low-level topic that most users will not need to worry about: you do not need to understand the internals of datapoints to efficiently rely on torchvision.transforms.v2. It may however be useful for advanced users trying to implement their own datasets, transforms, or work directly with the datapoints.

import PIL.Image

import torch
import torchvision

# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
# some APIs may slightly change in the future
torchvision.disable_beta_transforms_warning()

from torchvision import datapoints

What are datapoints?

Datapoints are zero-copy tensor subclasses:

tensor = torch.rand(3, 256, 256)
image = datapoints.Image(tensor)

assert isinstance(image, torch.Tensor)
assert image.data_ptr() == tensor.data_ptr()

Under the hood, they are needed in torchvision.transforms.v2 to correctly dispatch to the appropriate function for the input data.

What datapoints are supported?

So far torchvision.datapoints supports four types of datapoints:

How do I construct a datapoint?

Each datapoint class takes any tensor-like data that can be turned into a Tensor

image = datapoints.Image([[[[0, 1], [1, 0]]]])
print(image)
Image([[[[0, 1],
         [1, 0]]]], )

Similar to other PyTorch creations ops, the constructor also takes the dtype, device, and requires_grad parameters.

float_image = datapoints.Image([[[0, 1], [1, 0]]], dtype=torch.float32, requires_grad=True)
print(float_image)
Image([[[0., 1.],
        [1., 0.]]], grad_fn=<AliasBackward0>, )

In addition, Image and Mask also take a PIL.Image.Image directly:

image = datapoints.Image(PIL.Image.open("assets/astronaut.jpg"))
print(image.shape, image.dtype)
torch.Size([3, 512, 512]) torch.uint8

In general, the datapoints can also store additional metadata that complements the underlying tensor. For example, BoundingBox stores the coordinate format as well as the spatial size of the corresponding image alongside the actual values:

bounding_box = datapoints.BoundingBox(
    [17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=image.shape[-2:]
)
print(bounding_box)
BoundingBox([ 17,  16, 344, 495], format=BoundingBoxFormat.XYXY, spatial_size=torch.Size([512, 512]))

Do I have to wrap the output of the datasets myself?

Only if you are using custom datasets. For the built-in ones, you can use torchvision.datasets.wrap_dataset_for_transforms_v2(). Note that the function also supports subclasses of the built-in datasets. Meaning, if your custom dataset subclasses from a built-in one and the output type is the same, you also don’t have to wrap manually.

How do the datapoints behave inside a computation?

Datapoints look and feel just like regular tensors. Everything that is supported on a plain torch.Tensor also works on datapoints. Since for most operations involving datapoints, it cannot be safely inferred whether the result should retain the datapoint type, we choose to return a plain tensor instead of a datapoint (this might change, see note below):

assert isinstance(image, datapoints.Image)

new_image = image + 0

assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image)

Note

This “unwrapping” behaviour is something we’re actively seeking feedback on. If you find this surprising or if you have any suggestions on how to better support your use-cases, please reach out to us via this issue: https://github.com/pytorch/vision/issues/7319

There are two exceptions to this rule:

  1. The operations clone(), to(), and requires_grad_() retain the datapoint type.

  2. Inplace operations on datapoints cannot change the type of the datapoint they are called on. However, if you use the flow style, the returned value will be unwrapped:

image = datapoints.Image([[[0, 1], [1, 0]]])

new_image = image.add_(1).mul_(2)

assert isinstance(image, torch.Tensor)
print(image)

assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image)
assert (new_image == image).all()
Image([[[2, 4],
        [4, 2]]], )

Total running time of the script: ( 0 minutes 0.059 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources