LSTMModule¶
- class torchrl.modules.LSTMModule(*args, **kwargs)[source]¶
An embedder for an LSTM module.
This class adds the following functionality to
torch.nn.LSTM
:Compatibility with TensorDict: the hidden states are reshaped to match the tensordict batch size.
Optional multi-step execution: with torch.nn, one has to choose between
torch.nn.LSTMCell
andtorch.nn.LSTM
, the former being compatible with single step inputs and the latter being compatible with multi-step. This class enables both usages.
After construction, the module is not set in recurrent mode, ie. it will expect single steps inputs.
If in recurrent mode, it is expected that the last dimension of the tensordict marks the number of steps. There is no constrain on the dimensionality of the tensordict (except that it must be greater than one for temporal inputs).
Note
This class can handle multiple consecutive trajectories along the time dimension but the final hidden values should not be trusted in those cases (ie. they should not be re-used for a consecutive trajectory). The reason is that LSTM returns only the last hidden value, which for the padded inputs we provide can correspont to a 0-filled input.
- Parameters:
input_size – The number of expected features in the input x
hidden_size – The number of features in the hidden state h
num_layers – Number of recurrent layers. E.g., setting
num_layers=2
would mean stacking two LSTMs together to form a stacked LSTM, with the second LSTM taking in outputs of the first LSTM and computing the final results. Default: 1bias – If
False
, then the layer does not use bias weights b_ih and b_hh. Default:True
dropout – If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to
dropout
. Default: 0python_based – If
True
, will use a full Python implementation of the LSTM cell. Default:False
- Keyword Arguments:
in_key (str or tuple of str) – the input key of the module. Exclusive use with
in_keys
. If provided, the recurrent keys are assumed to be [“recurrent_state_h”, “recurrent_state_c”] and thein_key
will be appended before these.in_keys (list of str) – a triplet of strings corresponding to the input value, first and second hidden key. Exclusive with
in_key
.out_key (str or tuple of str) – the output key of the module. Exclusive use with
out_keys
. If provided, the recurrent keys are assumed to be [(“next”, “recurrent_state_h”), (“next”, “recurrent_state_c”)] and theout_key
will be appended before these.out_keys (list of str) –
a triplet of strings corresponding to the output value, first and second hidden key. .. note:
For a better integration with TorchRL's environments, the best naming for the output hidden key is ``("next", <custom_key>)``, such that the hidden values are passed from step to step during a rollout.
device (torch.device or compatible) – the device of the module.
lstm (torch.nn.LSTM, optional) – an LSTM instance to be wrapped. Exclusive with other nn.LSTM arguments.
- Variables:
recurrent_mode – Returns the recurrent mode of the module.
Examples
>>> from torchrl.envs import TransformedEnv, InitTracker >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> lstm_module = LSTMModule( ... input_size=env.observation_spec["observation"].shape[-1], ... hidden_size=64, ... in_keys=["observation", "rs_h", "rs_c"], ... out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy(env.reset()) TensorDict( fields={ action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), intermediate: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False), is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ rs_c: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False), rs_h: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)
- forward(tensordict: TensorDictBase)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- set_recurrent_mode(mode: bool = True)[source]¶
Returns a new copy of the module that shares the same lstm model but with a different
recurrent_mode
attribute (if it differs).A copy is created such that the module can be used with divergent behaviour in various parts of the code (inference vs training):
Examples
>>> from torchrl.envs import TransformedEnv, InitTracker, step_mdp >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP >>> from tensordict import TensorDict >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> lstm = nn.LSTM(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True) >>> lstm_module = LSTMModule(lstm=lstm, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> # building two policies with different behaviours: >>> policy_inference = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy_training = Seq(lstm_module.set_recurrent_mode(True), Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> traj_td = env.rollout(3) # some random temporal data >>> traj_td = policy_training(traj_td) >>> # let's check that both return the same results >>> td_inf = TensorDict({}, traj_td.shape[:-1]) >>> for td in traj_td.unbind(-1): ... td_inf = td_inf.update(td.select("is_init", "observation", ("next", "observation"))) ... td_inf = policy_inference(td_inf) ... td_inf = step_mdp(td_inf) ... >>> torch.testing.assert_close(td_inf["hidden0"], traj_td[..., -1]["next", "hidden0"])