GRUModule¶
- class torchrl.modules.GRUModule(*args, **kwargs)[source]¶
An embedder for an GRU module.
This class adds the following functionality to
torch.nn.GRU
:Compatibility with TensorDict: the hidden states are reshaped to match the tensordict batch size.
Optional multi-step execution: with torch.nn, one has to choose between
torch.nn.GRUCell
andtorch.nn.GRU
, the former being compatible with single step inputs and the latter being compatible with multi-step. This class enables both usages.
After construction, the module is not set in recurrent mode, ie. it will expect single steps inputs.
If in recurrent mode, it is expected that the last dimension of the tensordict marks the number of steps. There is no constrain on the dimensionality of the tensordict (except that it must be greater than one for temporal inputs).
- Parameters:
input_size – The number of expected features in the input x
hidden_size – The number of features in the hidden state h
num_layers – Number of recurrent layers. E.g., setting
num_layers=2
would mean stacking two GRUs together to form a stacked GRU, with the second GRU taking in outputs of the first GRU and computing the final results. Default: 1bias – If
False
, then the layer does not use bias weights. Default:True
dropout – If non-zero, introduces a Dropout layer on the outputs of each GRU layer except the last layer, with dropout probability equal to
dropout
. Default: 0python_based – If
True
, will use a full Python implementation of the GRU cell. Default:False
- Keyword Arguments:
in_key (str or tuple of str) – the input key of the module. Exclusive use with
in_keys
. If provided, the recurrent keys are assumed to be [“recurrent_state”] and thein_key
will be appended before this.in_keys (list of str) – a pair of strings corresponding to the input value and recurrent entry. Exclusive with
in_key
.out_key (str or tuple of str) – the output key of the module. Exclusive use with
out_keys
. If provided, the recurrent keys are assumed to be [(“recurrent_state”)] and theout_key
will be appended before these.out_keys (list of str) –
a pair of strings corresponding to the output value, first and second hidden key. .. note:
For a better integration with TorchRL's environments, the best naming for the output hidden key is ``("next", <custom_key>)``, such that the hidden values are passed from step to step during a rollout.
device (torch.device or compatible) – the device of the module.
gru (torch.nn.GRU, optional) – a GRU instance to be wrapped. Exclusive with other nn.GRU arguments.
- Variables:
recurrent_mode – Returns the recurrent mode of the module.
- make_tensordict_primer()[source]¶
creates the TensorDictPrimer transforms for the environment to be aware of the recurrent states of the RNN.
Note
This module relies on specific
recurrent_state
keys being present in the input TensorDicts. To generate aTensorDictPrimer
transform that will automatically add hidden states to the environment TensorDicts, use the methodmake_tensordict_primer()
. If this class is a submodule in a larger module, the methodget_primers_from_module()
can be called on the parent module to automatically generate the primer transforms required for all submodules, including this one.Examples
>>> from torchrl.envs import TransformedEnv, InitTracker >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> gru_module = GRUModule( ... input_size=env.observation_spec["observation"].shape[-1], ... hidden_size=64, ... in_keys=["observation", "rs"], ... out_keys=["intermediate", ("next", "rs")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy(env.reset()) TensorDict( fields={ action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), intermediate: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False), is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ rs: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False) >>> gru_module_training = gru_module.set_recurrent_mode() >>> policy_training = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> traj_td = env.rollout(3) # some random temporal data >>> traj_td = policy_training(traj_td) >>> print(traj_td) TensorDict( fields={ action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), intermediate: Tensor(shape=torch.Size([3, 64]), device=cpu, dtype=torch.float32, is_shared=False), is_init: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), is_init: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), rs: Tensor(shape=torch.Size([3, 1, 64]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([3]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([3]), device=cpu, is_shared=False)
- forward(tensordict: TensorDictBase)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- make_tensordict_primer()[source]¶
Makes a tensordict primer for the environment.
A
TensorDictPrimer
object will ensure that the policy is aware of the supplementary inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across processes and dealt with properly.Not including a
TensorDictPrimer
in the environment may result in poorly defined behaviors, for instance in parallel settings where a step involves copying the new recurrent state from"next"
to the root tensordict, which the meth:~torchrl.EnvBase.step_mdp method will not be able to do as the recurrent states are not registered within the environment specs.See
torchrl.modules.utils.get_primers_from_module()
for a method to generate all primers for a given module.Examples
>>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs import TransformedEnv, InitTracker >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP, LSTMModule >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> gru_module = GRUModule( ... input_size=env.observation_spec["observation"].shape[-1], ... hidden_size=64, ... in_keys=["observation", "rs"], ... out_keys=["intermediate", ("next", "rs")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy(env.reset()) >>> env = env.append_transform(gru_module.make_tensordict_primer()) >>> data_collector = SyncDataCollector( ... env, ... policy, ... frames_per_batch=10 ... ) >>> for data in data_collector: ... print(data) ... break
- set_recurrent_mode(mode: bool = True)[source]¶
Returns a new copy of the module that shares the same gru model but with a different
recurrent_mode
attribute (if it differs).A copy is created such that the module can be used with divergent behavior in various parts of the code (inference vs training):
Examples
>>> from torchrl.envs import GymEnv, TransformedEnv, InitTracker, step_mdp >>> from torchrl.modules import MLP >>> from tensordict import TensorDict >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> gru = nn.GRU(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True) >>> gru_module = GRUModule(gru=gru, in_keys=["observation", "hidden"], out_keys=["intermediate", ("next", "hidden")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> # building two policies with different behaviors: >>> policy_inference = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy_training = Seq(gru_module.set_recurrent_mode(True), Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> traj_td = env.rollout(3) # some random temporal data >>> traj_td = policy_training(traj_td) >>> # let's check that both return the same results >>> td_inf = TensorDict({}, traj_td.shape[:-1]) >>> for td in traj_td.unbind(-1): ... td_inf = td_inf.update(td.select("is_init", "observation", ("next", "observation"))) ... td_inf = policy_inference(td_inf) ... td_inf = step_mdp(td_inf) ... >>> torch.testing.assert_close(td_inf["hidden"], traj_td[..., -1]["next", "hidden"])