LSTMCell¶
- class torchrl.modules.LSTMCell(input_size: int, hidden_size: int, bias: bool = True, device=None, dtype=None)[source]¶
A long short-term memory (LSTM) cell that performs the same operation as nn.LSTMCell but is fully coded in Python.
Note
This class is implemented without relying on CuDNN, which makes it compatible with
torch.vmap()
andtorch.compile()
.Examples
>>> import torch >>> from torchrl.modules.tensordict_module.rnn import LSTMCell >>> device = torch.device("cuda") if torch.cuda.device_count() else torch.device("cpu") >>> B = 2 >>> N_IN = 10 >>> N_OUT = 20 >>> V = 4 # vector size >>> lstm_cell = LSTMCell(input_size=N_IN, hidden_size=N_OUT, device=device)
# single call >>> x = torch.randn(B, 10, device=device) >>> h0 = torch.zeros(B, 20, device=device) >>> c0 = torch.zeros(B, 20, device=device) >>> with torch.no_grad(): … (h1, c1) = lstm_cell(x, (h0, c0))
# vectorised call - not possible with nn.LSTMCell >>> def call_lstm(x, h, c): … h_out, c_out = lstm_cell(x, (h, c)) … return h_out, c_out >>> batched_call = torch.vmap(call_lstm) >>> x = torch.randn(V, B, 10, device=device) >>> h0 = torch.zeros(V, B, 20, device=device) >>> c0 = torch.zeros(V, B, 20, device=device) >>> with torch.no_grad(): … (h1, c1) = batched_call(x, h0, c0)
A long short-term memory (LSTM) cell.
\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\ c' = f \odot c + i \odot g \\ h' = o \odot \tanh(c') \\ \end{array}\end{split}\]where \(\sigma\) is the sigmoid function, and \(\odot\) is the Hadamard product.
- Parameters:
input_size – The number of expected features in the input x
hidden_size – The number of features in the hidden state h
bias – If
False
, then the layer does not use bias weights b_ih and b_hh. Default:True
- Inputs: input, (h_0, c_0)
input of shape (batch, input_size) or (input_size): tensor containing input features
h_0 of shape (batch, hidden_size) or (hidden_size): tensor containing the initial hidden state
c_0 of shape (batch, hidden_size) or (hidden_size): tensor containing the initial cell state
If (h_0, c_0) is not provided, both h_0 and c_0 default to zero.
- Outputs: (h_1, c_1)
h_1 of shape (batch, hidden_size) or (hidden_size): tensor containing the next hidden state
c_1 of shape (batch, hidden_size) or (hidden_size): tensor containing the next cell state
- Variables:
weight_ih (torch.Tensor) – the learnable input-hidden weights, of shape (4*hidden_size, input_size)
weight_hh (torch.Tensor) – the learnable hidden-hidden weights, of shape (4*hidden_size, hidden_size)
bias_ih – the learnable input-hidden bias, of shape (4*hidden_size)
bias_hh – the learnable hidden-hidden bias, of shape (4*hidden_size)
Note
All the weights and biases are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{1}{\text{hidden\_size}}\)
On certain ROCm devices, when using float16 inputs this module will use different precision for backward.
Examples:
>>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size) >>> input = torch.randn(2, 3, 10) # (time_steps, batch, input_size) >>> hx = torch.randn(3, 20) # (batch, hidden_size) >>> cx = torch.randn(3, 20) >>> output = [] >>> for i in range(input.size()[0]): ... hx, cx = rnn(input[i], (hx, cx)) ... output.append(hx) >>> output = torch.stack(output, dim=0)
- forward(input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) Tuple[Tensor, Tensor] [source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.