Shortcuts

TensorDictSequential

class tensordict.nn.TensorDictSequential(*args, **kwargs)

A sequence of TensorDictModules.

Similarly to nn.Sequence which passes a tensor through a chain of mappings that read and write a single tensor each, this module will read and write over a tensordict by querying each of the input modules. When calling a TensorDictSequencial instance with a functional module, it is expected that the parameter lists (and buffers) will be concatenated in a single list.

Parameters:

modules (iterable of TensorDictModules) – ordered sequence of TensorDictModule instances to be run sequentially.

Keyword Arguments:
  • partial_tolerant (bool, optional) – if True, the input tensordict can miss some of the input keys. If so, the only module that will be executed are those who can be executed given the keys that are present. Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is True AND if the stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts looking for those that have the required keys, if any.

  • selected_out_keys (iterable of NestedKeys, optional) – the list of out-keys to select. If not provided, all out_keys will be written.

Note

A TensorDictSequential instance may have a long list of output keys, and one may wish to remove some of them after execution for clarity or memory purposes. If this is the case, the method select_out_keys() can be used after instantiation, or selected_out_keys may be passed to the constructor.

Examples

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> torch.manual_seed(0)
>>> module = TensorDictSequential(
...     TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["x+1"]),
...     TensorDictModule(nn.Linear(3, 4), in_keys=["x+1"], out_keys=["w*(x+1)+b"]),
... )
>>> # with tensordict input
>>> print(module(TensorDict({"x": torch.zeros(3)}, [])))
TensorDict(
    fields={
        w*(x+1)+b: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
        x+1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> # with tensor input: returns all the output keys in the order of the modules, ie "x+1" and "w*(x+1)+b"
>>> module(x=torch.zeros(3))
(tensor([1., 1., 1.]), tensor([-0.7214, -0.8748,  0.1571, -0.1138], grad_fn=<AddBackward0>))
>>> module(torch.zeros(3))
(tensor([1., 1., 1.]), tensor([-0.7214, -0.8748,  0.1571, -0.1138], grad_fn=<AddBackward0>))

TensorDictSequence supports functional, modular and vmap coding: .. rubric:: Examples

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import (
...     ProbabilisticTensorDictModule,
...     ProbabilisticTensorDictSequential,
...     TensorDictModule,
...     TensorDictSequential,
... )
>>> from tensordict.nn.distributions import NormalParamExtractor
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch.distributions import Normal
>>> td = TensorDict({"input": torch.randn(3, 4)}, [3,])
>>> net1 = torch.nn.Linear(4, 8)
>>> module1 = TensorDictModule(net1, in_keys=["input"], out_keys=["params"])
>>> normal_params = TensorDictModule(
...      NormalParamExtractor(), in_keys=["params"], out_keys=["loc", "scale"]
...  )
>>> td_module1 = ProbabilisticTensorDictSequential(
...     module1,
...     normal_params,
...     ProbabilisticTensorDictModule(
...         in_keys=["loc", "scale"],
...         out_keys=["hidden"],
...         distribution_class=Normal,
...         return_log_prob=True,
...     )
... )
>>> module2 = torch.nn.Linear(4, 8)
>>> td_module2 = TensorDictModule(
...    module=module2, in_keys=["hidden"], out_keys=["output"]
... )
>>> td_module = TensorDictSequential(td_module1, td_module2)
>>> params = TensorDict.from_module(td_module)
>>> with params.to_module(td_module):
...     _ = td_module(td)
>>> print(td)
TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        params: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
In the vmap case:
>>> from torch import vmap
>>> params = params.expand(4)
>>> def func(td, params):
...     with params.to_module(td_module):
...         return td_module(td)
>>> td_vmap = vmap(func, (None, 0))(td, params)
>>> print(td_vmap)
TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        params: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([4, 3]),
    device=None,
    is_shared=False)
forward(tensordict: TensorDictBase = None, tensordict_out: tensordict.base.TensorDictBase | None = None, **kwargs: Any) TensorDictBase

When the tensordict parameter is not set, kwargs are used to create an instance of TensorDict.

reset_out_keys()

Resets the out_keys attribute to its orignal value.

Returns: the same module, with its original out_keys values.

Examples

>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> import torch
>>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"])
>>> mod.select_out_keys("d")
>>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])
>>> mod(td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> mod.reset_out_keys()
>>> mod(td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
select_out_keys(*selected_out_keys) TensorDictSequential

Selects the keys that will be found in the output tensordict.

This is useful whenever one wants to get rid of intermediate keys in a complicated graph, or when the presence of these keys may trigger unexpected behaviours.

The original out_keys can still be accessed via module.out_keys_source.

Parameters:

*out_keys (a sequence of strings or tuples of strings) – the out_keys that should be found in the output tensordict.

Returns: the same module, modified in-place with updated out_keys.

The simplest usage is with TensorDictModule:

Examples

>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> import torch
>>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"])
>>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])
>>> mod(td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> mod.select_out_keys("d")
>>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])
>>> mod(td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

This feature will also work with dispatched arguments: .. rubric:: Examples

>>> mod(torch.zeros(()), torch.ones(()))
tensor(2.)

This change will occur in-place (ie the same module will be returned with an updated list of out_keys). It can be reverted using the TensorDictModuleBase.reset_out_keys() method.

Examples

>>> mod.reset_out_keys()
>>> mod(TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

This will work with other classes too, such as Sequential: .. rubric:: Examples

>>> from tensordict.nn import TensorDictSequential
>>> seq = TensorDictSequential(
...     TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["y"]),
...     TensorDictModule(lambda x: x+1, in_keys=["y"], out_keys=["z"]),
... )
>>> td = TensorDict({"x": torch.zeros(())}, [])
>>> seq(td)
TensorDict(
    fields={
        x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        y: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> seq.select_out_keys("z")
>>> td = TensorDict({"x": torch.zeros(())}, [])
>>> seq(td)
TensorDict(
    fields={
        x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
select_subsequence(in_keys: Optional[Iterable[NestedKey]] = None, out_keys: Optional[Iterable[NestedKey]] = None) TensorDictSequential

Returns a new TensorDictSequential with only the modules that are necessary to compute the given output keys with the given input keys.

Parameters:
  • in_keys – input keys of the subsequence we want to select. All the keys absent from in_keys will be considered as non-relevant, and modules that just take these keys as inputs will be discarded. The resulting sequential module will follow the pattern “all the modules which output will be affected by a different value for any in <in_keys>”. If none is provided, the module’s in_keys are assumed.

  • out_keys – output keys of the subsequence we want to select. Only the modules that are necessary to get the out_keys will be found in the resulting sequence. The resulting sequential module will follow the pattern “all the modules that condition the value or <out_keys> entries.” If none is provided, the module’s out_keys are assumed.

Returns:

A new TensorDictSequential with only the modules that are necessary acording to the given input and output keys.

Examples

>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>> idn = lambda x: x
>>> module = Seq(
...     Mod(idn, in_keys=["a"], out_keys=["b"]),
...     Mod(idn, in_keys=["b"], out_keys=["c"]),
...     Mod(idn, in_keys=["c"], out_keys=["d"]),
...     Mod(idn, in_keys=["a"], out_keys=["e"]),
... )
>>> # select all modules whose output depend on "a"
>>> module.select_subsequence(in_keys=["a"])
TensorDictSequential(
    module=ModuleList(
      (0): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['a'],
          out_keys=['b'])
      (1): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['b'],
          out_keys=['c'])
      (2): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['c'],
          out_keys=['d'])
      (3): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['a'],
          out_keys=['e'])
    ),
    device=cpu,
    in_keys=['a'],
    out_keys=['b', 'c', 'd', 'e'])
>>> # select all modules whose output depend on "c"
>>> module.select_subsequence(in_keys=["c"])
TensorDictSequential(
    module=ModuleList(
      (0): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['c'],
          out_keys=['d'])
    ),
    device=cpu,
    in_keys=['c'],
    out_keys=['d'])
>>> # select all modules that affect the value of "c"
>>> module.select_subsequence(out_keys=["c"])
TensorDictSequential(
    module=ModuleList(
      (0): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['a'],
          out_keys=['b'])
      (1): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['b'],
          out_keys=['c'])
    ),
    device=cpu,
    in_keys=['a'],
    out_keys=['b', 'c'])
>>> # select all modules that affect the value of "e"
>>> module.select_subsequence(out_keys=["e"])
TensorDictSequential(
    module=ModuleList(
      (0): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['a'],
          out_keys=['e'])
    ),
    device=cpu,
    in_keys=['a'],
    out_keys=['e'])

This method propagates to nested sequential:

>>> module = Seq(
...     Seq(
...         Mod(idn, in_keys=["a"], out_keys=["b"]),
...         Mod(idn, in_keys=["b"], out_keys=["c"]),
...     ),
...     Seq(
...         Mod(idn, in_keys=["b"], out_keys=["d"]),
...         Mod(idn, in_keys=["d"], out_keys=["e"]),
...     ),
... )
>>> # select submodules whose output will be affected by a change in "b" or "d" AND which output is "e"
>>> module.select_subsequence(in_keys=["b", "d"], out_keys=["e"])
TensorDictSequential(
    module=ModuleList(
      (0): TensorDictSequential(
          module=ModuleList(
            (0): TensorDictModule(
                module=<function <lambda> at 0x129efae50>,
                device=cpu,
                in_keys=['b'],
                out_keys=['d'])
            (1): TensorDictModule(
                module=<function <lambda> at 0x129efae50>,
                device=cpu,
                in_keys=['d'],
                out_keys=['e'])
          ),
          device=cpu,
          in_keys=['b'],
          out_keys=['d', 'e'])
    ),
    device=cpu,
    in_keys=['b'],
    out_keys=['d', 'e'])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources