Shortcuts

TVTensors FAQ

Note

Try on collab or go to the end to download the full example code.

TVTensors are Tensor subclasses introduced together with torchvision.transforms.v2. This example showcases what these TVTensors are and how they behave.

Warning

Intended Audience Unless you’re writing your own transforms or your own TVTensors, you probably do not need to read this guide. This is a fairly low-level topic that most users will not need to worry about: you do not need to understand the internals of TVTensors to efficiently rely on torchvision.transforms.v2. It may however be useful for advanced users trying to implement their own datasets, transforms, or work directly with the TVTensors.

import PIL.Image

import torch
from torchvision import tv_tensors

What are TVTensors?

TVTensors are zero-copy tensor subclasses:

tensor = torch.rand(3, 256, 256)
image = tv_tensors.Image(tensor)

assert isinstance(image, torch.Tensor)
assert image.data_ptr() == tensor.data_ptr()

Under the hood, they are needed in torchvision.transforms.v2 to correctly dispatch to the appropriate function for the input data.

torchvision.tv_tensors supports four types of TVTensors:

What can I do with a TVTensor?

TVTensors look and feel just like regular tensors - they are tensors. Everything that is supported on a plain torch.Tensor like .sum() or any torch.* operator will also work on TVTensors. See I had a TVTensor but now I have a Tensor. Help! for a few gotchas.

How do I construct a TVTensor?

Using the constructor

Each TVTensor class takes any tensor-like data that can be turned into a Tensor

image = tv_tensors.Image([[[[0, 1], [1, 0]]]])
print(image)
Image([[[[0, 1],
         [1, 0]]]], )

Similar to other PyTorch creations ops, the constructor also takes the dtype, device, and requires_grad parameters.

float_image = tv_tensors.Image([[[0, 1], [1, 0]]], dtype=torch.float32, requires_grad=True)
print(float_image)
Image([[[0., 1.],
        [1., 0.]]], grad_fn=<AliasBackward0>, )

In addition, Image and Mask can also take a PIL.Image.Image directly:

image = tv_tensors.Image(PIL.Image.open("../assets/astronaut.jpg"))
print(image.shape, image.dtype)
torch.Size([3, 512, 512]) torch.uint8

Some TVTensors require additional metadata to be passed in ordered to be constructed. For example, BoundingBoxes requires the coordinate format as well as the size of the corresponding image (canvas_size) alongside the actual values. These metadata are required to properly transform the bounding boxes.

bboxes = tv_tensors.BoundingBoxes(
    [[17, 16, 344, 495], [0, 10, 0, 10]],
    format=tv_tensors.BoundingBoxFormat.XYXY,
    canvas_size=image.shape[-2:]
)
print(bboxes)
BoundingBoxes([[ 17,  16, 344, 495],
               [  0,  10,   0,  10]], format=BoundingBoxFormat.XYXY, canvas_size=torch.Size([512, 512]))

Using tv_tensors.wrap()

You can also use the wrap() function to wrap a tensor object into a TVTensor. This is useful when you already have an object of the desired type, which typically happens when writing transforms: you just want to wrap the output like the input.

new_bboxes = torch.tensor([0, 20, 30, 40])
new_bboxes = tv_tensors.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
assert new_bboxes.canvas_size == bboxes.canvas_size

The metadata of new_bboxes is the same as bboxes, but you could pass it as a parameter to override it.

I had a TVTensor but now I have a Tensor. Help!

By default, operations on TVTensor objects will return a pure Tensor:

assert isinstance(bboxes, tv_tensors.BoundingBoxes)

# Shift bboxes by 3 pixels in both H and W
new_bboxes = bboxes + 3

assert isinstance(new_bboxes, torch.Tensor)
assert not isinstance(new_bboxes, tv_tensors.BoundingBoxes)

Note

This behavior only affects native torch operations. If you are using the built-in torchvision transforms or functionals, you will always get as output the same type that you passed as input (pure Tensor or TVTensor).

But I want a TVTensor back!

You can re-wrap a pure tensor into a TVTensor by just calling the TVTensor constructor, or by using the wrap() function (see more details above in How do I construct a TVTensor?):

new_bboxes = bboxes + 3
new_bboxes = tv_tensors.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)

Alternatively, you can use the set_return_type() as a global config setting for the whole program, or as a context manager (read its docs to learn more about caveats):

with tv_tensors.set_return_type("TVTensor"):
    new_bboxes = bboxes + 3
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)

Why is this happening?

For performance reasons. TVTensor classes are Tensor subclasses, so any operation involving a TVTensor object will go through the __torch_function__ protocol. This induces a small overhead, which we want to avoid when possible. This doesn’t matter for built-in torchvision transforms because we can avoid the overhead there, but it could be a problem in your model’s forward.

The alternative isn’t much better anyway. For every operation where preserving the TVTensor type makes sense, there are just as many operations where returning a pure Tensor is preferable: for example, is img.sum() still an Image? If we were to preserve TVTensor types all the way, even model’s logits or the output of the loss function would end up being of type Image, and surely that’s not desirable.

Note

This behaviour is something we’re actively seeking feedback on. If you find this surprising or if you have any suggestions on how to better support your use-cases, please reach out to us via this issue: https://github.com/pytorch/vision/issues/7319

Exceptions

There are a few exceptions to this “unwrapping” rule: clone(), to(), torch.Tensor.detach(), and requires_grad_() retain the TVTensor type.

Inplace operations on TVTensors like obj.add_() will preserve the type of obj. However, the returned value of inplace operations will be a pure tensor:

image = tv_tensors.Image([[[0, 1], [1, 0]]])

new_image = image.add_(1).mul_(2)

# image got transformed in-place and is still a TVTensor Image, but new_image
# is a Tensor. They share the same underlying data and they're equal, just
# different classes.
assert isinstance(image, tv_tensors.Image)
print(image)

assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, tv_tensors.Image)
assert (new_image == image).all()
assert new_image.data_ptr() == image.data_ptr()
Image([[[2, 4],
        [4, 2]]], )

Total running time of the script: (0 minutes 0.009 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources