Source code for torch.ao.nn.quantizable.modules.rnn
# mypy: allow-untyped-defs
import numbers
from typing import Optional, Tuple
import warnings
import torch
from torch import Tensor
"""
We will recreate all the RNN modules as we require the modules to be decomposed
into its building blocks to be able to observe.
"""
__all__ = [
"LSTMCell",
"LSTM"
]
class LSTMCell(torch.nn.Module):
r"""A quantizable long short-term memory (LSTM) cell.
For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell`
Examples::
>>> import torch.ao.nn.quantizable as nnqa
>>> rnn = nnqa.LSTMCell(10, 20)
>>> input = torch.randn(6, 10)
>>> hx = torch.randn(3, 20)
>>> cx = torch.randn(3, 20)
>>> output = []
>>> for i in range(6):
... hx, cx = rnn(input[i], (hx, cx))
... output.append(hx)
"""
_FLOAT_MODULE = torch.nn.LSTMCell
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.input_size = input_dim
self.hidden_size = hidden_dim
self.bias = bias
self.igates = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
self.hgates = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
self.gates = torch.ao.nn.quantized.FloatFunctional()
self.input_gate = torch.nn.Sigmoid()
self.forget_gate = torch.nn.Sigmoid()
self.cell_gate = torch.nn.Tanh()
self.output_gate = torch.nn.Sigmoid()
self.fgate_cx = torch.ao.nn.quantized.FloatFunctional()
self.igate_cgate = torch.ao.nn.quantized.FloatFunctional()
self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional()
self.ogate_cy = torch.ao.nn.quantized.FloatFunctional()
self.initial_hidden_state_qparams: Tuple[float, int] = (1.0, 0)
self.initial_cell_state_qparams: Tuple[float, int] = (1.0, 0)
self.hidden_state_dtype: torch.dtype = torch.quint8
self.cell_state_dtype: torch.dtype = torch.quint8
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
if hidden is None or hidden[0] is None or hidden[1] is None:
hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
hx, cx = hidden
igates = self.igates(x)
hgates = self.hgates(hx)
gates = self.gates.add(igates, hgates)
input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)
input_gate = self.input_gate(input_gate)
forget_gate = self.forget_gate(forget_gate)
cell_gate = self.cell_gate(cell_gate)
out_gate = self.output_gate(out_gate)
fgate_cx = self.fgate_cx.mul(forget_gate, cx)
igate_cgate = self.igate_cgate.mul(input_gate, cell_gate)
fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate)
cy = fgate_cx_igate_cgate
# TODO: make this tanh a member of the module so its qparams can be configured
tanh_cy = torch.tanh(cy)
hy = self.ogate_cy.mul(out_gate, tanh_cy)
return hy, cy
def initialize_hidden(self, batch_size: int, is_quantized: bool = False) -> Tuple[Tensor, Tensor]:
h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros((batch_size, self.hidden_size))
if is_quantized:
(h_scale, h_zp) = self.initial_hidden_state_qparams
(c_scale, c_zp) = self.initial_cell_state_qparams
h = torch.quantize_per_tensor(h, scale=h_scale, zero_point=h_zp, dtype=self.hidden_state_dtype)
c = torch.quantize_per_tensor(c, scale=c_scale, zero_point=c_zp, dtype=self.cell_state_dtype)
return h, c
def _get_name(self):
return 'QuantizableLSTMCell'
@classmethod
def from_params(cls, wi, wh, bi=None, bh=None):
"""Uses the weights and biases to create a new LSTM cell.
Args:
wi, wh: Weights for the input and hidden layers
bi, bh: Biases for the input and hidden layers
"""
assert (bi is None) == (bh is None) # Either both None or both have values
input_size = wi.shape[1]
hidden_size = wh.shape[1]
cell = cls(input_dim=input_size, hidden_dim=hidden_size,
bias=(bi is not None))
cell.igates.weight = torch.nn.Parameter(wi)
if bi is not None:
cell.igates.bias = torch.nn.Parameter(bi)
cell.hgates.weight = torch.nn.Parameter(wh)
if bh is not None:
cell.hgates.bias = torch.nn.Parameter(bh)
return cell
@classmethod
def from_float(cls, other, use_precomputed_fake_quant=False):
assert type(other) == cls._FLOAT_MODULE
assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
observed = cls.from_params(other.weight_ih, other.weight_hh,
other.bias_ih, other.bias_hh)
observed.qconfig = other.qconfig
observed.igates.qconfig = other.qconfig
observed.hgates.qconfig = other.qconfig
return observed
class _LSTMSingleLayer(torch.nn.Module):
r"""A single one-directional LSTM layer.
The difference between a layer and a cell is that the layer can process a
sequence, while the cell only expects an instantaneous value.
"""
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.cell = LSTMCell(input_dim, hidden_dim, bias=bias, **factory_kwargs)
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
result = []
seq_len = x.shape[0]
for i in range(seq_len):
hidden = self.cell(x[i], hidden)
result.append(hidden[0]) # type: ignore[index]
result_tensor = torch.stack(result, 0)
return result_tensor, hidden
@classmethod
def from_params(cls, *args, **kwargs):
cell = LSTMCell.from_params(*args, **kwargs)
layer = cls(cell.input_size, cell.hidden_size, cell.bias)
layer.cell = cell
return layer
class _LSTMLayer(torch.nn.Module):
r"""A single bi-directional LSTM layer."""
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
batch_first: bool = False, bidirectional: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.batch_first = batch_first
self.bidirectional = bidirectional
self.layer_fw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs)
if self.bidirectional:
self.layer_bw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs)
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
if self.batch_first:
x = x.transpose(0, 1)
if hidden is None:
hx_fw, cx_fw = (None, None)
else:
hx_fw, cx_fw = hidden
hidden_bw: Optional[Tuple[Tensor, Tensor]] = None
if self.bidirectional:
if hx_fw is None:
hx_bw = None
else:
hx_bw = hx_fw[1]
hx_fw = hx_fw[0]
if cx_fw is None:
cx_bw = None
else:
cx_bw = cx_fw[1]
cx_fw = cx_fw[0]
if hx_bw is not None and cx_bw is not None:
hidden_bw = hx_bw, cx_bw
if hx_fw is None and cx_fw is None:
hidden_fw = None
else:
hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional(cx_fw)
result_fw, hidden_fw = self.layer_fw(x, hidden_fw)
if hasattr(self, 'layer_bw') and self.bidirectional:
x_reversed = x.flip(0)
result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw)
result_bw = result_bw.flip(0)
result = torch.cat([result_fw, result_bw], result_fw.dim() - 1)
if hidden_fw is None and hidden_bw is None:
h = None
c = None
elif hidden_fw is None:
(h, c) = torch.jit._unwrap_optional(hidden_bw)
elif hidden_bw is None:
(h, c) = torch.jit._unwrap_optional(hidden_fw)
else:
h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore[list-item]
c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore[list-item]
else:
result = result_fw
h, c = torch.jit._unwrap_optional(hidden_fw) # type: ignore[assignment]
if self.batch_first:
result.transpose_(0, 1)
return result, (h, c)
@classmethod
def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs):
r"""
There is no FP equivalent of this class. This function is here just to
mimic the behavior of the `prepare` within the `torch.ao.quantization`
flow.
"""
assert hasattr(other, 'qconfig') or (qconfig is not None)
input_size = kwargs.get('input_size', other.input_size)
hidden_size = kwargs.get('hidden_size', other.hidden_size)
bias = kwargs.get('bias', other.bias)
batch_first = kwargs.get('batch_first', other.batch_first)
bidirectional = kwargs.get('bidirectional', other.bidirectional)
layer = cls(input_size, hidden_size, bias, batch_first, bidirectional)
layer.qconfig = getattr(other, 'qconfig', qconfig)
wi = getattr(other, f'weight_ih_l{layer_idx}')
wh = getattr(other, f'weight_hh_l{layer_idx}')
bi = getattr(other, f'bias_ih_l{layer_idx}', None)
bh = getattr(other, f'bias_hh_l{layer_idx}', None)
layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
if other.bidirectional:
wi = getattr(other, f'weight_ih_l{layer_idx}_reverse')
wh = getattr(other, f'weight_hh_l{layer_idx}_reverse')
bi = getattr(other, f'bias_ih_l{layer_idx}_reverse', None)
bh = getattr(other, f'bias_hh_l{layer_idx}_reverse', None)
layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
return layer
[docs]class LSTM(torch.nn.Module):
r"""A quantizable long short-term memory (LSTM).
For the description and the argument types, please, refer to :class:`~torch.nn.LSTM`
Attributes:
layers : instances of the `_LSTMLayer`
.. note::
To access the weights and biases, you need to access them per layer.
See examples below.
Examples::
>>> import torch.ao.nn.quantizable as nnqa
>>> rnn = nnqa.LSTM(10, 20, 2)
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> c0 = torch.randn(2, 3, 20)
>>> output, (hn, cn) = rnn(input, (h0, c0))
>>> # To get the weights:
>>> # xdoctest: +SKIP
>>> print(rnn.layers[0].weight_ih)
tensor([[...]])
>>> print(rnn.layers[0].weight_hh)
AssertionError: There is no reverse path in the non-bidirectional layer
"""
_FLOAT_MODULE = torch.nn.LSTM
def __init__(self, input_size: int, hidden_size: int,
num_layers: int = 1, bias: bool = True,
batch_first: bool = False, dropout: float = 0.,
bidirectional: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.dropout = float(dropout)
self.bidirectional = bidirectional
self.training = False # Default to eval mode. If we want to train, we will explicitly set to training.
num_directions = 2 if bidirectional else 1
if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
isinstance(dropout, bool):
raise ValueError("dropout should be a number in range [0, 1] "
"representing the probability of an element being "
"zeroed")
if dropout > 0:
warnings.warn("dropout option for quantizable LSTM is ignored. "
"If you are training, please, use nn.LSTM version "
"followed by `prepare` step.")
if num_layers == 1:
warnings.warn("dropout option adds dropout after all but last "
"recurrent layer, so non-zero dropout expects "
f"num_layers greater than 1, but got dropout={dropout} "
f"and num_layers={num_layers}")
layers = [_LSTMLayer(self.input_size, self.hidden_size,
self.bias, batch_first=False,
bidirectional=self.bidirectional, **factory_kwargs)]
for layer in range(1, num_layers):
layers.append(_LSTMLayer(self.hidden_size, self.hidden_size,
self.bias, batch_first=False,
bidirectional=self.bidirectional,
**factory_kwargs))
self.layers = torch.nn.ModuleList(layers)
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
if self.batch_first:
x = x.transpose(0, 1)
max_batch_size = x.size(1)
num_directions = 2 if self.bidirectional else 1
if hidden is None:
zeros = torch.zeros(num_directions, max_batch_size,
self.hidden_size, dtype=torch.float,
device=x.device)
zeros.squeeze_(0)
if x.is_quantized:
zeros = torch.quantize_per_tensor(zeros, scale=1.0,
zero_point=0, dtype=x.dtype)
hxcx = [(zeros, zeros) for _ in range(self.num_layers)]
else:
hidden_non_opt = torch.jit._unwrap_optional(hidden)
if isinstance(hidden_non_opt[0], Tensor):
hx = hidden_non_opt[0].reshape(self.num_layers, num_directions,
max_batch_size,
self.hidden_size)
cx = hidden_non_opt[1].reshape(self.num_layers, num_directions,
max_batch_size,
self.hidden_size)
hxcx = [(hx[idx].squeeze(0), cx[idx].squeeze(0)) for idx in range(self.num_layers)]
else:
hxcx = hidden_non_opt
hx_list = []
cx_list = []
for idx, layer in enumerate(self.layers):
x, (h, c) = layer(x, hxcx[idx])
hx_list.append(torch.jit._unwrap_optional(h))
cx_list.append(torch.jit._unwrap_optional(c))
hx_tensor = torch.stack(hx_list)
cx_tensor = torch.stack(cx_list)
# We are creating another dimension for bidirectional case
# need to collapse it
hx_tensor = hx_tensor.reshape(-1, hx_tensor.shape[-2], hx_tensor.shape[-1])
cx_tensor = cx_tensor.reshape(-1, cx_tensor.shape[-2], cx_tensor.shape[-1])
if self.batch_first:
x = x.transpose(0, 1)
return x, (hx_tensor, cx_tensor)
def _get_name(self):
return 'QuantizableLSTM'
@classmethod
def from_float(cls, other, qconfig=None):
assert isinstance(other, cls._FLOAT_MODULE)
assert (hasattr(other, 'qconfig') or qconfig)
observed = cls(other.input_size, other.hidden_size, other.num_layers,
other.bias, other.batch_first, other.dropout,
other.bidirectional)
observed.qconfig = getattr(other, 'qconfig', qconfig)
for idx in range(other.num_layers):
observed.layers[idx] = _LSTMLayer.from_float(other, idx, qconfig,
batch_first=False)
# Prepare the model
if other.training:
observed.train()
observed = torch.ao.quantization.prepare_qat(observed, inplace=True)
else:
observed.eval()
observed = torch.ao.quantization.prepare(observed, inplace=True)
return observed
@classmethod
def from_observed(cls, other):
# The whole flow is float -> observed -> quantized
# This class does float -> observed only
raise NotImplementedError("It looks like you are trying to convert a "
"non-quantizable LSTM module. Please, see "
"the examples on quantizable LSTMs.")