TensorDictModuleBase¶
- class tensordict.nn.TensorDictModuleBase(*args, **kwargs)¶
Base class to TensorDict modules.
TensorDictModule subclasses are characterized by
in_keys
andout_keys
key-lists that indicate what input entries are to be read and what output entries should be expected to be written.The forward method input/output signature should always follow the convention:
>>> tensordict_out = module.forward(tensordict_in)
- static is_tdmodule_compatible(module)¶
Checks if a module is compatible with TensorDictModule API.
- reset_out_keys()¶
Resets the
out_keys
attribute to its orignal value.Returns: the same module, with its original
out_keys
values.Examples
>>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> import torch >>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"]) >>> mod.select_out_keys("d") >>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []) >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> mod.reset_out_keys() >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- reset_parameters_recursive(parameters: Optional[TensorDictBase] = None) Optional[TensorDictBase] ¶
Recursively reset the parameters of the module and its children.
- Parameters:
parameters (TensorDict of parameters, optional) – If set to None, the module will reset using self.parameters(). Otherwise, we will reset the parameters in the tensordict in-place. This is useful for functional modules where the parameters are not stored in the module itself.
- Returns:
A tensordict of the new parameters, only if parameters was not None.
Examples
>>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> net = nn.Sequential(nn.Linear(2,3), nn.ReLU()) >>> old_param = net[0].weight.clone() >>> module = TensorDictModule(net, in_keys=['bork'], out_keys=['dork']) >>> module.reset_parameters() >>> (old_param == net[0].weight).any() tensor(False)
This method also supports functional parameter sampling:
>>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> net = nn.Sequential(nn.Linear(2,3), nn.ReLU()) >>> module = TensorDictModule(net, in_keys=['bork'], out_keys=['dork']) >>> params = TensorDict.from_module(module) >>> old_params = params.clone(recurse=True) >>> module.reset_parameters(params) >>> (old_params == params).any() False
- select_out_keys(*out_keys) TensorDictModuleBase ¶
Selects the keys that will be found in the output tensordict.
This is useful whenever one wants to get rid of intermediate keys in a complicated graph, or when the presence of these keys may trigger unexpected behaviours.
The original
out_keys
can still be accessed viamodule.out_keys_source
.- Parameters:
*out_keys (a sequence of strings or tuples of strings) – the out_keys that should be found in the output tensordict.
Returns: the same module, modified in-place with updated
out_keys
.The simplest usage is with
TensorDictModule
:Examples
>>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> import torch >>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"]) >>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []) >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> mod.select_out_keys("d") >>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []) >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
This feature will also work with dispatched arguments: .. rubric:: Examples
>>> mod(torch.zeros(()), torch.ones(())) tensor(2.)
This change will occur in-place (ie the same module will be returned with an updated list of out_keys). It can be reverted using the
TensorDictModuleBase.reset_out_keys()
method.Examples
>>> mod.reset_out_keys() >>> mod(TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
This will work with other classes too, such as Sequential: .. rubric:: Examples
>>> from tensordict.nn import TensorDictSequential >>> seq = TensorDictSequential( ... TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["y"]), ... TensorDictModule(lambda x: x+1, in_keys=["y"], out_keys=["z"]), ... ) >>> td = TensorDict({"x": torch.zeros(())}, []) >>> seq(td) TensorDict( fields={ x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), y: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> seq.select_out_keys("z") >>> td = TensorDict({"x": torch.zeros(())}, []) >>> seq(td) TensorDict( fields={ x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)