tensordict.nn.dispatch¶
- class tensordict.nn.dispatch(separator='_', source='in_keys', dest='out_keys', auto_batch_size: bool = True)¶
Allows for a function expecting a TensorDict to be called using kwargs.
dispatch()
must be used within modules that have anin_keys
(or another source of keys indicated by thesource
keyword argument) andout_keys
(or anotherdest
key list) attributes indicating what keys to be read and written from the tensordict. The wrapped function should also have atensordict
leading argument.The resulting function will return a single tensor (if there is a single element in out_keys), otherwise it will return a tuple sorted as the
out_keys
of the module.dispatch()
can be used either as a method or as a class when extra arguments need to be passed.- Parameters:
separator (str, optional) – separator that combines sub-keys together for
in_keys
that are tuples of strings. Defaults to"_"
.source (str or list of keys, optional) – if a string is provided, it points to the module attribute that contains the list of input keys to be used. If a list is provided instead, it will contain the keys used as input to the module. Defaults to
"in_keys"
which is the attribute name ofTensorDictModule
list of input keys.dest (str or list of keys, optional) – if a string is provided, it points to the module attribute that contains the list of output keys to be used. If a list is provided instead, it will contain the keys used as output to the module. Defaults to
"out_keys"
which is the attribute name ofTensorDictModule
list of output keys.auto_batch_size (bool, optional) – if
True
, the batch-size of the input tensordict is determined automatically as the maximum number of common dimensions across all the input tensors. Defaults toTrue
.
Examples
>>> class MyModule(nn.Module): ... in_keys = ["a"] ... out_keys = ["b"] ... ... @dispatch ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a'] + 1 ... return tensordict ... >>> module = MyModule() >>> b = module(a=torch.zeros(1, 2)) >>> assert (b == 1).all() >>> # equivalently >>> class MyModule(nn.Module): ... keys_in = ["a"] ... keys_out = ["b"] ... ... @dispatch(source="keys_in", dest="keys_out") ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a'] + 1 ... return tensordict ... >>> module = MyModule() >>> b = module(a=torch.zeros(1, 2)) >>> assert (b == 1).all() >>> # or this >>> class MyModule(nn.Module): ... @dispatch(source=["a"], dest=["b"]) ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a'] + 1 ... return tensordict ... >>> module = MyModule() >>> b = module(a=torch.zeros(1, 2)) >>> assert (b == 1).all()
dispatch_kwargs()
will also work with nested keys with the default"_"
separator.Examples
>>> class MyModuleNest(nn.Module): ... in_keys = [("a", "c")] ... out_keys = ["b"] ... ... @dispatch ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a', 'c'] + 1 ... return tensordict ... >>> module = MyModuleNest() >>> b, = module(a_c=torch.zeros(1, 2)) >>> assert (b == 1).all()
If another separator is wanted, it can be indicated with the
separator
argument in the constructor:Examples
>>> class MyModuleNest(nn.Module): ... in_keys = [("a", "c")] ... out_keys = ["b"] ... ... @dispatch(separator="sep") ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a', 'c'] + 1 ... return tensordict ... >>> module = MyModuleNest() >>> b, = module(asepc=torch.zeros(1, 2)) >>> assert (b == 1).all()
Since the input keys is a sorted sequence of strings,
dispatch()
can also be used with unnamed arguments where the order must match the order of the input keys.Note
If the first argument is a
TensorDictBase
instance, it is assumed that dispatch is __not__ being used and that this tensordict contains all the necessary information to be run through the module. In other words, one cannot decompose a tensordict with the first key of the module inputs pointing to a tensordict instance. In general, it is preferred to usedispatch()
with tensordict leaves only.Examples
>>> class MyModuleNest(nn.Module): ... in_keys = [("a", "c"), "d"] ... out_keys = ["b"] ... ... @dispatch ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a', 'c'] + tensordict["d"] ... return tensordict ... >>> module = MyModuleNest() >>> b, = module(torch.zeros(1, 2), d=torch.ones(1, 2)) # works >>> assert (b == 1).all() >>> b, = module(torch.zeros(1, 2), torch.ones(1, 2)) # works >>> assert (b == 1).all() >>> try: ... b, = module(torch.zeros(1, 2), a_c=torch.ones(1, 2)) # fails ... except: ... print("oopsy!") ...