GRUCell¶
- class torchrl.modules.GRUCell(input_size: int, hidden_size: int, bias: bool = True, device=None, dtype=None)[source]¶
A gated recurrent unit (GRU) cell that performs the same operation as nn.LSTMCell but is fully coded in Python.
Note
This class is implemented without relying on CuDNN, which makes it compatible with
torch.vmap()
andtorch.compile()
.Examples
>>> import torch >>> from torchrl.modules.tensordict_module.rnn import GRUCell >>> device = torch.device("cuda") if torch.cuda.device_count() else torch.device("cpu") >>> B = 2 >>> N_IN = 10 >>> N_OUT = 20 >>> V = 4 # vector size >>> gru_cell = GRUCell(input_size=N_IN, hidden_size=N_OUT, device=device)
# single call >>> x = torch.randn(B, 10, device=device) >>> h0 = torch.zeros(B, 20, device=device) >>> with torch.no_grad(): … h1 = gru_cell(x, h0)
# vectorised call - not possible with nn.GRUCell >>> def call_gru(x, h): … h_out = gru_cell(x, h) … return h_out >>> batched_call = torch.vmap(call_gru) >>> x = torch.randn(V, B, 10, device=device) >>> h0 = torch.zeros(V, B, 20, device=device) >>> with torch.no_grad(): … h1 = batched_call(x, h0)
A gated recurrent unit (GRU) cell.
\[\begin{split}\begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ n = \tanh(W_{in} x + b_{in} + r \odot (W_{hn} h + b_{hn})) \\ h' = (1 - z) \odot n + z \odot h \end{array}\end{split}\]where \(\sigma\) is the sigmoid function, and \(\odot\) is the Hadamard product.
- Parameters:
input_size – The number of expected features in the input x
hidden_size – The number of features in the hidden state h
bias – If
False
, then the layer does not use bias weights b_ih and b_hh. Default:True
- Inputs: input, hidden
input : tensor containing input features
hidden : tensor containing the initial hidden state for each element in the batch. Defaults to zero if not provided.
- Outputs: h’
h’ : tensor containing the next hidden state for each element in the batch
- Shape:
input: \((N, H_{in})\) or \((H_{in})\) tensor containing input features where \(H_{in}\) = input_size.
hidden: \((N, H_{out})\) or \((H_{out})\) tensor containing the initial hidden state where \(H_{out}\) = hidden_size. Defaults to zero if not provided.
output: \((N, H_{out})\) or \((H_{out})\) tensor containing the next hidden state.
- Variables:
weight_ih (torch.Tensor) – the learnable input-hidden weights, of shape (3*hidden_size, input_size)
weight_hh (torch.Tensor) – the learnable hidden-hidden weights, of shape (3*hidden_size, hidden_size)
bias_ih – the learnable input-hidden bias, of shape (3*hidden_size)
bias_hh – the learnable hidden-hidden bias, of shape (3*hidden_size)
Note
All the weights and biases are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{1}{\text{hidden\_size}}\)
On certain ROCm devices, when using float16 inputs this module will use different precision for backward.
Examples:
>>> rnn = nn.GRUCell(10, 20) >>> input = torch.randn(6, 3, 10) >>> hx = torch.randn(3, 20) >>> output = [] >>> for i in range(6): ... hx = rnn(input[i], hx) ... output.append(hx)
- forward(input: Tensor, hx: Optional[Tensor] = None) Tensor [source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.