TensorDictModule
- class tensordict.nn.TensorDictModule(*args, **kwargs)
A TensorDictModule, is a python wrapper around a
nn.Module
that reads and writes to a TensorDict.- Parameters:
module (Callable[[Any], Any]) – a callable, typically a
torch.nn.Module
, used to map the input to the output parameter space. Its forward method can return a single tensor, a tuple of tensors or even a dictionary. In the latter case, the output keys of theTensorDictModule
will be used to populate the output tensordict (ie. the keys present inout_keys
should be present in the dictionary returned by themodule
forward method).in_keys (iterable of NestedKeys, Dict[NestedStr, str]) – keys to be read from input tensordict and passed to the module. If it contains more than one element, the values will be passed in the order given by the in_keys iterable. If
in_keys
is a dictionary, its keys must correspond to the key to be read in the tensordict and its values must match the name of the keyword argument in the function signature. If out_to_in_map isTrue
, the mapping gets inverted so that the keys correspond to the keyword arguments in the function signature.out_keys (iterable of str) – keys to be written to the input tensordict. The length of out_keys must match the number of tensors returned by the embedded module. Using “_” as a key avoid writing tensor to output.
- Keyword Arguments:
out_to_in_map (bool, optional) –
if
True
, in_keys is read as if the keys are the arguments keys of theforward()
method and the values are the keys in the inputTensorDict
. IfFalse
orNone
(default), keys are considered to be the input keys and values the method’s arguments keys.Warning
The default value of out_to_in_map will change from
False
toTrue
in the v0.9 release.inplace (bool or string, optional) –
if
True
(default), the output of the module are written in the tensordict provided to theforward()
method. IfFalse
, a newTensorDict
with and empty batch-size and no device is created. if"empty"
,empty()
will be used to create the output tensordict.Note
If
inplace=False
and the tensordict passed to the module is anotherTensorDictBase
subclass thanTensorDict
, the output will still be aTensorDict
instance. Its batch-size will be empty, and it will have no device. Set to"empty"
to get the sameTensorDictBase
subtype, an identical batch-size and device. Usetensordict_out
at runtime (see below) to have a more fine-grained control over the output.Note
If
inplace=False
and a tensordict_out is passed to theforward()
method, thetensordict_out
will prevail. This is the way one can get a tensordict_out taensordict passed to the module is anotherTensorDictBase
subclass thanTensorDict
, the output will still be aTensorDict
instance.method (str, optional) – the method to be called in the module, if any. Defaults to __call__.
method_kwargs (Dict[str, Any], optional) – additional keyword arguments to be passed to the module’s method being called.
strict (bool, optional) – if
True
, the module will raise an exception if any of the inputs is missing from the input tensordict. Otherwise, a None value will be used as placeholder. Defaults toFalse
.
Embedding a neural network in a TensorDictModule only requires to specify the input and output keys. TensorDictModule support functional and regular
nn.Module
objects. In the functional case, the ‘params’ (and ‘buffers’) keyword argument must be specified:Examples
>>> from tensordict import TensorDict >>> # one can wrap regular nn.Module >>> module = TensorDictModule(nn.Transformer(128), in_keys=["input", "tgt"], out_keys=["out"]) >>> input = torch.ones(2, 3, 128) >>> tgt = torch.zeros(2, 3, 128) >>> data = TensorDict({"input": input, "tgt": tgt}, batch_size=[2, 3]) >>> data = module(data) >>> print(data) TensorDict( fields={ input: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False), out: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False), tgt: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2, 3]), device=None, is_shared=False)
We can also pass directly the tensors
Examples
>>> out = module(input, tgt) >>> assert out.shape == input.shape >>> # we can also wrap regular functions >>> module = TensorDictModule(lambda x: (x-1, x+1), in_keys=[("input", "x")], out_keys=[("output", "x-1"), ("output", "x+1")]) >>> module(TensorDict({("input", "x"): torch.zeros(())}, batch_size=[])) TensorDict( fields={ input: TensorDict( fields={ x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), output: TensorDict( fields={ x+1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), x-1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
We can use TensorDictModule to populate a tensordict:
Examples
>>> module = TensorDictModule(lambda: torch.randn(3), in_keys=[], out_keys=["x"]) >>> print(module(TensorDict({}, batch_size=[]))) TensorDict( fields={ x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
Another feature is passing a dictionary as input keys, to control the dispatching of values to specific keyword arguments.
Examples
>>> module = TensorDictModule(lambda x, *, y: x+y, ... in_keys={'1': 'x', '2': 'y'}, out_keys=['z'], out_to_in_map=False ... ) >>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, [])) >>> td['z'] tensor(3.)
If out_to_in_map is set to
True
, then the in_keys mapping is reversed. This way, one can use the same input key for different keyword arguments.Examples
>>> module = TensorDictModule(lambda x, *, y, z: x+y+z, ... in_keys={'x': '1', 'y': '2', z: '2'}, out_keys=['t'], out_to_in_map=True ... ) >>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, [])) >>> td['t'] tensor(5.)
We can specify the method to be called within a module. Compared to using a lambda function or similar around the module’s method, this has the advantage that the module attributes (params, buffers, submodules) will be exposed.
Examples
>>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> from torch import nn >>> import torch >>> >>> class MyNet(nn.Module): ... def my_func(self, tensor: torch.Tensor, *, an_integer: int): ... return tensor + an_integer ... >>> s = Seq( ... { ... "a": lambda td: td+1, ... "b": lambda td: td * 2, ... "c": Mod(MyNet(), in_keys=["a"], out_keys=["b"], method="my_func", method_kwargs={"an_integer": 2}), ... } ... ) >>> td = s(TensorDict(a=0)) >>> print(td) >>> >>> assert td["b"] == 4
Functional calls to a tensordict module is easy:
Examples
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) >>> module = torch.nn.GRUCell(4, 8) >>> td_module = TensorDictModule( ... module=module, in_keys=["input", "hidden"], out_keys=["output"] ... ) >>> params = TensorDict.from_module(td_module) >>> # functional API >>> with params.to_module(td_module): ... td_functional = td_module(td.clone()) >>> print(td_functional) TensorDict( fields={ hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- In the stateful case:
>>> module = torch.nn.GRUCell(4, 8) >>> td_module = TensorDictModule( ... module=module, in_keys=["input", "hidden"], out_keys=["output"] ... ) >>> td_stateful = td_module(td.clone()) >>> print(td_stateful) TensorDict( fields={ hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- forward(tensordict: TensorDictBase = None, args=None, *, tensordict_out: tensordict.base.TensorDictBase | None = None, **kwargs: Any) TensorDictBase
When the tensordict parameter is not set, kwargs are used to create an instance of TensorDict.