Shortcuts

LSTMNet

class torchrl.modules.LSTMNet(out_features: int, lstm_kwargs: Dict, mlp_kwargs: Dict, device: DEVICE_TYPING | None = None, *, lstm_backend: str | None = None)[source]

An embedder for an LSTM preceded by an MLP.

The forward method returns the hidden states of the current state (input hidden states) and the output, as the environment returns the ‘observation’ and ‘next_observation’.

Because the LSTM kernel only returns the last hidden state, hidden states are padded with zeros such that they have the right size to be stored in a TensorDict of size [batch x time_steps].

If a 2D tensor is provided as input, it is assumed that it is a batch of data with only one time step. This means that we explicitely assume that users will unsqueeze inputs of a single batch with multiple time steps.

Parameters:
  • out_features (int) – number of output features.

  • lstm_kwargs (dict) – the keyword arguments for the LSTM layer.

  • mlp_kwargs (dict) – the keyword arguments for the MLP layer.

  • device (torch.device, optional) – the device where the module should be instantiated.

Keyword Arguments:

lstm_backend (str, optional) – one of "torchrl" or "torch" that indeicates where the LSTM class is to be retrieved. The "torchrl" backend (LSTM) is slower but works with vmap() and should work with compile(). Defaults to "torch".

Examples

>>> batch = 7
>>> time_steps = 6
>>> in_features = 4
>>> out_features = 10
>>> hidden_size = 5
>>> net = LSTMNet(
...     out_features,
...     {"input_size": hidden_size, "hidden_size": hidden_size},
...     {"out_features": hidden_size},
... )
>>> # test single step vs multi-step
>>> x = torch.randn(batch, time_steps, in_features)  # >3 dims = multi-step
>>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x)
>>> x = torch.randn(batch, in_features)  # 2 dims = single step
>>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x)
forward(input: torch.Tensor, hidden0_in: torch.Tensor | None = None, hidden1_in: torch.Tensor | None = None) Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources