TensorDictModule¶
- class tensordict.nn.TensorDictModule(*args, **kwargs)¶
A TensorDictModule, is a python wrapper around a
nn.Module
that reads and writes to a TensorDict.- Parameters:
module (Callable) – a callable, typically a
torch.nn.Module
, used to map the input to the output parameter space. Its forward method can return a single tensor, a tuple of tensors or even a dictionary. In the latter case, the output keys of theTensorDictModule
will be used to populate the output tensordict (ie. the keys present inout_keys
should be present in the dictionary returned by themodule
forward method).in_keys (iterable of NestedKeys, Dict[NestedStr, str]) – keys to be read from input tensordict and passed to the module. If it contains more than one element, the values will be passed in the order given by the in_keys iterable. If
in_keys
is a dictionary, its keys must correspond to the key to be read in the tensordict and its values must match the name of the keyword argument in the function signature.out_keys (iterable of str) – keys to be written to the input tensordict. The length of out_keys must match the number of tensors returned by the embedded module. Using “_” as a key avoid writing tensor to output.
- Keyword Arguments:
inplace (bool or string, optional) –
if
True
(default), the output of the module are written in the tensordict provided to theforward()
method. IfFalse
, a newTensorDict
with and empty batch-size and no device is created. if"empty"
,empty()
will be used to create the output tensordict.Note
If
inplace=False
and the tensordict passed to the module is anotherTensorDictBase
subclass thanTensorDict
, the output will still be aTensorDict
instance. Its batch-size will be empty, and it will have no device. Set to"empty"
to get the sameTensorDictBase
subtype, an identical batch-size and device. Usetensordict_out
at runtime (see below) to have a more fine-grained control over the output.Note
If
inplace=False
and a tensordict_out is passed to theforward()
method, thetensordict_out
will prevail. This is the way one can get a tensordict_out taensordict passed to the module is anotherTensorDictBase
subclass thanTensorDict
, the output will still be aTensorDict
instance.
Embedding a neural network in a TensorDictModule only requires to specify the input and output keys. TensorDictModule support functional and regular
nn.Module
objects. In the functional case, the ‘params’ (and ‘buffers’) keyword argument must be specified:Examples
>>> from tensordict import TensorDict >>> # one can wrap regular nn.Module >>> module = TensorDictModule(nn.Transformer(128), in_keys=["input", "tgt"], out_keys=["out"]) >>> input = torch.ones(2, 3, 128) >>> tgt = torch.zeros(2, 3, 128) >>> data = TensorDict({"input": input, "tgt": tgt}, batch_size=[2, 3]) >>> data = module(data) >>> print(data) TensorDict( fields={ input: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False), out: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False), tgt: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2, 3]), device=None, is_shared=False)
We can also pass directly the tensors
Examples
>>> out = module(input, tgt) >>> assert out.shape == input.shape >>> # we can also wrap regular functions >>> module = TensorDictModule(lambda x: (x-1, x+1), in_keys=[("input", "x")], out_keys=[("output", "x-1"), ("output", "x+1")]) >>> module(TensorDict({("input", "x"): torch.zeros(())}, batch_size=[])) TensorDict( fields={ input: TensorDict( fields={ x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), output: TensorDict( fields={ x+1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), x-1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
We can use TensorDictModule to populate a tensordict:
Examples
>>> module = TensorDictModule(lambda: torch.randn(3), in_keys=[], out_keys=["x"]) >>> print(module(TensorDict({}, batch_size=[]))) TensorDict( fields={ x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
Another feature is passing a dictionary as input keys, to control the dispatching of values to specific keyword arguments.
Examples
>>> module = TensorDictModule(lambda x, *, y: x+y, ... in_keys={'1': 'x', '2': 'y'}, out_keys=['z'], ... ) >>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, [])) >>> td['z'] tensor(3.)
Functional calls to a tensordict module is easy:
Examples
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) >>> module = torch.nn.GRUCell(4, 8) >>> td_module = TensorDictModule( ... module=module, in_keys=["input", "hidden"], out_keys=["output"] ... ) >>> params = TensorDict.from_module(td_module) >>> # functional API >>> with params.to_module(td_module): ... td_functional = td_module(td.clone()) >>> print(td_functional) TensorDict( fields={ hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- In the stateful case:
>>> module = torch.nn.GRUCell(4, 8) >>> td_module = TensorDictModule( ... module=module, in_keys=["input", "hidden"], out_keys=["output"] ... ) >>> td_stateful = td_module(td.clone()) >>> print(td_stateful) TensorDict( fields={ hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- forward(tensordict: TensorDictBase = None, args=None, *, tensordict_out: tensordict.base.TensorDictBase | None = None, **kwargs: Any) TensorDictBase ¶
When the tensordict parameter is not set, kwargs are used to create an instance of TensorDict.