Shortcuts

Source code for torch.ao.nn.qat.modules.conv

# mypy: allow-untyped-defs
from typing import ClassVar, Tuple, Type, Union

import torch
import torch.nn as nn
from torch.ao.nn.intrinsic import _FusedModule
from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
from torch.nn.modules.utils import _pair, _single, _triple


__all__ = ["Conv1d", "Conv2d", "Conv3d"]


class _ConvNd(nn.modules.conv._ConvNd):
    _FLOAT_MODULE: ClassVar[Type[nn.modules.conv._ConvNd]]

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Tuple[int, ...],
        stride: Tuple[int, ...],
        padding: Tuple[int, ...],
        dilation: Tuple[int, ...],
        transposed: bool,
        output_padding: Tuple[int, ...],
        groups: int,
        bias: bool,
        padding_mode: str,
        qconfig=None,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        nn.modules.conv._ConvNd.__init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            transposed,
            output_padding,
            groups,
            bias,
            padding_mode,
            **factory_kwargs,
        )
        assert qconfig, "qconfig must be provided for QAT module"
        self.qconfig = qconfig
        self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)

    def forward(self, input):
        return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)

    @staticmethod
    def from_float(cls, mod, use_precomputed_fake_quant=False):
        r"""Create a qat module from a float module

        Args:
           `mod`: a float module, either produced by torch.ao.quantization utilities
           or directly from user
        """
        assert type(mod) == cls._FLOAT_MODULE, (
            "qat."
            + cls.__name__
            + ".from_float only works for "
            + cls._FLOAT_MODULE.__name__  # type: ignore[attr-defined]
        )
        assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
        assert mod.qconfig, "Input float module must have a valid qconfig"
        if issubclass(type(mod), _FusedModule):
            mod = mod[0]  # type: ignore[index]
        qconfig = mod.qconfig
        qat_conv = cls(
            mod.in_channels,
            mod.out_channels,
            mod.kernel_size,
            stride=mod.stride,
            padding=mod.padding,
            dilation=mod.dilation,
            groups=mod.groups,
            bias=mod.bias is not None,
            padding_mode=mod.padding_mode,
            qconfig=qconfig,
        )
        qat_conv.weight = mod.weight
        qat_conv.bias = mod.bias
        return qat_conv

    def to_float(self):
        """This works for both single qat conv, and the qat conv - relu modules
        to convert the qat module to a floating point module
        """
        cls = type(self)
        conv = cls._FLOAT_CONV_MODULE(  # type: ignore[attr-defined, operator]
            self.in_channels,
            self.out_channels,
            self.kernel_size,  # type: ignore[arg-type]
            self.stride,  # type: ignore[arg-type]
            self.padding,  # type: ignore[arg-type]
            self.dilation,  # type: ignore[arg-type]
            self.groups,
            self.bias is not None,
            self.padding_mode,
        )
        conv.weight = torch.nn.Parameter(self.weight.detach())
        if self.bias is not None:
            conv.bias = torch.nn.Parameter(self.bias.detach())
        # conv relu
        if issubclass(cls, _FusedModule):
            modules = [conv]
            assert hasattr(cls, "_FLOAT_RELU_MODULE")
            relu = cls._FLOAT_RELU_MODULE()  # type: ignore[attr-defined]
            modules.append(relu)
            fused = cls._FLOAT_MODULE(*modules)  # type: ignore[arg-type, attr-defined, operator]
            fused.train(self.training)
            return fused
        else:
            return conv


class Conv1d(_ConvNd, nn.Conv1d):
    r"""
    A Conv1d module attached with FakeQuantize modules for weight,
    used for quantization aware training.

    We adopt the same interface as :class:`~torch.nn.Conv1d`

    Similar to :class:`~torch.nn.Conv2d`, with FakeQuantize modules initialized to
    default.

    Attributes:
        weight_fake_quant: fake quant module for weight
    """
    _FLOAT_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d  # type: ignore[assignment,misc]
    _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: _size_1_t,
        stride: _size_1_t = 1,
        padding: Union[str, _size_1_t] = 0,
        dilation: _size_1_t = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",
        qconfig=None,
        device=None,
        dtype=None,
    ) -> None:
        kernel_size_ = _single(kernel_size)
        stride_ = _single(stride)
        padding_ = padding if isinstance(padding, str) else _single(padding)
        dilation_ = _single(dilation)
        super().__init__(
            in_channels,
            out_channels,
            kernel_size_,
            stride=stride_,
            padding=padding_,
            dilation=dilation_,
            transposed=False,
            output_padding=_single(0),
            groups=groups,
            bias=bias,
            padding_mode=padding_mode,
            qconfig=qconfig,
            device=device,
            dtype=dtype,
        )

    @classmethod
    def from_float(cls, mod, use_precomputed_fake_quant=False):
        return super().from_float(
            cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
        )


[docs]class Conv2d(_ConvNd, nn.Conv2d): r""" A Conv2d module attached with FakeQuantize modules for weight, used for quantization aware training. We adopt the same interface as `torch.nn.Conv2d`, please see https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d for documentation. Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to default. Attributes: weight_fake_quant: fake quant module for weight """ _FLOAT_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d # type: ignore[assignment,misc] _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: Union[str, _size_2_t] = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", qconfig=None, device=None, dtype=None, ) -> None: kernel_size_ = _pair(kernel_size) stride_ = _pair(stride) padding_ = padding if isinstance(padding, str) else _pair(padding) dilation_ = _pair(dilation) super().__init__( in_channels, out_channels, kernel_size_, stride=stride_, padding=padding_, dilation=dilation_, transposed=False, output_padding=_pair(0), groups=groups, bias=bias, padding_mode=padding_mode, qconfig=qconfig, device=device, dtype=dtype, ) def forward(self, input): return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) @classmethod def from_float(cls, mod, use_precomputed_fake_quant=False): return super().from_float( cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant )
[docs]class Conv3d(_ConvNd, nn.Conv3d): r""" A Conv3d module attached with FakeQuantize modules for weight, used for quantization aware training. We adopt the same interface as `torch.nn.Conv3d`, please see https://pytorch.org/docs/stable/nn.html?highlight=conv3d#torch.nn.Conv3d for documentation. Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to default. Attributes: weight_fake_quant: fake quant module for weight """ _FLOAT_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d # type: ignore[assignment,misc] _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_3_t, stride: _size_3_t = 1, padding: Union[str, _size_3_t] = 0, dilation: _size_3_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", qconfig=None, device=None, dtype=None, ) -> None: kernel_size_ = _triple(kernel_size) stride_ = _triple(stride) padding_ = padding if isinstance(padding, str) else _triple(padding) dilation_ = _triple(dilation) super().__init__( in_channels, out_channels, kernel_size_, stride=stride_, padding=padding_, dilation=dilation_, transposed=False, output_padding=_triple(0), groups=groups, bias=bias, padding_mode=padding_mode, qconfig=qconfig, device=device, dtype=dtype, ) def forward(self, input): return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) @classmethod def from_float(cls, mod, use_precomputed_fake_quant=False): return super().from_float( cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources