WrapModule
- class tensordict.nn.WrapModule(*args, **kwargs)
A wrapper around any callable that processes TensorDict instances.
This wrapper is useful when building
TensorDictSequential
stacks and when a transform requires the entire TensorDict instance to be visible.- Parameters:
func (Callable[[TensorDictBase], TensorDictBase]) – A callable function that takes in a TensorDictBase instance and returns a transformed TensorDictBase instance.
- Keyword Arguments:
inplace (bool, optional) – If
True
, the input TensorDict will be modified in-place. Otherwise, a new TensorDict will be returned (if the function does not modify it in-place and returns it). Defaults toFalse
.in_keys (list of NestedKey, optional) – if provided, indicates what entries are read by the module. This will not be checked and is provided just for the purpose of informing
TensorDictSequential
about the input keys of the wrapped module. Defaults to [].out_keys (list of NestedKey, optional) – if provided, indicates what entries are written by the module. This will not be checked and is provided just for the purpose of informing
TensorDictSequential
about the output keys of the wrapped module. Defaults to [].
Examples
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod, WrapModule >>> seq = Seq( ... Mod(lambda x: x * 2, in_keys=["x"], out_keys=["y"]), ... WrapModule(lambda td: td.reshape(-1)), ... ) >>> td = TensorDict(x=torch.ones(3, 4, 5), batch_size=[3, 4]) >>> td = Seq(td) >>> assert td.shape == (12,) >>> assert (td["y"] == 2).all() >>> assert td["y"].shape == (12, 5)
- forward(data: TensorDictBase) TensorDictBase
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.