Shortcuts

tensordict.nn.dispatch

class tensordict.nn.dispatch(separator='_', source='in_keys', dest='out_keys', auto_batch_size: bool = True)

Allows for a function expecting a TensorDict to be called using kwargs.

dispatch() must be used within modules that have an in_keys (or another source of keys indicated by the source keyword argument) and out_keys (or another dest key list) attributes indicating what keys to be read and written from the tensordict. The wrapped function should also have a tensordict leading argument.

The resulting function will return a single tensor (if there is a single element in out_keys), otherwise it will return a tuple sorted as the out_keys of the module.

dispatch() can be used either as a method or as a class when extra arguments need to be passed.

Parameters:
  • separator (str, optional) – separator that combines sub-keys together for in_keys that are tuples of strings. Defaults to "_".

  • source (str or list of keys, optional) – if a string is provided, it points to the module attribute that contains the list of input keys to be used. If a list is provided instead, it will contain the keys used as input to the module. Defaults to "in_keys" which is the attribute name of TensorDictModule list of input keys.

  • dest (str or list of keys, optional) – if a string is provided, it points to the module attribute that contains the list of output keys to be used. If a list is provided instead, it will contain the keys used as output to the module. Defaults to "out_keys" which is the attribute name of TensorDictModule list of output keys.

  • auto_batch_size (bool, optional) – if True, the batch-size of the input tensordict is determined automatically as the maximum number of common dimensions across all the input tensors. Defaults to True.

Examples

>>> class MyModule(nn.Module):
...     in_keys = ["a"]
...     out_keys = ["b"]
...
...     @dispatch
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a'] + 1
...         return tensordict
...
>>> module = MyModule()
>>> b = module(a=torch.zeros(1, 2))
>>> assert (b == 1).all()
>>> # equivalently
>>> class MyModule(nn.Module):
...     keys_in = ["a"]
...     keys_out = ["b"]
...
...     @dispatch(source="keys_in", dest="keys_out")
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a'] + 1
...         return tensordict
...
>>> module = MyModule()
>>> b = module(a=torch.zeros(1, 2))
>>> assert (b == 1).all()
>>> # or this
>>> class MyModule(nn.Module):
...     @dispatch(source=["a"], dest=["b"])
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a'] + 1
...         return tensordict
...
>>> module = MyModule()
>>> b = module(a=torch.zeros(1, 2))
>>> assert (b == 1).all()

dispatch_kwargs() will also work with nested keys with the default "_" separator.

Examples

>>> class MyModuleNest(nn.Module):
...     in_keys = [("a", "c")]
...     out_keys = ["b"]
...
...     @dispatch
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a', 'c'] + 1
...         return tensordict
...
>>> module = MyModuleNest()
>>> b, = module(a_c=torch.zeros(1, 2))
>>> assert (b == 1).all()

If another separator is wanted, it can be indicated with the separator argument in the constructor:

Examples

>>> class MyModuleNest(nn.Module):
...     in_keys = [("a", "c")]
...     out_keys = ["b"]
...
...     @dispatch(separator="sep")
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a', 'c'] + 1
...         return tensordict
...
>>> module = MyModuleNest()
>>> b, = module(asepc=torch.zeros(1, 2))
>>> assert (b == 1).all()

Since the input keys is a sorted sequence of strings, dispatch() can also be used with unnamed arguments where the order must match the order of the input keys.

Note

If the first argument is a TensorDictBase instance, it is assumed that dispatch is __not__ being used and that this tensordict contains all the necessary information to be run through the module. In other words, one cannot decompose a tensordict with the first key of the module inputs pointing to a tensordict instance. In general, it is preferred to use dispatch() with tensordict leaves only.

Examples

>>> class MyModuleNest(nn.Module):
...     in_keys = [("a", "c"), "d"]
...     out_keys = ["b"]
...
...     @dispatch
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a', 'c'] + tensordict["d"]
...         return tensordict
...
>>> module = MyModuleNest()
>>> b, = module(torch.zeros(1, 2), d=torch.ones(1, 2))  # works
>>> assert (b == 1).all()
>>> b, = module(torch.zeros(1, 2), torch.ones(1, 2))  # works
>>> assert (b == 1).all()
>>> try:
...     b, = module(torch.zeros(1, 2), a_c=torch.ones(1, 2))  # fails
... except:
...     print("oopsy!")
...

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources