Shortcuts

LSTM

class torchrl.modules.LSTM(input_size: int, hidden_size: int, num_layers: int = 1, batch_first: bool = True, bias: bool = True, dropout: float = 0.0, bidirectional: float = False, proj_size: int = 0, device=None, dtype=None)[source]

A PyTorch module for executing multiple steps of a multi-layer LSTM. The module behaves exactly like torch.nn.LSTM, but this implementation is exclusively coded in Python.

Note

This class is implemented without relying on CuDNN, which makes it compatible with torch.vmap() and torch.compile().

Examples

>>> import torch
>>> from torchrl.modules.tensordict_module.rnn import LSTM
>>> device = torch.device("cuda") if torch.cuda.device_count() else torch.device("cpu")
>>> B = 2
>>> T = 4
>>> N_IN = 10
>>> N_OUT = 20
>>> N_LAYERS = 2
>>> V = 4  # vector size
>>> lstm = LSTM(
...     input_size=N_IN,
...     hidden_size=N_OUT,
...     device=device,
...     num_layers=N_LAYERS,
... )

# single call >>> x = torch.randn(B, T, N_IN, device=device) >>> h0 = torch.zeros(N_LAYERS, B, N_OUT, device=device) >>> c0 = torch.zeros(N_LAYERS, B, N_OUT, device=device) >>> with torch.no_grad(): … h1, c1 = lstm(x, (h0, c0))

# vectorised call - not possible with nn.LSTM >>> def call_lstm(x, h, c): … h_out, c_out = lstm(x, (h, c)) … return h_out, c_out >>> batched_call = torch.vmap(call_lstm) >>> x = torch.randn(V, B, T, 10, device=device) >>> h0 = torch.zeros(V, N_LAYERS, B, N_OUT, device=device) >>> c0 = torch.zeros(V, N_LAYERS, B, N_OUT, device=device) >>> with torch.no_grad(): … h1, c1 = batched_call(x, h0, c0)

__init__(input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,proj_size=0,device=None,dtype=None)

Apply a multi-layer long short-term memory (LSTM) RNN to an input sequence. For each element in the input sequence, each layer computes the following function:

\[\begin{split}\begin{array}{ll} \\ i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\ f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\ g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\ o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\ c_t = f_t \odot c_{t-1} + i_t \odot g_t \\ h_t = o_t \odot \tanh(c_t) \\ \end{array}\end{split}\]

where \(h_t\) is the hidden state at time t, \(c_t\) is the cell state at time t, \(x_t\) is the input at time t, \(h_{t-1}\) is the hidden state of the layer at time t-1 or the initial hidden state at time 0, and \(i_t\), \(f_t\), \(g_t\), \(o_t\) are the input, forget, cell, and output gates, respectively. \(\sigma\) is the sigmoid function, and \(\odot\) is the Hadamard product.

In a multilayer LSTM, the input \(x^{(l)}_t\) of the \(l\) -th layer (\(l \ge 2\)) is the hidden state \(h^{(l-1)}_t\) of the previous layer multiplied by dropout \(\delta^{(l-1)}_t\) where each \(\delta^{(l-1)}_t\) is a Bernoulli random variable which is \(0\) with probability dropout.

If proj_size > 0 is specified, LSTM with projections will be used. This changes the LSTM cell in the following way. First, the dimension of \(h_t\) will be changed from hidden_size to proj_size (dimensions of \(W_{hi}\) will be changed accordingly). Second, the output hidden state of each layer will be multiplied by a learnable projection matrix: \(h_t = W_{hr}h_t\). Note that as a consequence of this, the output of LSTM network will be of different shape as well. See Inputs/Outputs sections below for exact dimensions of all variables. You can find more details in https://arxiv.org/abs/1402.1128.

Parameters:
  • input_size – The number of expected features in the input x

  • hidden_size – The number of features in the hidden state h

  • num_layers – Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two LSTMs together to form a stacked LSTM, with the second LSTM taking in outputs of the first LSTM and computing the final results. Default: 1

  • bias – If False, then the layer does not use bias weights b_ih and b_hh. Default: True

  • batch_first – If True, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature). Note that this does not apply to hidden or cell states. See the Inputs/Outputs sections below for details. Default: False

  • dropout – If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to dropout. Default: 0

  • bidirectional – If True, becomes a bidirectional LSTM. Default: False

  • proj_size – If > 0, will use LSTM with projections of corresponding size. Default: 0

Inputs: input, (h_0, c_0)
  • input: tensor of shape \((L, H_{in})\) for unbatched input, \((L, N, H_{in})\) when batch_first=False or \((N, L, H_{in})\) when batch_first=True containing the features of the input sequence. The input can also be a packed variable length sequence. See torch.nn.utils.rnn.pack_padded_sequence() or torch.nn.utils.rnn.pack_sequence() for details.

  • h_0: tensor of shape \((D * \text{num\_layers}, H_{out})\) for unbatched input or \((D * \text{num\_layers}, N, H_{out})\) containing the initial hidden state for each element in the input sequence. Defaults to zeros if (h_0, c_0) is not provided.

  • c_0: tensor of shape \((D * \text{num\_layers}, H_{cell})\) for unbatched input or \((D * \text{num\_layers}, N, H_{cell})\) containing the initial cell state for each element in the input sequence. Defaults to zeros if (h_0, c_0) is not provided.

where:

\[\begin{split}\begin{aligned} N ={} & \text{batch size} \\ L ={} & \text{sequence length} \\ D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ H_{in} ={} & \text{input\_size} \\ H_{cell} ={} & \text{hidden\_size} \\ H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\ \end{aligned}\end{split}\]
Outputs: output, (h_n, c_n)
  • output: tensor of shape \((L, D * H_{out})\) for unbatched input, \((L, N, D * H_{out})\) when batch_first=False or \((N, L, D * H_{out})\) when batch_first=True containing the output features (h_t) from the last layer of the LSTM, for each t. If a torch.nn.utils.rnn.PackedSequence has been given as the input, the output will also be a packed sequence. When bidirectional=True, output will contain a concatenation of the forward and reverse hidden states at each time step in the sequence.

  • h_n: tensor of shape \((D * \text{num\_layers}, H_{out})\) for unbatched input or \((D * \text{num\_layers}, N, H_{out})\) containing the final hidden state for each element in the sequence. When bidirectional=True, h_n will contain a concatenation of the final forward and reverse hidden states, respectively.

  • c_n: tensor of shape \((D * \text{num\_layers}, H_{cell})\) for unbatched input or \((D * \text{num\_layers}, N, H_{cell})\) containing the final cell state for each element in the sequence. When bidirectional=True, c_n will contain a concatenation of the final forward and reverse cell states, respectively.

Variables:
  • weight_ih_l[k] – the learnable input-hidden weights of the \(\text{k}^{th}\) layer (W_ii|W_if|W_ig|W_io), of shape (4*hidden_size, input_size) for k = 0. Otherwise, the shape is (4*hidden_size, num_directions * hidden_size). If proj_size > 0 was specified, the shape will be (4*hidden_size, num_directions * proj_size) for k > 0

  • weight_hh_l[k] – the learnable hidden-hidden weights of the \(\text{k}^{th}\) layer (W_hi|W_hf|W_hg|W_ho), of shape (4*hidden_size, hidden_size). If proj_size > 0 was specified, the shape will be (4*hidden_size, proj_size).

  • bias_ih_l[k] – the learnable input-hidden bias of the \(\text{k}^{th}\) layer (b_ii|b_if|b_ig|b_io), of shape (4*hidden_size)

  • bias_hh_l[k] – the learnable hidden-hidden bias of the \(\text{k}^{th}\) layer (b_hi|b_hf|b_hg|b_ho), of shape (4*hidden_size)

  • weight_hr_l[k] – the learnable projection weights of the \(\text{k}^{th}\) layer of shape (proj_size, hidden_size). Only present when proj_size > 0 was specified.

  • weight_ih_l[k]_reverse – Analogous to weight_ih_l[k] for the reverse direction. Only present when bidirectional=True.

  • weight_hh_l[k]_reverse – Analogous to weight_hh_l[k] for the reverse direction. Only present when bidirectional=True.

  • bias_ih_l[k]_reverse – Analogous to bias_ih_l[k] for the reverse direction. Only present when bidirectional=True.

  • bias_hh_l[k]_reverse – Analogous to bias_hh_l[k] for the reverse direction. Only present when bidirectional=True.

  • weight_hr_l[k]_reverse – Analogous to weight_hr_l[k] for the reverse direction. Only present when bidirectional=True and proj_size > 0 was specified.

Note

All the weights and biases are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{1}{\text{hidden\_size}}\)

Note

For bidirectional LSTMs, forward and backward are directions 0 and 1 respectively. Example of splitting the output layers when batch_first=False: output.view(seq_len, batch, num_directions, hidden_size).

Note

For bidirectional LSTMs, h_n is not equivalent to the last element of output; the former contains the final forward and reverse hidden states, while the latter contains the final forward hidden state and the initial reverse hidden state.

Note

batch_first argument is ignored for unbatched inputs.

Note

proj_size should be smaller than hidden_size.

Examples:

>>> rnn = nn.LSTM(10, 20, 2)
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> c0 = torch.randn(2, 3, 20)
>>> output, (hn, cn) = rnn(input, (h0, c0))
forward(input, hx=None)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources