Source code for torch.ao.nn.quantized.modules.conv
# mypy: allow-untyped-defs
r"""Quantized convolution modules."""
from typing import Optional, List, TypeVar
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.intrinsic.qat as nniqat
from torch._ops import ops
from torch.nn.common_types import _size_1_t
from torch.nn.modules.utils import _single, _pair, _triple
from torch.nn.utils import fuse_conv_bn_weights
from .utils import _quantize_weight, WeightedQuantizedModule
__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d']
_SUPPORTED_PADDING = {
'zeros',
'reflect'
}
def _reverse_repeat_padding(padding: List[int]) -> List[int]:
_reversed_padding_repeated_twice: List[int] = []
N = len(padding)
for idx in range(N):
for _ in range(2):
_reversed_padding_repeated_twice.append(padding[N - idx - 1])
return _reversed_padding_repeated_twice
class _ConvNd(WeightedQuantizedModule):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
padding_mode='zeros', device=None, dtype=None):
# All subclasses have this signature - See PR #49702s
raise NotImplementedError
def _init(self, in_channels, out_channels, kernel_size, stride,
padding, dilation,
transposed, output_padding,
groups, bias,
padding_mode='zeros',
device=None,
dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups')
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.transposed = transposed
self.output_padding = output_padding
self.groups = groups
if padding_mode not in _SUPPORTED_PADDING:
raise ValueError(f"'padding_mode' {padding_mode} is not supported by quantized convolution")
self.padding_mode = padding_mode
# Initialize as NCHW. set_weight will internally transpose to NHWC.
if self.transposed:
weight_shape = [in_channels, out_channels // self.groups]
else:
weight_shape = [out_channels, in_channels // self.groups]
qweight = torch._empty_affine_quantized(
weight_shape + list(kernel_size),
scale=1, zero_point=0, dtype=torch.qint8,
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'})
bias_float = (
torch.zeros(out_channels, dtype=torch.float,
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}) if bias else None)
self.set_weight_bias(qweight, bias_float)
self.scale = 1.0
self.zero_point = 0
def set_weight_bias(self, qweight, bias_float):
raise NotImplementedError
def bias(self):
raise NotImplementedError
def _weight_bias(self):
raise NotImplementedError
def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}, scale={scale}, zero_point={zero_point}')
if self.padding != (0,) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1,) * len(self.dilation):
s += ', dilation={dilation}'
if self.output_padding != (0,) * len(self.output_padding):
s += ', output_padding={output_padding}'
if self.groups != 1:
s += ', groups={groups}'
if self.bias() is None:
s += ', bias=False'
return s.format(**self.__dict__)
# ===== Serialization methods =====
# The special consideration here is that we have to unpack the weights into
# their regular QTensor form for serialization. Packed weights should not
# live outside the process in which they were created, rather they should be
# derived from the QTensor weight.
# self
# |--- weight : Tensor
# |--- bias : Tensor
#
# TODO: maybe change to this when https://github.com/pytorch/pytorch/pull/32958 is landed
# self
# |--- _packed_params : Conv2dPackedParamsBase or Conv3dPackedParamsBase
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
(w, b) = self._weight_bias()
destination[prefix + 'weight'] = w
destination[prefix + 'bias'] = b
destination[prefix + 'scale'] = torch.tensor(self.scale)
destination[prefix + 'zero_point'] = torch.tensor(self.zero_point)
@torch.jit.export
def __getstate__(self):
(w, b) = self._weight_bias()
return (
self.in_channels,
self.out_channels,
self.kernel_size,
self.stride,
self.padding,
self.dilation,
self.transposed,
self.output_padding,
self.groups,
self.padding_mode,
w,
b,
self.scale,
self.zero_point,
self.training
)
# ===== Deserialization methods =====
# Counterpart to the serialization methods, we must pack the serialized
# QTensor weight into its packed format for use by the FBGEMM ops.
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
self.set_weight_bias(
state_dict[prefix + 'weight'], state_dict[prefix + 'bias'])
state_dict.pop(prefix + 'weight')
state_dict.pop(prefix + 'bias')
self.scale = float(state_dict[prefix + 'scale'])
state_dict.pop(prefix + 'scale')
self.zero_point = int(state_dict[prefix + 'zero_point'])
state_dict.pop(prefix + 'zero_point')
super()._load_from_state_dict(
state_dict, prefix, local_metadata, False, missing_keys,
unexpected_keys, error_msgs)
@torch.jit.export
def __setstate__(self, state):
self.in_channels = state[0]
self.out_channels = state[1]
self.kernel_size = state[2]
self.stride = state[3]
self.padding = state[4]
self.dilation = state[5]
self.transposed = state[6]
self.output_padding = state[7]
self.groups = state[8]
self.padding_mode = state[9]
self.set_weight_bias(state[10], state[11])
self.scale = state[12]
self.zero_point = state[13]
self.training = state[14]
def __deepcopy__(self, memo):
new_instance = type(self).__new__(type(self))
torch.nn.Module.__init__(new_instance)
state = self.__getstate__()
new_instance.__setstate__(state)
return new_instance
def __copy__(self):
return self.__deepcopy__({})
@classmethod
def get_qconv(cls, mod, activation_post_process, weight_post_process=None):
r"""Creates a qconv object and returns it.
"""
if weight_post_process is None:
weight_post_process = mod.qconfig.weight()
weight_post_process(mod.weight)
assert weight_post_process.dtype == torch.qint8, \
'Weight observer must have a dtype of qint8'
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
# the __init__ call used is the one from derived classes and not the one from _ConvNd
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
mod.stride, mod.padding, mod.dilation, mod.groups,
mod.bias is not None, mod.padding_mode)
qconv.set_weight_bias(qweight, mod.bias)
if activation_post_process is None or activation_post_process.dtype == torch.float:
return qconv # dynamic quantization doesn't need scale/zero_point
else:
act_scale, act_zp = activation_post_process.calculate_qparams()
qconv.scale = float(act_scale)
qconv.zero_point = int(act_zp)
return qconv
@staticmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
if hasattr(mod, "weight_fake_quant"):
# assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \
# ".from_float only works for " + cls.__QAT_MODULE.__name__
if type(mod) == cls._NNIQAT_CONV_BN_MODULE:
mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
mod.bn.eps, mod.bn.weight, mod.bn.bias)
assert hasattr(mod, "activation_post_process"), \
"Input QAT module must have observer attached"
weight_post_process = mod.weight_fake_quant
activation_post_process = mod.activation_post_process
else:
assert type(mod) == cls._FLOAT_MODULE, \
" nnq." + cls.__name__ + ".from_float only works for " + \
cls._FLOAT_MODULE.__name__ + " but got:" + str(type(mod))
assert hasattr(mod, "qconfig"), \
"Input float module must have qconfig defined."
activation_post_process = None if not hasattr(
mod, "activation_post_process") else mod.activation_post_process
if type(mod) in [cls._NNI_CONV_RELU_MODULE, cls._NNI_CONV_ADD_MODULE, cls._NNI_CONV_ADD_RELU_MODULE]:
mod = mod[0]
weight_post_process = mod.qconfig.weight()
return cls.get_qconv(mod, activation_post_process, weight_post_process)
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
Args:
ref_qconv (Module): a reference quantized module, either produced by torch.ao.quantization
utilities or provided by the user
output_scale (float): scale for output Tensor
output_zero_point (int): zero point for output Tensor
"""
qconv = cls(
ref_qconv.in_channels,
ref_qconv.out_channels,
ref_qconv.kernel_size, # type: ignore[arg-type]
ref_qconv.stride, # type: ignore[arg-type]
ref_qconv.padding, # type: ignore[arg-type]
ref_qconv.dilation, # type: ignore[arg-type]
ref_qconv.groups,
ref_qconv.bias is not None, # type: ignore[arg-type]
ref_qconv.padding_mode,
device=ref_qconv.weight.device,
dtype=ref_qconv.weight.dtype)
qweight = ref_qconv.get_quantized_weight()
qconv.set_weight_bias(qweight, ref_qconv.bias)
qconv.scale = float(output_scale)
qconv.zero_point = int(output_zero_point)
return qconv
[docs]class Conv1d(_ConvNd):
r"""Applies a 1D convolution over a quantized input signal composed of
several quantized input planes.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.Conv1d`.
.. note::
Only `zeros` is supported for the :attr:`padding_mode` argument.
.. note::
Only `torch.quint8` is supported for the input data type.
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.Conv1d` for other attributes.
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> m = nn.quantized.Conv1d(16, 33, 3, stride=2)
>>> input = torch.randn(20, 16, 100)
>>> # quantize input to quint8
>>> # xdoctest: +SKIP
>>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0,
... dtype=torch.quint8)
>>> output = m(q_input)
"""
_FLOAT_MODULE = nn.Conv1d
_NNIQAT_CONV_BN_MODULE = nniqat.ConvBn1d
_NNI_CONV_RELU_MODULE = nni.ConvReLU1d
_NNI_CONV_ADD_MODULE: None = None
_NNI_CONV_ADD_RELU_MODULE: None = None
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: _size_1_t = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
kernel_size = _single(kernel_size)
stride = _single(stride)
padding = padding if isinstance(padding, str) else _single(padding)
dilation = _single(dilation)
# Subclasses of _ConvNd needs to call _init rather than __init__. See
# discussion on PR #49702
super()._init(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _single(0), groups, bias, padding_mode, **factory_kwargs)
def _get_name(self):
return 'QuantizedConv1d'
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
if self.padding_mode == 'zeros':
self._packed_params = torch.ops.quantized.conv1d_prepack(
w, b, self.stride, self.padding, self.dilation, self.groups)
else:
self._packed_params = torch.ops.quantized.conv1d_prepack(
w, b, self.stride, _pair(0), self.dilation,
self.groups)
def _weight_bias(self):
w, b = torch.ops.quantized.conv1d_unpack(self._packed_params)
return w, b
def weight(self):
return self._weight_bias()[0]
def bias(self):
return self._weight_bias()[1]
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 3:
raise ValueError("Input shape must be `(N, C, L)`!")
if self.padding_mode != 'zeros':
# Padding in Conv1d is stored as (p, p), need to get (p,)
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
input = F.pad(input, _reversed_padding_repeated_twice,
mode=self.padding_mode)
return ops.quantized.conv1d(input, self._packed_params, self.scale, self.zero_point)
[docs] @classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Creates a quantized module from a float module or qparams_dict.
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
"""
return _ConvNd.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
[docs]class Conv2d(_ConvNd):
r"""Applies a 2D convolution over a quantized input signal composed of
several quantized input planes.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.Conv2d`.
.. note::
Only `zeros` is supported for the :attr:`padding_mode` argument.
.. note::
Only `torch.quint8` is supported for the input data type.
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.Conv2d` for other attributes.
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> # With square kernels and equal stride
>>> m = nn.quantized.Conv2d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
>>> # non-square kernels and unequal stride and with padding and dilation
>>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
>>> input = torch.randn(20, 16, 50, 100)
>>> # quantize input to quint8
>>> # xdoctest: +SKIP
>>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
>>> output = m(q_input)
"""
_FLOAT_MODULE = nn.Conv2d
_NNIQAT_CONV_BN_MODULE = nniqat.ConvBn2d
_NNI_CONV_RELU_MODULE = nni.ConvReLU2d
_NNI_CONV_ADD_MODULE = nni.ConvAdd2d
_NNI_CONV_ADD_RELU_MODULE = nni.ConvAddReLU2d
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
padding_mode='zeros', device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
# Subclasses of _ConvNd need to call _init rather than __init__. See
# discussion on PR #49702
super()._init(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias, padding_mode, **factory_kwargs)
def _get_name(self):
return 'QuantizedConv2d'
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
if self.padding_mode == 'zeros':
self._packed_params = torch.ops.quantized.conv2d_prepack(
w, b, self.stride, self.padding, self.dilation, self.groups)
else:
self._packed_params = torch.ops.quantized.conv2d_prepack(
w, b, self.stride, _pair(0), self.dilation, self.groups)
def _weight_bias(self):
return self._packed_params.unpack()
def weight(self):
return self._weight_bias()[0]
def bias(self):
return self._weight_bias()[1]
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
if self.padding_mode != 'zeros':
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(input, _reversed_padding_repeated_twice,
mode=self.padding_mode)
return ops.quantized.conv2d(
input, self._packed_params, self.scale, self.zero_point)
[docs] @classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Creates a quantized module from a float module or qparams_dict.
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
"""
return _ConvNd.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
[docs]class Conv3d(_ConvNd):
r"""Applies a 3D convolution over a quantized input signal composed of
several quantized input planes.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.Conv3d`.
.. note::
Only `zeros` is supported for the :attr:`padding_mode` argument.
.. note::
Only `torch.quint8` is supported for the input data type.
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.Conv3d` for other attributes.
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> # With square kernels and equal stride
>>> m = nn.quantized.Conv3d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2))
>>> # non-square kernels and unequal stride and with padding and dilation
>>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2))
>>> input = torch.randn(20, 16, 56, 56, 56)
>>> # quantize input to quint8
>>> # xdoctest: +SKIP
>>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
>>> output = m(q_input)
"""
_FLOAT_MODULE = nn.Conv3d
_NNIQAT_CONV_BN_MODULE = nniqat.ConvBn3d
_NNI_CONV_RELU_MODULE = nni.ConvReLU3d
_NNI_CONV_ADD_MODULE: None = None
_NNI_CONV_ADD_RELU_MODULE: None = None
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
padding_mode='zeros', device=None, dtype=None):
assert padding_mode != 'reflect', "Conv3d does not support reflection padding"
factory_kwargs = {'device': device, 'dtype': dtype}
kernel_size = _triple(kernel_size)
stride = _triple(stride)
padding = _triple(padding)
dilation = _triple(dilation)
# Subclasses of _ConvNd need to call _init rather than __init__. See
# discussion on PR #49702
super()._init(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _triple(0), groups, bias, padding_mode, **factory_kwargs)
def _get_name(self):
return 'QuantizedConv3d'
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
if self.padding_mode == 'zeros':
self._packed_params = torch.ops.quantized.conv3d_prepack(
w, b, self.stride, self.padding, self.dilation, self.groups)
else:
self._packed_params = torch.ops.quantized.conv3d_prepack(
w, b, self.stride, _triple(0), self.dilation, self.groups)
def _weight_bias(self):
return self._packed_params.unpack()
def weight(self):
return self._weight_bias()[0]
def bias(self):
return self._weight_bias()[1]
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 5:
raise ValueError("Input shape must be `(N, C, D, H, W)`!")
if self.padding_mode != 'zeros':
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(input, _reversed_padding_repeated_twice,
mode=self.padding_mode)
return ops.quantized.conv3d(
input, self._packed_params, self.scale, self.zero_point)
[docs] @classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Creates a quantized module from a float module or qparams_dict.
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
"""
return _ConvNd.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
# === Transposed Convolutions ===
MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
class _ConvTransposeNd(_ConvNd):
_FLOAT_MODULE = MOD
def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
groups, bias, padding_mode, device=None, dtype=None):
if padding_mode != 'zeros':
raise ValueError(f'Only "zeros" padding mode is supported for {self.__class__.__name__}')
factory_kwargs = {'device': device, 'dtype': dtype}
# Subclasses of _ConvNd need to call _init rather than __init__. See
# discussion on PR #49702
super()._init(
in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
groups, bias, padding_mode, **factory_kwargs)
def _input_padding(self, kernel_size: List[int], dilation: List[int], padding: List[int]) -> List[int]:
res = torch.jit.annotate(List[int], [])
for kdx in range(len(kernel_size)):
pad = (dilation[kdx] * (kernel_size[kdx] - 1) - padding[kdx])
res.append(pad)
return res
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Creates a quantized module from a float module or qparams_dict.
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
"""
# derived classes override cls._FLOAT_MODULE attribute
msg = ' nnq.' + cls.__name__ + '.from_float only works for ' + \
cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined]
assert type(mod) == cls._FLOAT_MODULE, msg
assert hasattr(mod, 'qconfig'), \
'Input float module must have qconfig defined.'
weight_post_process = mod.qconfig.weight()
weight_post_process(mod.weight)
assert weight_post_process.dtype == torch.qint8, \
'Weight observer must have a dtype of qint8'
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
# the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg]
mod.stride, mod.padding, mod.output_padding, mod.groups,
mod.bias is not None, mod.dilation, mod.padding_mode)
qconv.set_weight_bias(qweight, mod.bias)
if not hasattr(mod, "activation_post_process") or mod.activation_post_process.dtype == torch.float:
return qconv # dynamic quantization doesn't need scale/zero_point
else:
act_scale, act_zp = mod.activation_post_process.calculate_qparams()
qconv.scale = float(act_scale)
qconv.zero_point = int(act_zp)
return qconv
@staticmethod
def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
Args:
ref_qconvt (Module): a reference quantized module, either produced by torch.ao.quantization
utilities or provided by the user
output_scale (float): scale for output Tensor
output_zero_point (int): zero point for output Tensor
"""
qconv = cls(
ref_qconvt.in_channels,
ref_qconvt.out_channels,
ref_qconvt.kernel_size, # type: ignore[arg-type]
ref_qconvt.stride, # type: ignore[arg-type]
ref_qconvt.padding, # type: ignore[arg-type]
ref_qconvt.output_padding, # type: ignore[arg-type]
ref_qconvt.groups,
ref_qconvt.bias is not None, # type: ignore[arg-type]
ref_qconvt.dilation, # type: ignore[arg-type]
ref_qconvt.padding_mode,
device=ref_qconvt.weight.device,
dtype=ref_qconvt.weight.dtype)
qweight = ref_qconvt.get_quantized_weight()
qconv.set_weight_bias(qweight, ref_qconvt.bias)
qconv.scale = float(output_scale)
qconv.zero_point = int(output_zero_point)
return qconv
[docs]class ConvTranspose1d(_ConvTransposeNd):
r"""Applies a 1D transposed convolution operator over an input image
composed of several input planes.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.ConvTranspose1d`.
.. note:: Currently only the QNNPACK engine is implemented.
Please, set the `torch.backends.quantized.engine = 'qnnpack'`
For special notes, please, see :class:`~torch.ao.nn.quantized.Conv1d`
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.ConvTranspose2d` for other attributes.
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> torch.backends.quantized.engine = 'qnnpack'
>>> from torch.ao.nn import quantized as nnq
>>> # With square kernels and equal stride
>>> m = nnq.ConvTranspose1d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nnq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
>>> input = torch.randn(20, 16, 50)
>>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
>>> output = m(q_input)
>>> # exact output size can be also specified as an argument
>>> input = torch.randn(1, 16, 12)
>>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
>>> downsample = nnq.Conv1d(16, 16, 3, stride=2, padding=1)
>>> upsample = nnq.ConvTranspose1d(16, 16, 3, stride=2, padding=1)
>>> h = downsample(q_input)
>>> h.size()
torch.Size([1, 16, 6])
>>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
>>> output = upsample(h, output_size=input.size())
>>> output.size()
torch.Size([1, 16, 12])
"""
_FLOAT_MODULE = nn.ConvTranspose1d
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, output_padding=0, groups=1, bias=True,
dilation=1, padding_mode='zeros', device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
kernel_size = _single(kernel_size)
stride = _single(stride)
padding = _single(padding)
dilation = _single(dilation)
output_padding = _single(output_padding)
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias, padding_mode, **factory_kwargs)
def _get_name(self):
return 'QuantizedConvTranspose1d'
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(
w, b, self.stride, self.padding, self.output_padding, self.dilation,
self.groups)
def _weight_bias(self):
w, b = torch.ops.quantized.conv_transpose1d_unpack(self._packed_params)
return w, b
def weight(self):
(w, _) = self._weight_bias()
return w
def bias(self):
(_, b) = self._weight_bias()
return b
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 3:
raise ValueError("Input shape must be `(N, C, L)`!")
return torch.ops.quantized.conv_transpose1d(
input, self._packed_params, self.scale, self.zero_point)
@classmethod
def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
return _ConvTransposeNd.from_reference(cls, ref_qconvt, output_scale, output_zero_point)
[docs]class ConvTranspose2d(_ConvTransposeNd):
r"""Applies a 2D transposed convolution operator over an input image
composed of several input planes.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.ConvTranspose2d`.
For special notes, please, see :class:`~torch.ao.nn.quantized.Conv2d`
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.ConvTranspose2d` for other attributes.
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> # QNNPACK or FBGEMM as backend
>>> torch.backends.quantized.engine = 'qnnpack'
>>> # With square kernels and equal stride
>>> import torch.ao.nn.quantized as nnq
>>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
>>> input = torch.randn(20, 16, 50, 100)
>>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
>>> output = m(q_input)
>>> # exact output size can be also specified as an argument
>>> input = torch.randn(1, 16, 12, 12)
>>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
>>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1)
>>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
>>> h = downsample(q_input)
>>> h.size()
torch.Size([1, 16, 6, 6])
>>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
>>> output = upsample(h, output_size=input.size())
>>> output.size()
torch.Size([1, 16, 12, 12])
"""
_FLOAT_MODULE = nn.ConvTranspose2d
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, output_padding=0, groups=1, bias=True,
dilation=1, padding_mode='zeros', device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
output_padding = _pair(output_padding)
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias, padding_mode, **factory_kwargs)
def _get_name(self):
return 'QuantizedConvTranspose2d'
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(
w, b, self.stride, self.padding, self.output_padding, self.dilation,
self.groups)
def _weight_bias(self):
w, b = torch.ops.quantized.conv2d_unpack(self._packed_params)
return w, b
def weight(self):
(w, _) = self._weight_bias()
return w
def bias(self):
(_, b) = self._weight_bias()
return b
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
return ops.quantized.conv_transpose2d(
input, self._packed_params, self.scale, self.zero_point)
@classmethod
def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
return _ConvTransposeNd.from_reference(cls, ref_qconvt, output_scale, output_zero_point)
[docs]class ConvTranspose3d(_ConvTransposeNd):
r"""Applies a 3D transposed convolution operator over an input image
composed of several input planes.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.ConvTranspose3d`.
.. note:: Currently only the FBGEMM engine is implemented.
Please, set the `torch.backends.quantized.engine = 'fbgemm'`
For special notes, please, see :class:`~torch.ao.nn.quantized.Conv3d`
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.ConvTranspose3d` for other attributes.
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> torch.backends.quantized.engine = 'fbgemm'
>>> from torch.ao.nn import quantized as nnq
>>> # With cubic kernels and equal stride
>>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2)
>>> # non-cubic kernels and unequal stride and with padding
>>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2))
>>> input = torch.randn(20, 16, 50, 100, 100)
>>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
>>> output = m(q_input)
>>> # exact output size can be also specified as an argument
>>> input = torch.randn(1, 16, 12, 12, 12)
>>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
>>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1)
>>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1)
>>> h = downsample(q_input)
>>> h.size()
torch.Size([1, 16, 6, 6, 6])
>>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
>>> output = upsample(h, output_size=input.size())
>>> output.size()
torch.Size([1, 16, 12, 12, 12])
"""
_FLOAT_MODULE = nn.ConvTranspose3d
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, output_padding=0, groups=1, bias=True,
dilation=1, padding_mode='zeros', device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
kernel_size = _triple(kernel_size)
stride = _triple(stride)
padding = _triple(padding)
dilation = _triple(dilation)
output_padding = _triple(output_padding)
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
True, output_padding, groups, bias, padding_mode, **factory_kwargs)
def _get_name(self):
return 'QuantizedConvTranspose3d'
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(
w, b, self.stride, self.padding, self.output_padding, self.dilation,
self.groups)
def _weight_bias(self):
w, b = torch.ops.quantized.conv3d_unpack(self._packed_params)
return w, b
def weight(self):
(w, _) = self._weight_bias()
return w
def bias(self):
(_, b) = self._weight_bias()
return b
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 5:
raise ValueError("Input shape must be `(N, C, T, H, W)`!")
return ops.quantized.conv_transpose3d(
input, self._packed_params, self.scale, self.zero_point)
@classmethod
def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
return _ConvTransposeNd.from_reference(cls, ref_qconvt, output_scale, output_zero_point)