- class tensordict.nn.TensorDictSequential(*args, **kwargs)
A sequence of TensorDictModules.
Similarly to
which passes a tensor through a chain of mappings that read and write a single tensor each, this module will read and write over a tensordict by querying each of the input modules. When calling aTensorDictSequencial
instance with a functional module, it is expected that the parameter lists (and buffers) will be concatenated in a single list.- Parameters:
modules (OrderedDict[str, Callable[[TensorDictBase], TensorDictBase]] | List[Callable[[TensorDictBase], TensorDictBase]]) – ordered sequence of callables that take a TensorDictBase as input and return a TensorDictBase. These can be instances of TensorDictModuleBase or any other function that matches this signature. Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked, and thus will not affect the in_keys and out_keys attributes of the TensorDictSequential. Regular
inputs will be converted toOrderedDict
if necessary.- Keyword Arguments:
partial_tolerant (bool, optional) – if True, the input tensordict can miss some of the input keys. If so, the only module that will be executed are those who can be executed given the keys that are present. Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is
AND if the stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts looking for those that have the required keys, if any. Defaults to False.selected_out_keys (iterable of NestedKeys, optional) – the list of out-keys to select. If not provided, all
will be written.inplace (bool, optional) – if True, the input tensordict is modified in-place. If False, a new empty
instance is created. If “empty”, input.empty() is used instead (ie, the output preserves type, device and batch-size). Defaults to None (relies on sub-modules).
instance may have a long list of output keys, and one may wish to remove some of them after execution for clarity or memory purposes. If this is the case, the methodselect_out_keys()
can be used after instantiation, or selected_out_keys may be passed to the constructor.Examples
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> torch.manual_seed(0) >>> module = TensorDictSequential( ... TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["x+1"]), ... TensorDictModule(nn.Linear(3, 4), in_keys=["x+1"], out_keys=["w*(x+1)+b"]), ... ) >>> # with tensordict input >>> print(module(TensorDict({"x": torch.zeros(3)}, []))) TensorDict( fields={ w*(x+1)+b: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), x+1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> # with tensor input: returns all the output keys in the order of the modules, ie "x+1" and "w*(x+1)+b" >>> module(x=torch.zeros(3)) (tensor([1., 1., 1.]), tensor([-0.7214, -0.8748, 0.1571, -0.1138], grad_fn=<AddBackward0>)) >>> module(torch.zeros(3)) (tensor([1., 1., 1.]), tensor([-0.7214, -0.8748, 0.1571, -0.1138], grad_fn=<AddBackward0>))
TensorDictSequence supports functional, modular and vmap coding.
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import ( ... ProbabilisticTensorDictModule, ... ProbabilisticTensorDictSequential, ... TensorDictModule, ... TensorDictSequential, ... ) >>> from tensordict.nn.distributions import NormalParamExtractor >>> from tensordict.nn.functional_modules import make_functional >>> from torch.distributions import Normal >>> td = TensorDict({"input": torch.randn(3, 4)}, [3,]) >>> net1 = torch.nn.Linear(4, 8) >>> module1 = TensorDictModule(net1, in_keys=["input"], out_keys=["params"]) >>> normal_params = TensorDictModule( ... NormalParamExtractor(), in_keys=["params"], out_keys=["loc", "scale"] ... ) >>> td_module1 = ProbabilisticTensorDictSequential( ... module1, ... normal_params, ... ProbabilisticTensorDictModule( ... in_keys=["loc", "scale"], ... out_keys=["hidden"], ... distribution_class=Normal, ... return_log_prob=True, ... ) ... ) >>> module2 = torch.nn.Linear(4, 8) >>> td_module2 = TensorDictModule( ... module=module2, in_keys=["hidden"], out_keys=["output"] ... ) >>> td_module = TensorDictSequential(td_module1, td_module2) >>> params = TensorDict.from_module(td_module) >>> with params.to_module(td_module): ... _ = td_module(td) >>> print(td) TensorDict( fields={ hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), params: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- In the vmap case:
>>> from torch import vmap >>> params = params.expand(4) >>> def func(td, params): ... with params.to_module(td_module): ... return td_module(td) >>> td_vmap = vmap(func, (None, 0))(td, params) >>> print(td_vmap) TensorDict( fields={ hidden: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), output: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False), params: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 3]), device=None, is_shared=False)
- forward(tensordict: TensorDictBase = None, tensordict_out: tensordict.base.TensorDictBase | None = None, **kwargs: Any) TensorDictBase
When the tensordict parameter is not set, kwargs are used to create an instance of TensorDict.
- reset_out_keys()
Resets the
attribute to its orignal value.Returns: the same module, with its original
>>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> import torch >>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"]) >>> mod.select_out_keys("d") >>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []) >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> mod.reset_out_keys() >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- select_out_keys(*selected_out_keys) TensorDictSequential
Selects the keys that will be found in the output tensordict.
This is useful whenever one wants to get rid of intermediate keys in a complicated graph, or when the presence of these keys may trigger unexpected behaviours.
The original
can still be accessed viamodule.out_keys_source
.- Parameters:
*out_keys (a sequence of strings or tuples of strings) – the out_keys that should be found in the output tensordict.
Returns: the same module, modified in-place with updated
.The simplest usage is with
>>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> import torch >>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"]) >>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []) >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> mod.select_out_keys("d") >>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []) >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
This feature will also work with dispatched arguments: .. rubric:: Examples
>>> mod(torch.zeros(()), torch.ones(())) tensor(2.)
This change will occur in-place (ie the same module will be returned with an updated list of out_keys). It can be reverted using the
>>> mod.reset_out_keys() >>> mod(TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
This will work with other classes too, such as Sequential: .. rubric:: Examples
>>> from tensordict.nn import TensorDictSequential >>> seq = TensorDictSequential( ... TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["y"]), ... TensorDictModule(lambda x: x+1, in_keys=["y"], out_keys=["z"]), ... ) >>> td = TensorDict({"x": torch.zeros(())}, []) >>> seq(td) TensorDict( fields={ x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), y: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> seq.select_out_keys("z") >>> td = TensorDict({"x": torch.zeros(())}, []) >>> seq(td) TensorDict( fields={ x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- select_subsequence(in_keys: Optional[Iterable[NestedKey]] = None, out_keys: Optional[Iterable[NestedKey]] = None) TensorDictSequential
Returns a new TensorDictSequential with only the modules that are necessary to compute the given output keys with the given input keys.
- Parameters:
in_keys – input keys of the subsequence we want to select. All the keys absent from
will be considered as non-relevant, and modules that just take these keys as inputs will be discarded. The resulting sequential module will follow the pattern “all the modules which output will be affected by a different value for any in <in_keys>”. If none is provided, the module’sin_keys
are assumed.out_keys – output keys of the subsequence we want to select. Only the modules that are necessary to get the
will be found in the resulting sequence. The resulting sequential module will follow the pattern “all the modules that condition the value or <out_keys> entries.” If none is provided, the module’sout_keys
are assumed.
- Returns:
A new TensorDictSequential with only the modules that are necessary acording to the given input and output keys.
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> idn = lambda x: x >>> module = Seq( ... Mod(idn, in_keys=["a"], out_keys=["b"]), ... Mod(idn, in_keys=["b"], out_keys=["c"]), ... Mod(idn, in_keys=["c"], out_keys=["d"]), ... Mod(idn, in_keys=["a"], out_keys=["e"]), ... ) >>> # select all modules whose output depend on "a" >>> module.select_subsequence(in_keys=["a"]) TensorDictSequential( module=ModuleList( (0): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['a'], out_keys=['b']) (1): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['b'], out_keys=['c']) (2): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['c'], out_keys=['d']) (3): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['a'], out_keys=['e']) ), device=cpu, in_keys=['a'], out_keys=['b', 'c', 'd', 'e']) >>> # select all modules whose output depend on "c" >>> module.select_subsequence(in_keys=["c"]) TensorDictSequential( module=ModuleList( (0): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['c'], out_keys=['d']) ), device=cpu, in_keys=['c'], out_keys=['d']) >>> # select all modules that affect the value of "c" >>> module.select_subsequence(out_keys=["c"]) TensorDictSequential( module=ModuleList( (0): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['a'], out_keys=['b']) (1): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['b'], out_keys=['c']) ), device=cpu, in_keys=['a'], out_keys=['b', 'c']) >>> # select all modules that affect the value of "e" >>> module.select_subsequence(out_keys=["e"]) TensorDictSequential( module=ModuleList( (0): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['a'], out_keys=['e']) ), device=cpu, in_keys=['a'], out_keys=['e'])
This method propagates to nested sequential:
>>> module = Seq( ... Seq( ... Mod(idn, in_keys=["a"], out_keys=["b"]), ... Mod(idn, in_keys=["b"], out_keys=["c"]), ... ), ... Seq( ... Mod(idn, in_keys=["b"], out_keys=["d"]), ... Mod(idn, in_keys=["d"], out_keys=["e"]), ... ), ... ) >>> # select submodules whose output will be affected by a change in "b" or "d" AND which output is "e" >>> module.select_subsequence(in_keys=["b", "d"], out_keys=["e"]) TensorDictSequential( module=ModuleList( (0): TensorDictSequential( module=ModuleList( (0): TensorDictModule( module=<function <lambda> at 0x129efae50>, device=cpu, in_keys=['b'], out_keys=['d']) (1): TensorDictModule( module=<function <lambda> at 0x129efae50>, device=cpu, in_keys=['d'], out_keys=['e']) ), device=cpu, in_keys=['b'], out_keys=['d', 'e']) ), device=cpu, in_keys=['b'], out_keys=['d', 'e'])
The inplace argument allows for a fine-grained control over the output type, allowing for instance to write the result of the computational graph in the input object without tracking the intermediate tensors.
>>> import torch >>> from tensordict import TensorClass >>> from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq >>> >>> class MyClass(TensorClass): ... input: torch.Tensor ... output: torch.Tensor | None = None >>> >>> obj = MyClass(torch.randn(2, 3), batch_size=(2,)) >>> >>> model = Seq( ... Mod( ... lambda x: (x + 1, x - 1), ... in_keys=["input"], ... out_keys=[("intermediate", "0"), ("intermediate", "1")], ... inplace=False ... ), ... Mod( ... lambda y0, y1: y0 * y1, ... in_keys=[("intermediate", "0"), ("intermediate", "1")], ... out_keys=["output"], ... inplace=False ... ), ... inplace=True, ) >>> print(model(obj)) MyClass( input=Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), output=Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), output=None, batch_size=torch.Size([2]), device=None, is_shared=False)