Shortcuts

WrapModule

class tensordict.nn.WrapModule(*args, **kwargs)

A wrapper around any callable that processes TensorDict instances.

This wrapper is useful when building TensorDictSequential stacks and when a transform requires the entire TensorDict instance to be visible.

Parameters:

func (Callable[[TensorDictBase], TensorDictBase]) – A callable function that takes in a TensorDictBase instance and returns a transformed TensorDictBase instance.

Keyword Arguments:
  • inplace (bool, optional) – If True, the input TensorDict will be modified in-place. Otherwise, a new TensorDict will be returned (if the function does not modify it in-place and returns it). Defaults to False.

  • in_keys (list of NestedKey, optional) – if provided, indicates what entries are read by the module. This will not be checked and is provided just for the purpose of informing TensorDictSequential about the input keys of the wrapped module. Defaults to [].

  • out_keys (list of NestedKey, optional) – if provided, indicates what entries are written by the module. This will not be checked and is provided just for the purpose of informing TensorDictSequential about the output keys of the wrapped module. Defaults to [].

Examples

>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod, WrapModule
>>> seq = Seq(
...     Mod(lambda x: x * 2, in_keys=["x"], out_keys=["y"]),
...     WrapModule(lambda td: td.reshape(-1)),
... )
>>> td = TensorDict(x=torch.ones(3, 4, 5), batch_size=[3, 4])
>>> td = Seq(td)
>>> assert td.shape == (12,)
>>> assert (td["y"] == 2).all()
>>> assert td["y"].shape == (12, 5)
forward(data: TensorDictBase) TensorDictBase

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources