Source code for torch.ao.nn.qat.modules.conv
# mypy: allow-untyped-defs
from typing import ClassVar, Tuple, Type, Union
import torch
import torch.nn as nn
from torch.ao.nn.intrinsic import _FusedModule
from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
from torch.nn.modules.utils import _pair, _single, _triple
__all__ = ["Conv1d", "Conv2d", "Conv3d"]
class _ConvNd(nn.modules.conv._ConvNd):
_FLOAT_MODULE: ClassVar[Type[nn.modules.conv._ConvNd]]
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, ...],
stride: Tuple[int, ...],
padding: Union[str, Tuple[int, ...]],
dilation: Tuple[int, ...],
transposed: bool,
output_padding: Tuple[int, ...],
groups: int,
bias: bool,
padding_mode: str,
qconfig=None,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
nn.modules.conv._ConvNd.__init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
bias,
padding_mode,
**factory_kwargs,
)
assert qconfig, "qconfig must be provided for QAT module"
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
def forward(self, input):
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
@staticmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module
Args:
`mod`: a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, (
"qat."
+ cls.__name__
+ ".from_float only works for "
+ cls._FLOAT_MODULE.__name__
)
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
assert mod.qconfig, "Input float module must have a valid qconfig"
if issubclass(type(mod), _FusedModule):
mod = mod[0]
qconfig = mod.qconfig
qat_conv = cls(
mod.in_channels,
mod.out_channels,
mod.kernel_size,
stride=mod.stride,
padding=mod.padding,
dilation=mod.dilation,
groups=mod.groups,
bias=mod.bias is not None,
padding_mode=mod.padding_mode,
qconfig=qconfig,
)
qat_conv.weight = mod.weight
qat_conv.bias = mod.bias
return qat_conv
def to_float(self):
"""This works for both single qat conv, and the qat conv - relu modules
to convert the qat module to a floating point module
"""
cls = type(self)
conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined]
self.in_channels,
self.out_channels,
self.kernel_size,
self.stride,
self.padding,
self.dilation,
self.groups,
self.bias is not None,
self.padding_mode,
)
conv.weight = torch.nn.Parameter(self.weight.detach())
if self.bias is not None:
conv.bias = torch.nn.Parameter(self.bias.detach())
# conv relu
if issubclass(cls, _FusedModule):
modules = [conv]
assert hasattr(cls, "_FLOAT_RELU_MODULE")
relu = cls._FLOAT_RELU_MODULE()
modules.append(relu)
fused = cls._FLOAT_MODULE(*modules)
fused.train(self.training)
return fused
else:
return conv
class Conv1d(_ConvNd, nn.Conv1d):
r"""
A Conv1d module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as :class:`~torch.nn.Conv1d`
Similar to :class:`~torch.nn.Conv2d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d
_FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: Union[str, _size_1_t] = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
qconfig=None,
device=None,
dtype=None,
) -> None:
kernel_size_ = _single(kernel_size)
stride_ = _single(stride)
padding_ = padding if isinstance(padding, str) else _single(padding)
dilation_ = _single(dilation)
super().__init__(
in_channels,
out_channels,
kernel_size_,
stride=stride_,
padding=padding_,
dilation=dilation_,
transposed=False,
output_padding=_single(0),
groups=groups,
bias=bias,
padding_mode=padding_mode,
qconfig=qconfig,
device=device,
dtype=dtype,
)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(
cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
[docs]class Conv2d(_ConvNd, nn.Conv2d):
r"""
A Conv2d module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as `torch.nn.Conv2d`, please see
https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d
for documentation.
Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d
_FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
qconfig=None,
device=None,
dtype=None,
) -> None:
kernel_size_ = _pair(kernel_size)
stride_ = _pair(stride)
padding_ = padding if isinstance(padding, str) else _pair(padding)
dilation_ = _pair(dilation)
super().__init__(
in_channels,
out_channels,
kernel_size_,
stride=stride_,
padding=padding_,
dilation=dilation_,
transposed=False,
output_padding=_pair(0),
groups=groups,
bias=bias,
padding_mode=padding_mode,
qconfig=qconfig,
device=device,
dtype=dtype,
)
def forward(self, input):
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(
cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
[docs]class Conv3d(_ConvNd, nn.Conv3d):
r"""
A Conv3d module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as `torch.nn.Conv3d`, please see
https://pytorch.org/docs/stable/nn.html?highlight=conv3d#torch.nn.Conv3d
for documentation.
Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d
_FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_3_t,
stride: _size_3_t = 1,
padding: Union[str, _size_3_t] = 0,
dilation: _size_3_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
qconfig=None,
device=None,
dtype=None,
) -> None:
kernel_size_ = _triple(kernel_size)
stride_ = _triple(stride)
padding_ = padding if isinstance(padding, str) else _triple(padding)
dilation_ = _triple(dilation)
super().__init__(
in_channels,
out_channels,
kernel_size_,
stride=stride_,
padding=padding_,
dilation=dilation_,
transposed=False,
output_padding=_triple(0),
groups=groups,
bias=bias,
padding_mode=padding_mode,
qconfig=qconfig,
device=device,
dtype=dtype,
)
def forward(self, input):
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(
cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)