Shortcuts

GRUCell

class torchrl.modules.GRUCell(input_size: int, hidden_size: int, bias: bool = True, device=None, dtype=None)[source]

A gated recurrent unit (GRU) cell that performs the same operation as nn.LSTMCell but is fully coded in Python.

Note

This class is implemented without relying on CuDNN, which makes it compatible with torch.vmap() and torch.compile().

Examples

>>> import torch
>>> from torchrl.modules.tensordict_module.rnn import GRUCell
>>> device = torch.device("cuda") if torch.cuda.device_count() else torch.device("cpu")
>>> B = 2
>>> N_IN = 10
>>> N_OUT = 20
>>> V = 4  # vector size
>>> gru_cell = GRUCell(input_size=N_IN, hidden_size=N_OUT, device=device)

# single call >>> x = torch.randn(B, 10, device=device) >>> h0 = torch.zeros(B, 20, device=device) >>> with torch.no_grad(): … h1 = gru_cell(x, h0)

# vectorised call - not possible with nn.GRUCell >>> def call_gru(x, h): … h_out = gru_cell(x, h) … return h_out >>> batched_call = torch.vmap(call_gru) >>> x = torch.randn(V, B, 10, device=device) >>> h0 = torch.zeros(V, B, 20, device=device) >>> with torch.no_grad(): … h1 = batched_call(x, h0)

A gated recurrent unit (GRU) cell.

\[\begin{split}\begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ n = \tanh(W_{in} x + b_{in} + r \odot (W_{hn} h + b_{hn})) \\ h' = (1 - z) \odot n + z \odot h \end{array}\end{split}\]

where \(\sigma\) is the sigmoid function, and \(\odot\) is the Hadamard product.

Parameters:
  • input_size – The number of expected features in the input x

  • hidden_size – The number of features in the hidden state h

  • bias – If False, then the layer does not use bias weights b_ih and b_hh. Default: True

Inputs: input, hidden
  • input : tensor containing input features

  • hidden : tensor containing the initial hidden state for each element in the batch. Defaults to zero if not provided.

Outputs: h’
  • h’ : tensor containing the next hidden state for each element in the batch

Shape:
  • input: \((N, H_{in})\) or \((H_{in})\) tensor containing input features where \(H_{in}\) = input_size.

  • hidden: \((N, H_{out})\) or \((H_{out})\) tensor containing the initial hidden state where \(H_{out}\) = hidden_size. Defaults to zero if not provided.

  • output: \((N, H_{out})\) or \((H_{out})\) tensor containing the next hidden state.

Variables:
  • weight_ih (torch.Tensor) – the learnable input-hidden weights, of shape (3*hidden_size, input_size)

  • weight_hh (torch.Tensor) – the learnable hidden-hidden weights, of shape (3*hidden_size, hidden_size)

  • bias_ih – the learnable input-hidden bias, of shape (3*hidden_size)

  • bias_hh – the learnable hidden-hidden bias, of shape (3*hidden_size)

Note

All the weights and biases are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{1}{\text{hidden\_size}}\)

On certain ROCm devices, when using float16 inputs this module will use different precision for backward.

Examples:

>>> rnn = nn.GRUCell(10, 20)
>>> input = torch.randn(6, 3, 10)
>>> hx = torch.randn(3, 20)
>>> output = []
>>> for i in range(6):
...     hx = rnn(input[i], hx)
...     output.append(hx)
forward(input: Tensor, hx: Optional[Tensor] = None) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources