# LSTMCell¶

class torch.nn.LSTMCell(input_size: int, hidden_size: int, bias: bool = True)[source]

A long short-term memory (LSTM) cell.

$\begin{array}{ll} i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}$

where $\sigma$ is the sigmoid function, and $*$ is the Hadamard product.

Parameters
• input_size – The number of expected features in the input x

• hidden_size – The number of features in the hidden state h

• bias – If False, then the layer does not use bias weights b_ih and b_hh. Default: True

Inputs: input, (h_0, c_0)
• input of shape (batch, input_size): tensor containing input features

• h_0 of shape (batch, hidden_size): tensor containing the initial hidden state for each element in the batch.

• c_0 of shape (batch, hidden_size): tensor containing the initial cell state for each element in the batch.

If (h_0, c_0) is not provided, both h_0 and c_0 default to zero.

Outputs: (h_1, c_1)
• h_1 of shape (batch, hidden_size): tensor containing the next hidden state for each element in the batch

• c_1 of shape (batch, hidden_size): tensor containing the next cell state for each element in the batch

Variables
• ~LSTMCell.weight_ih – the learnable input-hidden weights, of shape (4*hidden_size, input_size)

• ~LSTMCell.weight_hh – the learnable hidden-hidden weights, of shape (4*hidden_size, hidden_size)

• ~LSTMCell.bias_ih – the learnable input-hidden bias, of shape (4*hidden_size)

• ~LSTMCell.bias_hh – the learnable hidden-hidden bias, of shape (4*hidden_size)

Note

All the weights and biases are initialized from $\mathcal{U}(-\sqrt{k}, \sqrt{k})$ where $k = \frac{1}{\text{hidden\_size}}$

Examples:

>>> rnn = nn.LSTMCell(10, 20)
>>> input = torch.randn(3, 10)
>>> hx = torch.randn(3, 20)
>>> cx = torch.randn(3, 20)
>>> output = []
>>> for i in range(6):
hx, cx = rnn(input[i], (hx, cx))
output.append(hx)