as_tensordict_module
- class tensordict.nn.as_tensordict_module(*, in_keys: Union[List[NestedKey], NestedKey], out_keys: Union[List[NestedKey], NestedKey])
A decorator that converts a function into a TensorDictModule.
- Parameters:
in_keys (List[NestedKey] | NestedKey | None, optional) – The input keys of the resulting TensorDictModule.
out_keys (List[NestedKey] | NestedKey | None, optional) – The output keys of the resulting TensorDictModule.
- Returns:
A decorator that can be applied to a function to convert it into a TensorDictModule.
- Return type:
Callable
Examples
>>> class MyClass: ... @as_tensordict_module(in_keys="c", out_keys="d") ... def my_method(self, c): ... return c + 1 >>> obj = MyClass() >>> result = obj.my_method(TensorDict(c=0)) >>> print(result["d"]) # prints: 1 >>> @as_tensordict_module(in_keys="c", out_keys="d") ... def my_function(c): ... return c + 1 >>> result = my_function(TensorDict(c=0)) >>> print(result["d"]) # prints: 1