Shortcuts

TensorDictModule

class tensordict.nn.TensorDictModule(*args, **kwargs)

A TensorDictModule, is a python wrapper around a nn.Module that reads and writes to a TensorDict.

Parameters:
  • module (Callable) – a callable, typically a torch.nn.Module, used to map the input to the output parameter space. Its forward method can return a single tensor, a tuple of tensors or even a dictionary. In the latter case, the output keys of the TensorDictModule will be used to populate the output tensordict (ie. the keys present in out_keys should be present in the dictionary returned by the module forward method).

  • in_keys (iterable of NestedKeys, Dict[NestedStr, str]) – keys to be read from input tensordict and passed to the module. If it contains more than one element, the values will be passed in the order given by the in_keys iterable. If in_keys is a dictionary, its keys must correspond to the key to be read in the tensordict and its values must match the name of the keyword argument in the function signature.

  • out_keys (iterable of str) – keys to be written to the input tensordict. The length of out_keys must match the number of tensors returned by the embedded module. Using “_” as a key avoid writing tensor to output.

Keyword Arguments:

inplace (bool or string, optional) –

if True (default), the output of the module are written in the tensordict provided to the forward() method. If False, a new TensorDict with and empty batch-size and no device is created. if "empty", empty() will be used to create the output tensordict.

Note

If inplace=False and the tensordict passed to the module is another TensorDictBase subclass than TensorDict, the output will still be a TensorDict instance. Its batch-size will be empty, and it will have no device. Set to "empty" to get the same TensorDictBase subtype, an identical batch-size and device. Use tensordict_out at runtime (see below) to have a more fine-grained control over the output.

Note

If inplace=False and a tensordict_out is passed to the forward() method, the tensordict_out will prevail. This is the way one can get a tensordict_out taensordict passed to the module is another TensorDictBase subclass than TensorDict, the output will still be a TensorDict instance.

Embedding a neural network in a TensorDictModule only requires to specify the input and output keys. TensorDictModule support functional and regular nn.Module objects. In the functional case, the ‘params’ (and ‘buffers’) keyword argument must be specified:

Examples

>>> from tensordict import TensorDict
>>> # one can wrap regular nn.Module
>>> module = TensorDictModule(nn.Transformer(128), in_keys=["input", "tgt"], out_keys=["out"])
>>> input = torch.ones(2, 3, 128)
>>> tgt = torch.zeros(2, 3, 128)
>>> data = TensorDict({"input": input, "tgt": tgt}, batch_size=[2, 3])
>>> data = module(data)
>>> print(data)
TensorDict(
    fields={
        input: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False),
        out: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False),
        tgt: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2, 3]),
    device=None,
    is_shared=False)

We can also pass directly the tensors

Examples

>>> out = module(input, tgt)
>>> assert out.shape == input.shape
>>> # we can also wrap regular functions
>>> module = TensorDictModule(lambda x: (x-1, x+1), in_keys=[("input", "x")], out_keys=[("output", "x-1"), ("output", "x+1")])
>>> module(TensorDict({("input", "x"): torch.zeros(())}, batch_size=[]))
TensorDict(
    fields={
        input: TensorDict(
            fields={
                x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        output: TensorDict(
            fields={
                x+1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                x-1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

We can use TensorDictModule to populate a tensordict:

Examples

>>> module = TensorDictModule(lambda: torch.randn(3), in_keys=[], out_keys=["x"])
>>> print(module(TensorDict({}, batch_size=[])))
TensorDict(
    fields={
        x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

Another feature is passing a dictionary as input keys, to control the dispatching of values to specific keyword arguments.

Examples

>>> module = TensorDictModule(lambda x, *, y: x+y,
...     in_keys={'1': 'x', '2': 'y'}, out_keys=['z'],
...     )
>>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, []))
>>> td['z']
tensor(3.)

Functional calls to a tensordict module is easy:

Examples

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,])
>>> module = torch.nn.GRUCell(4, 8)
>>> td_module = TensorDictModule(
...    module=module, in_keys=["input", "hidden"], out_keys=["output"]
... )
>>> params = TensorDict.from_module(td_module)
>>> # functional API
>>> with params.to_module(td_module):
...     td_functional = td_module(td.clone())
>>> print(td_functional)
TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
In the stateful case:
>>> module = torch.nn.GRUCell(4, 8)
>>> td_module = TensorDictModule(
...    module=module, in_keys=["input", "hidden"], out_keys=["output"]
... )
>>> td_stateful = td_module(td.clone())
>>> print(td_stateful)
TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
forward(tensordict: TensorDictBase = None, args=None, *, tensordict_out: tensordict.base.TensorDictBase | None = None, **kwargs: Any) TensorDictBase

When the tensordict parameter is not set, kwargs are used to create an instance of TensorDict.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources