• Docs >
  • Manipulating the shape of a TensorDict

Manipulating the shape of a TensorDict

Author: Tom Begley

In this tutorial you will learn how to manipulate the shape of a TensorDict and its contents.

When we create a TensorDict we specify a batch_size, which must agree with the leading dimensions of all entries in the TensorDict. Since we have a guarantee that all entries share those dimensions in common, TensorDict is able to expose a number of methods with which we can manipulate the shape of the TensorDict and its contents.

import torch
from tensordict.tensordict import TensorDict

Indexing a TensorDict

Since the batch dimensions are guaranteed to exist on all entries, we can index them as we please, and each entry of the TensorDict will be indexed in the same way.

a = torch.rand(3, 4)
b = torch.rand(3, 4, 5)
tensordict = TensorDict({"a": a, "b": b}, batch_size=[3, 4])

indexed_tensordict = tensordict[:2, 1]
assert indexed_tensordict["a"].shape == torch.Size([2])
assert indexed_tensordict["b"].shape == torch.Size([2, 5])

Reshaping a TensorDict

TensorDict.reshape works just like torch.Tensor.reshape(). It applies to all of the contents of the TensorDict along the batch dimensions - note the shape of b in the example below. It also updates the batch_size attribute.

reshaped_tensordict = tensordict.reshape(-1)
assert reshaped_tensordict.batch_size == torch.Size([12])
assert reshaped_tensordict["a"].shape == torch.Size([12])
assert reshaped_tensordict["b"].shape == torch.Size([12, 5])

Splitting a TensorDict

TensorDict.split is similar to torch.Tensor.split(). It splits the TensorDict into chunks. Each chunk is a TensorDict with the same structure as the original one, but whose entries are views of the corresponding entries in the original TensorDict.

chunks = tensordict.split([3, 1], dim=1)
assert chunks[0].batch_size == torch.Size([3, 3])
assert chunks[1].batch_size == torch.Size([3, 1])
torch.testing.assert_close(chunks[0]["a"], tensordict["a"][:, :-1])


Whenever a function or method accepts a dim argument, negative dimensions are interpreted relative to the batch_size of the TensorDict that the function or method is called on. In particular, if there are nested TensorDict values with different batch sizes, the negative dimension is always interpreted relative to the batch dimensions of the root.

tensordict = TensorDict(
        "a": torch.rand(3, 4),
        "nested": TensorDict({"b": torch.rand(3, 4, 5)}, [3, 4, 5])
    [3, 4],
# dim = -2 will be interpreted as the first dimension throughout, as the root
# TensorDict has 2 batch dimensions, even though the nested TensorDict has 3
chunks = tensordict.split([2, 1], dim=-2)
assert chunks[0].batch_size == torch.Size([2, 4])
assert chunks[0]["nested"].batch_size == torch.Size([2, 4, 5])

As you can see from this example, the TensorDict.split method behaves exactly as though we had replaced dim=-2 with dim=tensordict.batch_dims - 2 before calling.


TensorDict.unbind is similar to torch.Tensor.unbind(), and conceptually similar to TensorDict.split. It removes the specified dimension and returns a tuple of all slices along that dimension.

slices = tensordict.unbind(dim=1)
assert len(slices) == 4
assert all(s.batch_size == torch.Size([3]) for s in slices)
torch.testing.assert_close(slices[0]["a"], tensordict["a"][:, 0])

Stacking and concatenating

TensorDict can be used in conjunction with torch.cat and torch.stack.

Stacking TensorDict

Stacking can done lazily or contiguously. A lazy stack is just a list of tensordicts presented as a stack of tensordicts. It allows users to carry a bag of tensordicts with different content shape, device or key sets. Another advantage is that the stack operation can be expensive, and if only a small subset of keys is required, a lazy stack will be much faster than a proper stack. It relies on the LazyStackedTensorDict class. In this case, values will only be stacked on-demand when they are accessed.

from tensordict import LazyStackedTensorDict

cloned_tensordict = tensordict.clone()
stacked_tensordict = LazyStackedTensorDict.lazy_stack(
    [tensordict, cloned_tensordict], dim=0

# Previously, torch.stack was always returning a lazy stack. For consistency with
# the regular PyTorch API, this behaviour will soon be adapted to deliver only
# dense tensordicts. To control which behaviour you are relying on, you can use
# the :func:`~tensordict.utils.set_lazy_legacy` decorator/context manager:

from tensordict.utils import set_lazy_legacy

with set_lazy_legacy(True):  # old behaviour
    lazy_stack = torch.stack([tensordict, cloned_tensordict])
assert isinstance(lazy_stack, LazyStackedTensorDict)

with set_lazy_legacy(False):  # new behaviour
    dense_stack = torch.stack([tensordict, cloned_tensordict])
assert isinstance(dense_stack, TensorDict)
        a: Tensor(shape=torch.Size([2, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2, 3, 4]),

If we index a LazyStackedTensorDict along the stacking dimension we recover the original TensorDict.

assert stacked_tensordict[0] is tensordict
assert stacked_tensordict[1] is cloned_tensordict

Accessing a key in the LazyStackedTensorDict results in those values being stacked. If the key corresponds to a nested TensorDict then we will recover another LazyStackedTensorDict.

assert stacked_tensordict["a"].shape == torch.Size([2, 3, 4])


Since values are stacked on-demand, accessing an item multiple times will mean it gets stacked multiple times, which is inefficient. If you need to access a value in the stacked TensorDict more than once, you may want to consider converting the LazyStackedTensorDict to a contiguous TensorDict, which can be done with the LazyStackedTensorDict.to_tensordict or LazyStackedTensorDict.contiguous methods.

After calling either of these methods, we will have a regular TensorDict containing the stacked values, and no additional computation is performed when values are accessed.

Concatenating TensorDict

Concatenation is not done lazily, instead calling torch.cat() on a list of TensorDict instances simply returns a TensorDict whose entries are the concatenated entries of the elements of the list.

concatenated_tensordict = torch.cat([tensordict, cloned_tensordict], dim=0)
assert isinstance(concatenated_tensordict, TensorDict)
assert concatenated_tensordict.batch_size == torch.Size([6, 4])
assert concatenated_tensordict["b"].shape == torch.Size([6, 4, 5])

Expanding TensorDict

We can expand all of the entries of a TensorDict using TensorDict.expand.

exp_tensordict = tensordict.expand(2, *tensordict.batch_size)
assert exp_tensordict.batch_size == torch.Size([2, 3, 4])
torch.testing.assert_close(exp_tensordict["a"][0], exp_tensordict["a"][1])

Squeezing and Unsqueezing TensorDict

We can squeeze or unsqueeze the contents of a TensorDict with the squeeze() and unsqueeze() methods.

tensordict = TensorDict({"a": torch.rand(3, 1, 4)}, [3, 1, 4])
squeezed_tensordict = tensordict.squeeze()
assert squeezed_tensordict["a"].shape == torch.Size([3, 4])
print(squeezed_tensordict, end="\n\n")

unsqueezed_tensordict = tensordict.unsqueeze(-1)
assert unsqueezed_tensordict["a"].shape == torch.Size([3, 1, 4, 1])
        a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 4]),

        a: Tensor(shape=torch.Size([3, 1, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 1, 4, 1]),


Until now, operations like unsqueeze(), squeeze(), view(), permute(), transpose() were all returning a lazy version of these operations (ie, a container where the original tensordict was stored and where the operations was applied every time a key was accessed). This behaviour will be deprecated in the future and can be already controlled via the set_lazy_legacy() function:

>>> with set_lazy_legacy(True):
...     lazy_unsqueeze = tensordict.unsqueeze(0)
>>> with set_lazy_legacy(False):
...     dense_unsqueeze = tensordict.unsqueeze(0)

Bear in mind that as ever, these methods apply only to the batch dimensions. Any non batch dimensions of the entries will be unaffected

tensordict = TensorDict({"a": torch.rand(3, 1, 1, 4)}, [3, 1])
squeezed_tensordict = tensordict.squeeze()
# only one of the singleton dimensions is dropped as the other
# is not a batch dimension
assert squeezed_tensordict["a"].shape == torch.Size([3, 1, 4])

Viewing a TensorDict

TensorDict also supports view. This creates a _ViewedTensorDict which lazily creates views on its contents when they are accessed.

tensordict = TensorDict({"a": torch.arange(12)}, [12])
# no views are created at this step
viewed_tensordict = tensordict.view((2, 3, 2))

# the view of "a" is created on-demand when we access it
assert viewed_tensordict["a"].shape == torch.Size([2, 3, 2])

Permuting batch dimensions

The TensorDict.permute method can be used to permute the batch dimensions much like torch.permute(). Non batch dimensions are left untouched.

This operation is lazy, so batch dimensions are only permuted when we try to access the entries. As ever, if you are likely to need to access a particular entry multiple times, consider converting to a TensorDict.

tensordict = TensorDict({"a": torch.rand(3, 4), "b": torch.rand(3, 4, 5)}, [3, 4])
# swap the batch dimensions
permuted_tensordict = tensordict.permute([1, 0])

assert permuted_tensordict["a"].shape == torch.Size([4, 3])
assert permuted_tensordict["b"].shape == torch.Size([4, 3, 5])

Using tensordicts as decorators

For a bunch of reversible operations, tensordicts can be used as decorators. These operations include to_module() for functional calls, unlock_() and lock_() or shape operations such as view(), permute() transpose(), squeeze() and unsqueeze(). Here is a quick example with the transpose function:

tensordict = TensorDict({"a": torch.rand(3, 4), "b": torch.rand(3, 4, 5)}, [3, 4])

with tensordict.transpose(1, 0) as tdt:
    tdt.set("c", torch.ones(4, 3))  # we have permuted the dims

# the ``"c"`` entry is now in the tensordict we used as decorator:

assert (tensordict.get("c") == 1).all()

Gathering values in TensorDict

The TensorDict.gather method can be used to index along the batch dimensions and gather the results into a single dimension much like torch.gather().

index = torch.randint(4, (3, 4))
gathered_tensordict = tensordict.gather(dim=1, index=index)
print("index:\n", index, end="\n\n")
print("tensordict['a']:\n", tensordict["a"], end="\n\n")
print("gathered_tensordict['a']:\n", gathered_tensordict["a"], end="\n\n")
 tensor([[3, 3, 1, 1],
        [1, 0, 0, 1],
        [1, 3, 3, 3]])

 tensor([[0.1383, 0.3027, 0.4667, 0.5540],
        [0.6529, 0.4292, 0.2175, 0.6586],
        [0.6720, 0.3614, 0.4439, 0.2300]])

 tensor([[0.5540, 0.5540, 0.3027, 0.3027],
        [0.4292, 0.6529, 0.6529, 0.4292],
        [0.3614, 0.2300, 0.2300, 0.2300]])

Total running time of the script: (0 minutes 0.008 seconds)

Gallery generated by Sphinx-Gallery


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources