• Docs >
  • Manipulating the keys of a TensorDict
Shortcuts

Manipulating the keys of a TensorDict

Author: Tom Begley

In this tutorial you will learn how to work with and manipulate the keys in a TensorDict, including getting and setting keys, iterating over keys, manipulating nested values, and flattening the keys.

Setting and getting keys

We can set and get keys using the same syntax as a Python dict

import torch
from tensordict.tensordict import TensorDict

tensordict = TensorDict()

# set a key
a = torch.rand(10)
tensordict["a"] = a

# retrieve the value stored under "a"
assert tensordict["a"] is a

Note

Unlike a Python dict, all keys in the TensorDict must be strings. However as we will see, it is also possible to use tuples of strings to manipulate nested values.

We can also use the methods .get() and .set to accomplish the same thing.

tensordict = TensorDict()

# set a key
a = torch.rand(10)
tensordict.set("a", a)

# retrieve the value stored under "a"
assert tensordict.get("a") is a

Like dict, we can provide a default value to get that should be returned in case the requested key is not found.

assert tensordict.get("banana", a) is a

Similarly, like dict, we can use the TensorDict.setdefault() to get the value of a particular key, returning a default value if that key is not found, and also setting that value in the TensorDict.

assert tensordict.setdefault("banana", a) is a
# a is now stored under "banana"
assert tensordict["banana"] is a

Deleting keys is also achieve in the same way as a Python dict, using the del statement and the chosen key. Equivalently we could use the TensorDict.del_ method.

del tensordict["banana"]

Furthermore, when setting keys with .set() we can use the keyword argument inplace=True to make an inplace update, or equivalently use the .set_() method.

tensordict.set("a", torch.zeros(10), inplace=True)

# all the entries of the "a" tensor are now zero
assert (tensordict.get("a") == 0).all()
# but it's still the same tensor as before
assert tensordict.get("a") is a

# we can achieve the same with set_
tensordict.set_("a", torch.ones(10))
assert (tensordict.get("a") == 1).all()
assert tensordict.get("a") is a

Renaming keys

To rename a key, simply use the TensorDict.rename_key_ method. The value stored under the original key will remain in the TensorDict, but the key will be changed to the specified new key.

tensordict.rename_key_("a", "b")
assert tensordict.get("b") is a
print(tensordict)
TensorDict(
    fields={
        b: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

Updating multiple values

The TensorDict.update method can be used to update a TensorDict` with another one or with a dict. Keys that already exist are overwritten, and keys that do not already exist are created.

tensordict = TensorDict({"a": torch.rand(10), "b": torch.rand(10)}, [10])
tensordict.update(TensorDict({"a": torch.zeros(10), "c": torch.zeros(10)}, [10]))
assert (tensordict["a"] == 0).all()
assert (tensordict["b"] != 0).all()
assert (tensordict["c"] == 0).all()
print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
        c: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

Nested values

The values of a TensorDict can themselves be a TensorDict. We can add nested values during instantiation, either by adding TensorDict directly, or using nested dictionaries

# creating nested values with a nested dict
nested_tensordict = TensorDict(
    {"a": torch.rand(2, 3), "double_nested": {"a": torch.rand(2, 3)}}, [2, 3]
)
# creating nested values with a TensorDict
tensordict = TensorDict({"a": torch.rand(2), "nested": nested_tensordict}, [2])

print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                double_nested: TensorDict(
                    fields={
                        a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([2, 3]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([2, 3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

To access these nested values, we can use tuples of strings. For example

double_nested_a = tensordict["nested", "double_nested", "a"]
nested_a = tensordict.get(("nested", "a"))

Similarly we can set nested values using tuples of strings

tensordict["nested", "double_nested", "b"] = torch.rand(2, 3)
tensordict.set(("nested", "b"), torch.rand(2, 3))

print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                double_nested: TensorDict(
                    fields={
                        a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                        b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([2, 3]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([2, 3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

Iterating over a TensorDict’s contents

We can iterate over the keys of a TensorDict using the .keys() method.

for key in tensordict.keys():
    print(key)
a
nested

By default this will iterate only over the top-level keys in the TensorDict, however it is possible to recursively iterate over all of the keys in the TensorDict with the keyword argument include_nested=True. This will iterate recursively over all keys in any nested TensorDicts, returning nested keys as tuples of strings.

for key in tensordict.keys(include_nested=True):
    print(key)
a
('nested', 'a')
('nested', 'double_nested', 'a')
('nested', 'double_nested', 'b')
('nested', 'double_nested')
('nested', 'b')
nested

In case you want to only iterate over keys corresponding to Tensor values, you can additionally specify leaves_only=True.

for key in tensordict.keys(include_nested=True, leaves_only=True):
    print(key)
a
('nested', 'a')
('nested', 'double_nested', 'a')
('nested', 'double_nested', 'b')
('nested', 'b')

Much like dict, there are also .values and .items methods which accept the same keyword arguments.

for key, value in tensordict.items(include_nested=True):
    if isinstance(value, TensorDict):
        print(f"{key} is a TensorDict")
    else:
        print(f"{key} is a Tensor")
a is a Tensor
nested is a TensorDict
('nested', 'a') is a Tensor
('nested', 'double_nested') is a TensorDict
('nested', 'double_nested', 'a') is a Tensor
('nested', 'double_nested', 'b') is a Tensor
('nested', 'b') is a Tensor

Checking for existence of a key

To check if a key exists in a TensorDict, use the in operator in conjunction with .keys().

Note

Performing key in tensordict.keys() does efficient dict lookups of keys (recursively at each level in the nested case), and so performance is not negatively impacted when there is a large number of keys in the TensorDict.

assert "a" in tensordict.keys()
# to check for nested keys, set include_nested=True
assert ("nested", "a") in tensordict.keys(include_nested=True)
assert ("nested", "banana") not in tensordict.keys(include_nested=True)

Flattening and unflattening nested keys

We can flatten a TensorDict with nested values using the .flatten_keys() method.

print(tensordict, end="\n\n")
print(tensordict.flatten_keys(separator="."))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                double_nested: TensorDict(
                    fields={
                        a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                        b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([2, 3]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([2, 3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.double_nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.double_nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

Given a TensorDict that has been flattened, it is possible to unflatten it again with the .unflatten_keys() method.

flattened_tensordict = tensordict.flatten_keys(separator=".")
print(flattened_tensordict, end="\n\n")
print(flattened_tensordict.unflatten_keys(separator="."))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.double_nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested.double_nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                double_nested: TensorDict(
                    fields={
                        a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                        b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([2]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([2]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

This can be particularly useful when manipulating the parameters of a torch.nn.Module, as we can end up with a TensorDict whose structure mimics the module structure.

import torch.nn as nn

module = nn.Sequential(
    nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 10)),
    nn.Linear(10, 1),
)
params = TensorDict(dict(module.named_parameters()), []).unflatten_keys()

print(params)
TensorDict(
    fields={
        0: TensorDict(
            fields={
                0: TensorDict(
                    fields={
                        bias: Parameter(shape=torch.Size([50]), device=cpu, dtype=torch.float32, is_shared=False),
                        weight: Parameter(shape=torch.Size([50, 100]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                1: TensorDict(
                    fields={
                        bias: Parameter(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                        weight: Parameter(shape=torch.Size([10, 50]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        1: TensorDict(
            fields={
                bias: Parameter(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
                weight: Parameter(shape=torch.Size([1, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

Selecting and excluding keys

We can obtain a new TensorDict with a subset of the keys by using TensorDict.select, which returns a new TensorDict containing only the specified keys, or :meth: TensorDict.exclude <tensordict.TensorDict.exclude>, which returns a new TensorDict with the specified keys omitted.

print("Select:")
print(tensordict.select("a", ("nested", "a")), end="\n\n")
print("Exclude:")
print(tensordict.exclude(("nested", "b"), ("nested", "double_nested")))
Select:
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([2, 3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

Exclude:
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([2, 3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

Total running time of the script: (0 minutes 0.009 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources