Source code for torch.ao.nn.intrinsic.quantized.modules.conv_relu
# mypy: allow-untyped-defs
import torch
import torch.ao.nn.intrinsic
import torch.ao.nn.intrinsic.qat
import torch.ao.nn.quantized as nnq
import torch.nn.functional as F
from torch.nn.utils import fuse_conv_bn_weights
__all__ = [
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
]
_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding
# TODO: factor out the common parts to ConvNd
[docs]class ConvReLU1d(nnq.Conv1d):
r"""
A ConvReLU1d module is a fused module of Conv1d and ReLU
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv1d`.
Attributes:
Same as torch.ao.nn.quantized.Conv1d
"""
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU1d # type: ignore[assignment]
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
device=None,
dtype=None,
):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 3:
raise ValueError("Input shape must be `(N, C, L)`!")
if self.padding_mode != "zeros":
# Padding in Conv1d is stored as (p, p), need to get (p,)
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
input = F.pad(
input, _reversed_padding_repeated_twice, mode=self.padding_mode
)
return torch.ops.quantized.conv1d_relu(
input, self._packed_params, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedConvReLU1d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight,
mod.bias,
mod.bn.running_mean,
mod.bn.running_var,
mod.bn.eps,
mod.bn.weight,
mod.bn.bias,
)
return super().from_float(mod, use_precomputed_fake_quant)
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
assert (
type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU1d
), "BatchNorm1d should be fused into Conv1d before converting to reference module"
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
[docs]class ConvReLU2d(nnq.Conv2d):
r"""
A ConvReLU2d module is a fused module of Conv2d and ReLU
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
Attributes:
Same as torch.ao.nn.quantized.Conv2d
"""
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU2d # type: ignore[assignment]
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
device=None,
dtype=None,
):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
if self.padding_mode != "zeros":
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(
input, _reversed_padding_repeated_twice, mode=self.padding_mode
)
return torch.ops.quantized.conv2d_relu(
input, self._packed_params, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedConvReLU2d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight,
mod.bias,
mod.bn.running_mean,
mod.bn.running_var,
mod.bn.eps,
mod.bn.weight,
mod.bn.bias,
)
return super().from_float(
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
assert (
type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU2d
), "BatchNorm2d should be fused into Conv2d before converting to reference module"
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
[docs]class ConvReLU3d(nnq.Conv3d):
r"""
A ConvReLU3d module is a fused module of Conv3d and ReLU
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv3d`.
Attributes: Same as torch.ao.nn.quantized.Conv3d
"""
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU3d # type: ignore[assignment]
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
device=None,
dtype=None,
):
assert padding_mode != "reflect", "Conv3d does not support reflection padding"
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 5:
raise ValueError("Input shape must be `(N, C, D, H, W)`!")
if self.padding_mode != "zeros":
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(
input, _reversed_padding_repeated_twice, mode=self.padding_mode
)
return torch.ops.quantized.conv3d_relu(
input, self._packed_params, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedConvReLU3d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight,
mod.bias,
mod.bn.running_mean,
mod.bn.running_var,
mod.bn.eps,
mod.bn.weight,
mod.bn.bias,
)
return super().from_float(
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
assert (
type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU3d
), "BatchNorm3d should be fused into Conv3d before converting to reference module"
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)