Shortcuts

TensorDictModule

class tensordict.nn.TensorDictModule(*args, **kwargs)

A TensorDictModule, is a python wrapper around a nn.Module that reads and writes to a TensorDict.

Parameters:
  • module (Callable) – a callable, typically a torch.nn.Module, used to map the input to the output parameter space. Its forward method can return a single tensor, a tuple of tensors or even a dictionary. In the latter case, the output keys of the TensorDictModule will be used to populate the output tensordict (ie. the keys present in out_keys should be present in the dictionary returned by the module forward method).

  • in_keys (iterable of NestedKeys, Dict[NestedStr, str]) – keys to be read from input tensordict and passed to the module. If it contains more than one element, the values will be passed in the order given by the in_keys iterable. If in_keys is a dictionary, its keys must correspond to the key to be read in the tensordict and its values must match the name of the keyword argument in the function signature.

  • out_keys (iterable of str) – keys to be written to the input tensordict. The length of out_keys must match the number of tensors returned by the embedded module. Using “_” as a key avoid writing tensor to output.

Embedding a neural network in a TensorDictModule only requires to specify the input and output keys. TensorDictModule support functional and regular nn.Module objects. In the functional case, the ‘params’ (and ‘buffers’) keyword argument must be specified:

Examples

>>> from tensordict import TensorDict
>>> # one can wrap regular nn.Module
>>> module = TensorDictModule(nn.Transformer(128), in_keys=["input", "tgt"], out_keys=["out"])
>>> input = torch.ones(2, 3, 128)
>>> tgt = torch.zeros(2, 3, 128)
>>> data = TensorDict({"input": input, "tgt": tgt}, batch_size=[2, 3])
>>> data = module(data)
>>> print(data)
TensorDict(
    fields={
        input: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False),
        out: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False),
        tgt: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2, 3]),
    device=None,
    is_shared=False)

We can also pass directly the tensors

Examples

>>> out = module(input, tgt)
>>> assert out.shape == input.shape
>>> # we can also wrap regular functions
>>> module = TensorDictModule(lambda x: (x-1, x+1), in_keys=[("input", "x")], out_keys=[("output", "x-1"), ("output", "x+1")])
>>> module(TensorDict({("input", "x"): torch.zeros(())}, batch_size=[]))
TensorDict(
    fields={
        input: TensorDict(
            fields={
                x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        output: TensorDict(
            fields={
                x+1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                x-1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

We can use TensorDictModule to populate a tensordict:

Examples

>>> module = TensorDictModule(lambda: torch.randn(3), in_keys=[], out_keys=["x"])
>>> print(module(TensorDict({}, batch_size=[])))
TensorDict(
    fields={
        x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

Another feature is passing a dictionary as input keys, to control the dispatching of values to specific keyword arguments.

Examples

>>> module = TensorDictModule(lambda x, *, y: x+y,
...     in_keys={'1': 'x', '2': 'y'}, out_keys=['z'],
...     )
>>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, []))
>>> td['z']
tensor(3.)

Functional calls to a tensordict module is easy:

Examples

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,])
>>> module = torch.nn.GRUCell(4, 8)
>>> td_module = TensorDictModule(
...    module=module, in_keys=["input", "hidden"], out_keys=["output"]
... )
>>> params = TensorDict.from_module(td_module)
>>> # functional API
>>> with params.to_module(td_module):
...     td_functional = td_module(td.clone())
>>> print(td_functional)
TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
In the stateful case:
>>> module = torch.nn.GRUCell(4, 8)
>>> td_module = TensorDictModule(
...    module=module, in_keys=["input", "hidden"], out_keys=["output"]
... )
>>> td_stateful = td_module(td.clone())
>>> print(td_stateful)
TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

One can use a vmap operator to call the functional module.

Examples

>>> from torch import vmap
>>> from tensordict.nn.functional_modules import extract_weights_and_buffers
>>> params = extract_weights_and_buffers(td_module)
>>> params_repeat = params.expand(4)
>>> print(params_repeat)
TensorDict(
    fields={
        module: TensorDict(
            fields={
                bias_hh: Tensor(shape=torch.Size([4, 24]), device=cpu, dtype=torch.float32, is_shared=False),
                bias_ih: Tensor(shape=torch.Size([4, 24]), device=cpu, dtype=torch.float32, is_shared=False),
                weight_hh: Tensor(shape=torch.Size([4, 24, 8]), device=cpu, dtype=torch.float32, is_shared=False),
                weight_ih: Tensor(shape=torch.Size([4, 24, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([4]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([4]),
    device=None,
    is_shared=False)
>>> def func(td, params):
...     with params.to_module(td_module):
...         return td_module(td)
>>> td_vmap = vmap(func, (None, 0))(td.clone(), params_repeat)
>>> print(td_vmap)
TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([4, 3]),
    device=None,
    is_shared=False)
forward(tensordict: TensorDictBase, *args, tensordict_out: TensorDictBase | None = None, **kwargs: Any) TensorDictBase

When the tensordict parameter is not set, kwargs are used to create an instance of TensorDict.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources