Shortcuts

GRUModule

class torchrl.modules.GRUModule(*args, **kwargs)[source]

An embedder for an GRU module.

This class adds the following functionality to torch.nn.GRU:

  • Compatibility with TensorDict: the hidden states are reshaped to match the tensordict batch size.

  • Optional multi-step execution: with torch.nn, one has to choose between torch.nn.GRUCell and torch.nn.GRU, the former being compatible with single step inputs and the latter being compatible with multi-step. This class enables both usages.

After construction, the module is not set in recurrent mode, ie. it will expect single steps inputs.

If in recurrent mode, it is expected that the last dimension of the tensordict marks the number of steps. There is no constrain on the dimensionality of the tensordict (except that it must be greater than one for temporal inputs).

Parameters:
  • input_size – The number of expected features in the input x

  • hidden_size – The number of features in the hidden state h

  • num_layers – Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two GRUs together to form a stacked GRU, with the second GRU taking in outputs of the first GRU and computing the final results. Default: 1

  • bias – If False, then the layer does not use bias weights. Default: True

  • dropout – If non-zero, introduces a Dropout layer on the outputs of each GRU layer except the last layer, with dropout probability equal to dropout. Default: 0

  • python_based – If True, will use a full Python implementation of the GRU cell. Default: False

Keyword Arguments:
  • in_key (str or tuple of str) – the input key of the module. Exclusive use with in_keys. If provided, the recurrent keys are assumed to be [“recurrent_state”] and the in_key will be appended before this.

  • in_keys (list of str) – a pair of strings corresponding to the input value and recurrent entry. Exclusive with in_key.

  • out_key (str or tuple of str) – the output key of the module. Exclusive use with out_keys. If provided, the recurrent keys are assumed to be [(“recurrent_state”)] and the out_key will be appended before these.

  • out_keys (list of str) –

    a pair of strings corresponding to the output value, first and second hidden key. .. note:

    For a better integration with TorchRL's environments, the best naming
    for the output hidden key is ``("next", <custom_key>)``, such
    that the hidden values are passed from step to step during a rollout.
    

  • device (torch.device or compatible) – the device of the module.

  • gru (torch.nn.GRU, optional) – a GRU instance to be wrapped. Exclusive with other nn.GRU arguments.

Variables:

recurrent_mode – Returns the recurrent mode of the module.

set_recurrent_mode()[source]

controls whether the module should be executed in recurrent mode.

Examples

>>> from torchrl.envs import TransformedEnv, InitTracker
>>> from torchrl.envs import GymEnv
>>> from torchrl.modules import MLP
>>> from torch import nn
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
>>> gru_module = GRUModule(
...     input_size=env.observation_spec["observation"].shape[-1],
...     hidden_size=64,
...     in_keys=["observation", "rs"],
...     out_keys=["intermediate", ("next", "rs")])
>>> mlp = MLP(num_cells=[64], out_features=1)
>>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> policy(env.reset())
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        intermediate: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
        is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                rs: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)
>>> gru_module_training = gru_module.set_recurrent_mode()
>>> policy_training = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> traj_td = env.rollout(3) # some random temporal data
>>> traj_td = policy_training(traj_td)
>>> print(traj_td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        intermediate: Tensor(shape=torch.Size([3, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        is_init: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                is_init: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                rs: Tensor(shape=torch.Size([3, 1, 64]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([3]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([3]),
    device=cpu,
    is_shared=False)
forward(tensordict: TensorDictBase)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

set_recurrent_mode(mode: bool = True)[source]

Returns a new copy of the module that shares the same gru model but with a different recurrent_mode attribute (if it differs).

A copy is created such that the module can be used with divergent behaviour in various parts of the code (inference vs training):

Examples

>>> from torchrl.envs import GymEnv, TransformedEnv, InitTracker, step_mdp
>>> from torchrl.modules import MLP
>>> from tensordict import TensorDict
>>> from torch import nn
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
>>> gru = nn.GRU(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True)
>>> gru_module = GRUModule(gru=gru, in_keys=["observation", "hidden"], out_keys=["intermediate", ("next", "hidden")])
>>> mlp = MLP(num_cells=[64], out_features=1)
>>> # building two policies with different behaviours:
>>> policy_inference = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> policy_training = Seq(gru_module.set_recurrent_mode(True), Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> traj_td = env.rollout(3) # some random temporal data
>>> traj_td = policy_training(traj_td)
>>> # let's check that both return the same results
>>> td_inf = TensorDict({}, traj_td.shape[:-1])
>>> for td in traj_td.unbind(-1):
...     td_inf = td_inf.update(td.select("is_init", "observation", ("next", "observation")))
...     td_inf = policy_inference(td_inf)
...     td_inf = step_mdp(td_inf)
...
>>> torch.testing.assert_close(td_inf["hidden"], traj_td[..., -1]["next", "hidden"])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources