import math
import torch
import warnings
from .module import Module
from ..parameter import Parameter
from ..utils.rnn import PackedSequence
class RNNBase(Module):
def __init__(self, mode, input_size, hidden_size,
num_layers=1, bias=True, batch_first=False,
dropout=0, bidirectional=False):
super(RNNBase, self).__init__()
self.mode = mode
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.dropout = dropout
self.dropout_state = {}
self.bidirectional = bidirectional
num_directions = 2 if bidirectional else 1
if mode == 'LSTM':
gate_size = 4 * hidden_size
elif mode == 'GRU':
gate_size = 3 * hidden_size
else:
gate_size = hidden_size
self._all_weights = []
for layer in range(num_layers):
for direction in range(num_directions):
layer_input_size = input_size if layer == 0 else hidden_size * num_directions
w_ih = Parameter(torch.Tensor(gate_size, layer_input_size))
w_hh = Parameter(torch.Tensor(gate_size, hidden_size))
b_ih = Parameter(torch.Tensor(gate_size))
b_hh = Parameter(torch.Tensor(gate_size))
layer_params = (w_ih, w_hh, b_ih, b_hh)
suffix = '_reverse' if direction == 1 else ''
param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}']
if bias:
param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}']
param_names = [x.format(layer, suffix) for x in param_names]
for name, param in zip(param_names, layer_params):
setattr(self, name, param)
self._all_weights.append(param_names)
self.flatten_parameters()
self.reset_parameters()
def flatten_parameters(self):
"""Resets parameter data pointer so that they can use faster code paths.
Right now, this works only if the module is on the GPU and cuDNN is enabled.
Otherwise, it's a no-op.
"""
any_param = next(self.parameters()).data
if not any_param.is_cuda or not torch.backends.cudnn.is_acceptable(any_param):
self._data_ptrs = []
return
# If any parameters alias, we fall back to the slower, copying code path. This is
# a sufficient check, because overlapping parameter buffers that don't completely
# alias would break the assumptions of the uniqueness check in
# Module.named_parameters().
unique_data_ptrs = set(p.data_ptr() for l in self.all_weights for p in l)
if len(unique_data_ptrs) != sum(len(l) for l in self.all_weights):
self._data_ptrs = []
return
with torch.cuda.device_of(any_param):
# This is quite ugly, but it allows us to reuse the cuDNN code without larger
# modifications. It's really a low-level API that doesn't belong in here, but
# let's make this exception.
from torch.backends.cudnn import rnn
from torch.backends import cudnn
from torch.nn._functions.rnn import CudnnRNN
handle = cudnn.get_handle()
with warnings.catch_warnings(record=True):
fn = CudnnRNN(
self.mode,
self.input_size,
self.hidden_size,
num_layers=self.num_layers,
batch_first=self.batch_first,
dropout=self.dropout,
train=self.training,
bidirectional=self.bidirectional,
dropout_state=self.dropout_state,
)
# Initialize descriptors
fn.datatype = cudnn._typemap[any_param.type()]
fn.x_descs = cudnn.descriptor(any_param.new(1, self.input_size), 1)
fn.rnn_desc = rnn.init_rnn_descriptor(fn, handle)
# Allocate buffer to hold the weights
self._param_buf_size = rnn.get_num_weights(handle, fn.rnn_desc, fn.x_descs[0], fn.datatype)
fn.weight_buf = any_param.new(self._param_buf_size).zero_()
fn.w_desc = rnn.init_weight_descriptor(fn, fn.weight_buf)
# Slice off views into weight_buf
all_weights = [[p.data for p in l] for l in self.all_weights]
params = rnn.get_parameters(fn, handle, fn.weight_buf)
# Copy weights and update their storage
rnn._copyParams(all_weights, params)
for orig_layer_param, new_layer_param in zip(all_weights, params):
for orig_param, new_param in zip(orig_layer_param, new_layer_param):
orig_param.set_(new_param.view_as(orig_param))
self._data_ptrs = list(p.data.data_ptr() for p in self.parameters())
def _apply(self, fn):
ret = super(RNNBase, self)._apply(fn)
self.flatten_parameters()
return ret
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def check_forward_args(self, input, hidden, batch_sizes):
is_input_packed = batch_sizes is not None
expected_input_dim = 2 if is_input_packed else 3
if input.dim() != expected_input_dim:
raise RuntimeError(
'input must have {} dimensions, got {}'.format(
expected_input_dim, input.dim()))
if self.input_size != input.size(-1):
raise RuntimeError(
'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
fn.input_size, input.size(-1)))
if is_input_packed:
mini_batch = batch_sizes[0]
else:
mini_batch = input.size(0) if self.batch_first else input.size(1)
num_directions = 2 if self.bidirectional else 1
expected_hidden_size = (self.num_layers * num_directions,
mini_batch, self.hidden_size)
def check_hidden_size(hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
if tuple(hx.size()) != expected_hidden_size:
raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
if self.mode == 'LSTM':
check_hidden_size(hidden[0], expected_hidden_size,
'Expected hidden[0] size {}, got {}')
check_hidden_size(hidden[1], expected_hidden_size,
'Expected hidden[1] size {}, got {}')
else:
check_hidden_size(hidden, expected_hidden_size)
def forward(self, input, hx=None):
is_packed = isinstance(input, PackedSequence)
if is_packed:
input, batch_sizes = input
max_batch_size = batch_sizes[0]
else:
batch_sizes = None
max_batch_size = input.size(0) if self.batch_first else input.size(1)
if hx is None:
num_directions = 2 if self.bidirectional else 1
hx = torch.autograd.Variable(input.data.new(self.num_layers *
num_directions,
max_batch_size,
self.hidden_size).zero_(), requires_grad=False)
if self.mode == 'LSTM':
hx = (hx, hx)
has_flat_weights = list(p.data.data_ptr() for p in self.parameters()) == self._data_ptrs
if has_flat_weights:
first_data = next(self.parameters()).data
assert first_data.storage().size() == self._param_buf_size
flat_weight = first_data.new().set_(first_data.storage(), 0, torch.Size([self._param_buf_size]))
else:
flat_weight = None
self.check_forward_args(input, hx, batch_sizes)
func = self._backend.RNN(
self.mode,
self.input_size,
self.hidden_size,
num_layers=self.num_layers,
batch_first=self.batch_first,
dropout=self.dropout,
train=self.training,
bidirectional=self.bidirectional,
batch_sizes=batch_sizes,
dropout_state=self.dropout_state,
flat_weight=flat_weight
)
output, hidden = func(input, self.all_weights, hx)
if is_packed:
output = PackedSequence(output, batch_sizes)
return output, hidden
def __repr__(self):
s = '{name}({input_size}, {hidden_size}'
if self.num_layers != 1:
s += ', num_layers={num_layers}'
if self.bias is not True:
s += ', bias={bias}'
if self.batch_first is not False:
s += ', batch_first={batch_first}'
if self.dropout != 0:
s += ', dropout={dropout}'
if self.bidirectional is not False:
s += ', bidirectional={bidirectional}'
s += ')'
return s.format(name=self.__class__.__name__, **self.__dict__)
def __setstate__(self, d):
super(RNNBase, self).__setstate__(d)
self.__dict__.setdefault('_data_ptrs', [])
if 'all_weights' in d:
self._all_weights = d['all_weights']
if isinstance(self._all_weights[0][0], str):
return
num_layers = self.num_layers
num_directions = 2 if self.bidirectional else 1
self._all_weights = []
for layer in range(num_layers):
for direction in range(num_directions):
suffix = '_reverse' if direction == 1 else ''
weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', 'bias_hh_l{}{}']
weights = [x.format(layer, suffix) for x in weights]
if self.bias:
self._all_weights += [weights]
else:
self._all_weights += [weights[:2]]
@property
def all_weights(self):
return [[getattr(self, weight) for weight in weights] for weights in self._all_weights]
[docs]class RNN(RNNBase):
r"""Applies a multi-layer Elman RNN with tanh or ReLU non-linearity to an
input sequence.
For each element in the input sequence, each layer computes the following
function:
.. math::
h_t = \tanh(w_{ih} * x_t + b_{ih} + w_{hh} * h_{(t-1)} + b_{hh})
where :math:`h_t` is the hidden state at time `t`, and :math:`x_t` is
the hidden state of the previous layer at time `t` or :math:`input_t`
for the first layer. If nonlinearity='relu', then `ReLU` is used instead
of `tanh`.
Args:
input_size: The number of expected features in the input x
hidden_size: The number of features in the hidden state h
num_layers: Number of recurrent layers.
nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh'
bias: If ``False``, then the layer does not use bias weights b_ih and b_hh.
Default: ``True``
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature)
dropout: If non-zero, introduces a dropout layer on the outputs of each
RNN layer except the last layer
bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False``
Inputs: input, h_0
- **input** (seq_len, batch, input_size): tensor containing the features
of the input sequence. The input can also be a packed variable length
sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence`
for details.
- **h_0** (num_layers * num_directions, batch, hidden_size): tensor
containing the initial hidden state for each element in the batch.
Defaults to zero if not provided.
Outputs: output, h_n
- **output** (seq_len, batch, hidden_size * num_directions): tensor
containing the output features (h_k) from the last layer of the RNN,
for each k. If a :class:`torch.nn.utils.rnn.PackedSequence` has
been given as the input, the output will also be a packed sequence.
- **h_n** (num_layers * num_directions, batch, hidden_size): tensor
containing the hidden state for k=seq_len.
Attributes:
weight_ih_l[k]: the learnable input-hidden weights of the k-th layer,
of shape `(hidden_size x input_size)` for k=0. Otherwise, the shape is
`(hidden_size x hidden_size)`
weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer,
of shape `(hidden_size x hidden_size)`
bias_ih_l[k]: the learnable input-hidden bias of the k-th layer,
of shape `(hidden_size)`
bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer,
of shape `(hidden_size)`
Examples::
>>> rnn = nn.RNN(10, 20, 2)
>>> input = Variable(torch.randn(5, 3, 10))
>>> h0 = Variable(torch.randn(2, 3, 20))
>>> output, hn = rnn(input, h0)
"""
def __init__(self, *args, **kwargs):
if 'nonlinearity' in kwargs:
if kwargs['nonlinearity'] == 'tanh':
mode = 'RNN_TANH'
elif kwargs['nonlinearity'] == 'relu':
mode = 'RNN_RELU'
else:
raise ValueError("Unknown nonlinearity '{}'".format(
kwargs['nonlinearity']))
del kwargs['nonlinearity']
else:
mode = 'RNN_TANH'
super(RNN, self).__init__(mode, *args, **kwargs)
[docs]class LSTM(RNNBase):
r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input
sequence.
For each element in the input sequence, each layer computes the following
function:
.. math::
\begin{array}{ll}
i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\
f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\
o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\
c_t = f_t * c_{(t-1)} + i_t * g_t \\
h_t = o_t * \tanh(c_t)
\end{array}
where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
state at time `t`, :math:`x_t` is the hidden state of the previous layer at
time `t` or :math:`input_t` for the first layer, and :math:`i_t`,
:math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget, cell,
and out gates, respectively.
Args:
input_size: The number of expected features in the input x
hidden_size: The number of features in the hidden state h
num_layers: Number of recurrent layers.
bias: If ``False``, then the layer does not use bias weights b_ih and b_hh.
Default: ``True``
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature)
dropout: If non-zero, introduces a dropout layer on the outputs of each
RNN layer except the last layer
bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False``
Inputs: input, (h_0, c_0)
- **input** (seq_len, batch, input_size): tensor containing the features
of the input sequence.
The input can also be a packed variable length sequence.
See :func:`torch.nn.utils.rnn.pack_padded_sequence` for details.
- **h_0** (num_layers \* num_directions, batch, hidden_size): tensor
containing the initial hidden state for each element in the batch.
- **c_0** (num_layers \* num_directions, batch, hidden_size): tensor
containing the initial cell state for each element in the batch.
If (h_0, c_0) is not provided, both **h_0** and **c_0** default to zero.
Outputs: output, (h_n, c_n)
- **output** (seq_len, batch, hidden_size * num_directions): tensor
containing the output features `(h_t)` from the last layer of the RNN,
for each t. If a :class:`torch.nn.utils.rnn.PackedSequence` has been
given as the input, the output will also be a packed sequence.
- **h_n** (num_layers * num_directions, batch, hidden_size): tensor
containing the hidden state for t=seq_len
- **c_n** (num_layers * num_directions, batch, hidden_size): tensor
containing the cell state for t=seq_len
Attributes:
weight_ih_l[k] : the learnable input-hidden weights of the k-th layer
`(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size x input_size)`
weight_hh_l[k] : the learnable hidden-hidden weights of the k-th layer
`(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size x hidden_size)`
bias_ih_l[k] : the learnable input-hidden bias of the k-th layer
`(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)`
bias_hh_l[k] : the learnable hidden-hidden bias of the k-th layer
`(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)`
Examples::
>>> rnn = nn.LSTM(10, 20, 2)
>>> input = Variable(torch.randn(5, 3, 10))
>>> h0 = Variable(torch.randn(2, 3, 20))
>>> c0 = Variable(torch.randn(2, 3, 20))
>>> output, hn = rnn(input, (h0, c0))
"""
def __init__(self, *args, **kwargs):
super(LSTM, self).__init__('LSTM', *args, **kwargs)
[docs]class GRU(RNNBase):
r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
For each element in the input sequence, each layer computes the following
function:
.. math::
\begin{array}{ll}
r_t = \mathrm{sigmoid}(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\
\end{array}
where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the hidden
state of the previous layer at time `t` or :math:`input_t` for the first
layer, and :math:`r_t`, :math:`z_t`, :math:`n_t` are the reset, input,
and new gates, respectively.
Args:
input_size: The number of expected features in the input x
hidden_size: The number of features in the hidden state h
num_layers: Number of recurrent layers.
bias: If ``False``, then the layer does not use bias weights b_ih and b_hh.
Default: ``True``
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature)
dropout: If non-zero, introduces a dropout layer on the outputs of each
RNN layer except the last layer
bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False``
Inputs: input, h_0
- **input** (seq_len, batch, input_size): tensor containing the features
of the input sequence. The input can also be a packed variable length
sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence`
for details.
- **h_0** (num_layers * num_directions, batch, hidden_size): tensor
containing the initial hidden state for each element in the batch.
Defaults to zero if not provided.
Outputs: output, h_n
- **output** (seq_len, batch, hidden_size * num_directions): tensor
containing the output features h_t from the last layer of the RNN,
for each t. If a :class:`torch.nn.utils.rnn.PackedSequence` has been
given as the input, the output will also be a packed sequence.
- **h_n** (num_layers * num_directions, batch, hidden_size): tensor
containing the hidden state for t=seq_len
Attributes:
weight_ih_l[k] : the learnable input-hidden weights of the k-th layer
(W_ir|W_iz|W_in), of shape `(3*hidden_size x input_size)`
weight_hh_l[k] : the learnable hidden-hidden weights of the k-th layer
(W_hr|W_hz|W_hn), of shape `(3*hidden_size x hidden_size)`
bias_ih_l[k] : the learnable input-hidden bias of the k-th layer
(b_ir|b_iz|b_in), of shape `(3*hidden_size)`
bias_hh_l[k] : the learnable hidden-hidden bias of the k-th layer
(b_hr|b_hz|b_hn), of shape `(3*hidden_size)`
Examples::
>>> rnn = nn.GRU(10, 20, 2)
>>> input = Variable(torch.randn(5, 3, 10))
>>> h0 = Variable(torch.randn(2, 3, 20))
>>> output, hn = rnn(input, h0)
"""
def __init__(self, *args, **kwargs):
super(GRU, self).__init__('GRU', *args, **kwargs)
class RNNCellBase(Module):
def __repr__(self):
s = '{name}({input_size}, {hidden_size}'
if 'bias' in self.__dict__ and self.bias is not True:
s += ', bias={bias}'
if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
s += ', nonlinearity={nonlinearity}'
s += ')'
return s.format(name=self.__class__.__name__, **self.__dict__)
[docs]class RNNCell(RNNCellBase):
r"""An Elman RNN cell with tanh or ReLU non-linearity.
.. math::
h' = \tanh(w_{ih} * x + b_{ih} + w_{hh} * h + b_{hh})
If nonlinearity='relu', then ReLU is used in place of tanh.
Args:
input_size: The number of expected features in the input x
hidden_size: The number of features in the hidden state h
bias: If ``False``, then the layer does not use bias weights b_ih and b_hh.
Default: ``True``
nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh'
Inputs: input, hidden
- **input** (batch, input_size): tensor containing input features
- **hidden** (batch, hidden_size): tensor containing the initial hidden
state for each element in the batch.
Outputs: h'
- **h'** (batch, hidden_size): tensor containing the next hidden state
for each element in the batch
Attributes:
weight_ih: the learnable input-hidden weights, of shape
`(input_size x hidden_size)`
weight_hh: the learnable hidden-hidden weights, of shape
`(hidden_size x hidden_size)`
bias_ih: the learnable input-hidden bias, of shape `(hidden_size)`
bias_hh: the learnable hidden-hidden bias, of shape `(hidden_size)`
Examples::
>>> rnn = nn.RNNCell(10, 20)
>>> input = Variable(torch.randn(6, 3, 10))
>>> hx = Variable(torch.randn(3, 20))
>>> output = []
>>> for i in range(6):
... hx = rnn(input[i], hx)
... output.append(hx)
"""
def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh"):
super(RNNCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.nonlinearity = nonlinearity
self.weight_ih = Parameter(torch.Tensor(hidden_size, input_size))
self.weight_hh = Parameter(torch.Tensor(hidden_size, hidden_size))
if bias:
self.bias_ih = Parameter(torch.Tensor(hidden_size))
self.bias_hh = Parameter(torch.Tensor(hidden_size))
else:
self.register_parameter('bias_ih', None)
self.register_parameter('bias_hh', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, input, hx):
if self.nonlinearity == "tanh":
func = self._backend.RNNTanhCell
elif self.nonlinearity == "relu":
func = self._backend.RNNReLUCell
else:
raise RuntimeError(
"Unknown nonlinearity: {}".format(self.nonlinearity))
return func(
input, hx,
self.weight_ih, self.weight_hh,
self.bias_ih, self.bias_hh,
)
[docs]class LSTMCell(RNNCellBase):
r"""A long short-term memory (LSTM) cell.
.. math::
\begin{array}{ll}
i = \mathrm{sigmoid}(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
f = \mathrm{sigmoid}(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
g = \tanh(W_{ig} x + b_{ig} + W_{hc} h + b_{hg}) \\
o = \mathrm{sigmoid}(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
c' = f * c + i * g \\
h' = o * \tanh(c') \\
\end{array}
Args:
input_size: The number of expected features in the input x
hidden_size: The number of features in the hidden state h
bias: If `False`, then the layer does not use bias weights `b_ih` and
`b_hh`. Default: ``True``
Inputs: input, (h_0, c_0)
- **input** (batch, input_size): tensor containing input features
- **h_0** (batch, hidden_size): tensor containing the initial hidden
state for each element in the batch.
- **c_0** (batch. hidden_size): tensor containing the initial cell state
for each element in the batch.
Outputs: h_1, c_1
- **h_1** (batch, hidden_size): tensor containing the next hidden state
for each element in the batch
- **c_1** (batch, hidden_size): tensor containing the next cell state
for each element in the batch
Attributes:
weight_ih: the learnable input-hidden weights, of shape
`(4*hidden_size x input_size)`
weight_hh: the learnable hidden-hidden weights, of shape
`(4*hidden_size x hidden_size)`
bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)`
bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)`
Examples::
>>> rnn = nn.LSTMCell(10, 20)
>>> input = Variable(torch.randn(6, 3, 10))
>>> hx = Variable(torch.randn(3, 20))
>>> cx = Variable(torch.randn(3, 20))
>>> output = []
>>> for i in range(6):
... hx, cx = rnn(input[i], (hx, cx))
... output.append(hx)
"""
def __init__(self, input_size, hidden_size, bias=True):
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.weight_ih = Parameter(torch.Tensor(4 * hidden_size, input_size))
self.weight_hh = Parameter(torch.Tensor(4 * hidden_size, hidden_size))
if bias:
self.bias_ih = Parameter(torch.Tensor(4 * hidden_size))
self.bias_hh = Parameter(torch.Tensor(4 * hidden_size))
else:
self.register_parameter('bias_ih', None)
self.register_parameter('bias_hh', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, input, hx):
return self._backend.LSTMCell(
input, hx,
self.weight_ih, self.weight_hh,
self.bias_ih, self.bias_hh,
)
[docs]class GRUCell(RNNCellBase):
r"""A gated recurrent unit (GRU) cell
.. math::
\begin{array}{ll}
r = \mathrm{sigmoid}(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
z = \mathrm{sigmoid}(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
h' = (1 - z) * n + z * h
\end{array}
Args:
input_size: The number of expected features in the input x
hidden_size: The number of features in the hidden state h
bias: If `False`, then the layer does not use bias weights `b_ih` and
`b_hh`. Default: `True`
Inputs: input, hidden
- **input** (batch, input_size): tensor containing input features
- **hidden** (batch, hidden_size): tensor containing the initial hidden
state for each element in the batch.
Outputs: h'
- **h'**: (batch, hidden_size): tensor containing the next hidden state
for each element in the batch
Attributes:
weight_ih: the learnable input-hidden weights, of shape
`(3*hidden_size x input_size)`
weight_hh: the learnable hidden-hidden weights, of shape
`(3*hidden_size x hidden_size)`
bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)`
bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)`
Examples::
>>> rnn = nn.GRUCell(10, 20)
>>> input = Variable(torch.randn(6, 3, 10))
>>> hx = Variable(torch.randn(3, 20))
>>> output = []
>>> for i in range(6):
... hx = rnn(input[i], hx)
... output.append(hx)
"""
def __init__(self, input_size, hidden_size, bias=True):
super(GRUCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.weight_ih = Parameter(torch.Tensor(3 * hidden_size, input_size))
self.weight_hh = Parameter(torch.Tensor(3 * hidden_size, hidden_size))
if bias:
self.bias_ih = Parameter(torch.Tensor(3 * hidden_size))
self.bias_hh = Parameter(torch.Tensor(3 * hidden_size))
else:
self.register_parameter('bias_ih', None)
self.register_parameter('bias_hh', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, input, hx):
return self._backend.GRUCell(
input, hx,
self.weight_ih, self.weight_hh,
self.bias_ih, self.bias_hh,
)