• Docs >
  • Slicing, Indexing, and Masking
Shortcuts

Slicing, Indexing, and Masking

Author: Tom Begley

In this tutorial you will learn how to slice, index, and mask a TensorDict.

As discussed in the tutorial Manipulating the shape of a TensorDict, when we create a TensorDict we specify a batch_size, which must agree with the leading dimensions of all entries in the TensorDict. Since we have a guarantee that all entries share those dimensions in common, we are able to index and mask the batch dimensions in the same way that we would index a torch.Tensor. The indices are applied along the batch dimensions to all of the entries in the TensorDict.

For example, given a TensorDict with two batch dimensions, tensordict[0] returns a new TensorDict with the same structure, and whose values correspond to the first “row” of each entry in the original TensorDict.

import torch
from tensordict import TensorDict

tensordict = TensorDict(
    {"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4]
)

print(tensordict[0])
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([4]),
    device=None,
    is_shared=False)

The same syntax applies as for regular tensors. For example if we wanted to drop the first row of each entry we could index as follows

print(tensordict[1:])
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2, 4]),
    device=None,
    is_shared=False)

We can index multiple dimensions simultaneously

print(tensordict[:, 2:])
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 2]),
    device=None,
    is_shared=False)

We can also use Ellipsis to represent as many : as would be needed to make the selection tuple the same length as tensordict.batch_dims.

print(tensordict[..., 2:])
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 2]),
    device=None,
    is_shared=False)

Setting Values with Indexing

In general, tensordict[index] = new_tensordict will work as long as the batch sizes are compatible.

tensordict = TensorDict(
    {"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4]
)

td2 = TensorDict({"a": torch.ones(2, 4, 5), "b": torch.ones(2, 4)}, batch_size=[2, 4])
tensordict[:-1] = td2
print(tensordict["a"], tensordict["b"])
tensor([[[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]]) tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [0., 0., 0., 0.]])

Masking

We mask TensorDict as we mask tensors.

mask = torch.BoolTensor([[1, 0, 1, 0], [1, 0, 1, 0], [1, 0, 1, 0]])
tensordict[mask]
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([6, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([6]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([6]),
    device=None,
    is_shared=False)

Total running time of the script: (0 minutes 0.179 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources