Shortcuts

tensordict.nn package

The tensordict.nn package makes it possible to flexibly use TensorDict within ML pipelines.

Since TensorDict turns parts of one’s code to a key-based structure, it is now possible to build complex graph structures using these keys as hooks. The basic building block is TensorDictModule, which wraps an torch.nn.Module instance with a list of input and output keys:

>>> from torch.nn import Transformer
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> import torch
>>> module = TensorDictModule(Transformer(), in_keys=["feature", "target"], out_keys=["prediction"])
>>> data = TensorDict({"feature": torch.randn(10, 11, 512), "target": torch.randn(10, 11, 512)}, [10, 11])
>>> data = module(data)
>>> print(data)
TensorDict(
    fields={
        feature: Tensor(torch.Size([10, 11, 512]), dtype=torch.float32),
        prediction: Tensor(torch.Size([10, 11, 512]), dtype=torch.float32),
        target: Tensor(torch.Size([10, 11, 512]), dtype=torch.float32)},
    batch_size=torch.Size([10, 11]),
    device=None,
    is_shared=False)

One does not necessarily need to use TensorDictModule, a custom torch.nn.Module with an ordered list of input and output keys (named module.in_keys and module.out_keys) will suffice.

A key pain-point of multiple PyTorch users is the inability of nn.Sequential to handle modules with multiple inputs. Working with key-based graphs can easily solve that problem as each node in the sequence knows what data needs to be read and where to write it.

For this purpose, we provide the TensorDictSequential class which passes data through a sequence of TensorDictModules. Each module in the sequence takes its input from, and writes its output to the original TensorDict, meaning it’s possible for modules in the sequence to ignore output from their predecessors, or take additional input from the tensordict as necessary. Here’s an example:

>>> from tensordict.nn import TensorDictSequential
>>> class Net(nn.Module):
...     def __init__(self, input_size=100, hidden_size=50, output_size=10):
...         super().__init__()
...         self.fc1 = nn.Linear(input_size, hidden_size)
...         self.fc2 = nn.Linear(hidden_size, output_size)
...
...     def forward(self, x):
...         x = torch.relu(self.fc1(x))
...         return self.fc2(x)
...
>>> class Masker(nn.Module):
...     def forward(self, x, mask):
...         return torch.softmax(x * mask, dim=1)
...
>>> net = TensorDictModule(
...     Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
... )
>>> masker = TensorDictModule(
...     Masker(),
...     in_keys=[("intermediate", "x"), ("input", "mask")],
...     out_keys=[("output", "probabilities")],
... )
>>> module = TensorDictSequential(net, masker)
>>>
>>> td = TensorDict(
...     {
...         "input": TensorDict(
...             {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
...             batch_size=[32],
...         )
...     },
...     batch_size=[32],
... )
>>> td = module(td)
>>> print(td)
TensorDict(
    fields={
        input: TensorDict(
            fields={
                mask: Tensor(torch.Size([32, 10]), dtype=torch.int64),
                x: Tensor(torch.Size([32, 100]), dtype=torch.float32)},
            batch_size=torch.Size([32]),
            device=None,
            is_shared=False),
        intermediate: TensorDict(
            fields={
                x: Tensor(torch.Size([32, 10]), dtype=torch.float32)},
            batch_size=torch.Size([32]),
            device=None,
            is_shared=False),
        output: TensorDict(
            fields={
                probabilities: Tensor(torch.Size([32, 10]), dtype=torch.float32)},
            batch_size=torch.Size([32]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([32]),
    device=None,
    is_shared=False)

We can also select sub-graphs easily through the select_subsequence() method:

>>> sub_module = module.select_subsequence(out_keys=[("intermediate", "x")])
>>> td = TensorDict(
...     {
...         "input": TensorDict(
...             {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
...             batch_size=[32],
...         )
...     },
...     batch_size=[32],
... )
>>> sub_module(td)
>>> print(td)  # the "output" has not been computed
TensorDict(
    fields={
        input: TensorDict(
            fields={
                mask: Tensor(torch.Size([32, 10]), dtype=torch.int64),
                x: Tensor(torch.Size([32, 100]), dtype=torch.float32)},
            batch_size=torch.Size([32]),
            device=None,
            is_shared=False),
        intermediate: TensorDict(
            fields={
                x: Tensor(torch.Size([32, 10]), dtype=torch.float32)},
            batch_size=torch.Size([32]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([32]),
    device=None,
    is_shared=False)

Finally, tensordict.nn comes with a ProbabilisticTensorDictModule that allows to build distributions from network outputs and get summary statistics or samples from it (along with the distribution parameters):

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> from tensordict.nn.distributions import NormalParamWrapper
>>> from tensordict.nn.prototype import (
...     ProbabilisticTensorDictModule,
...     ProbabilisticTensorDictSequential,
... )
>>> from torch.distributions import Normal
>>> td = TensorDict(
...     {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3]
... )
>>> net = torch.nn.GRUCell(4, 8)
>>> module = TensorDictModule(
...     NormalParamWrapper(net), in_keys=["input", "hidden"], out_keys=["loc", "scale"]
... )
>>> prob_module = ProbabilisticTensorDictModule(
...     in_keys=["loc", "scale"],
...     out_keys=["sample"],
...     distribution_class=Normal,
...     return_log_prob=True,
... )
>>> td_module = ProbabilisticTensorDictSequential(module, prob_module)
>>> td_module(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32),
        input: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        sample_log_prob: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

TensorDictModuleBase(*args, **kwargs)

Base class to TensorDict modules.

TensorDictModule(*args, **kwargs)

A TensorDictModule, is a python wrapper around a nn.Module that reads and writes to a TensorDict.

ProbabilisticTensorDictModule(*args, **kwargs)

A probabilistic TD Module.

TensorDictSequential(*args, **kwargs)

A sequence of TensorDictModules.

TensorDictModuleWrapper(*args, **kwargs)

Wrapper class for TensorDictModule objects.

CudaGraphModule(module[, warmup, in_keys, ...])

A cudagraph wrapper for PyTorch callables.

Ensembles

The functional approach enables a straightforward ensemble implementation. We can duplicate and reinitialize model copies using the tensordict.nn.EnsembleModule

>>> import torch
>>> from torch import nn
>>> from tensordict.nn import TensorDictModule
>>> from torchrl.modules import EnsembleModule
>>> from tensordict import TensorDict
>>> net = nn.Sequential(nn.Linear(4, 32), nn.ReLU(), nn.Linear(32, 2))
>>> mod = TensorDictModule(net, in_keys=['a'], out_keys=['b'])
>>> ensemble = EnsembleModule(mod, num_copies=3)
>>> data = TensorDict({'a': torch.randn(10, 4)}, batch_size=[10])
>>> ensemble(data)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([3, 10, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 10]),
    device=None,
    is_shared=False)

EnsembleModule(*args, **kwargs)

Module that wraps a module and repeats it to form an ensemble.

Compiling TensorDictModules

Since v0.5, TensorDict components are compatible with compile(). For instance, a TensorDictSequential module can be compiled with torch.compile and reach a runtime similar to a regular PyTorch module wrapped in a TensorDictModule.

Distributions

NormalParamExtractor([scale_mapping, scale_lb])

A non-parametric nn.Module that splits its input into loc and scale parameters.

AddStateIndependentNormalScale([...])

A nn.Module that adds trainable state-independent scale parameters.

CompositeDistribution(params, ...[, ...])

A composition of distributions.

Delta(param[, atol, rtol, batch_shape, ...])

Delta distribution.

OneHotCategorical([logits, probs])

One-hot categorical distribution.

TruncatedNormal(loc, scale, a, b[, ...])

Truncated Normal distribution.

Utils

make_tensordict([input_dict, batch_size, device])

Returns a TensorDict created from the keyword arguments or an input dictionary.

dispatch([separator, source, dest, ...])

Allows for a function expecting a TensorDict to be called using kwargs.

set_interaction_type([type])

Sets all ProbabilisticTDModules sampling to the desired type.

inv_softplus(bias)

Inverse softplus function.

biased_softplus(bias[, min_val])

A biased softplus module.

set_skip_existing([mode, in_key_attr, ...])

A context manager for skipping existing nodes in a TensorDict graph.

skip_existing()

Returns whether or not existing entries in a tensordict should be re-computed by a module.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources