Shortcuts

Source code for torch.ao.nn.quantized.modules.normalization

# mypy: allow-untyped-defs
import torch


__all__ = [
    "LayerNorm",
    "GroupNorm",
    "InstanceNorm1d",
    "InstanceNorm2d",
    "InstanceNorm3d",
]


[docs]class LayerNorm(torch.nn.LayerNorm): r"""This is the quantized version of :class:`~torch.nn.LayerNorm`. Additional args: * **scale** - quantization scale of the output, type: double. * **zero_point** - quantization zero point of the output, type: long. """ def __init__( self, normalized_shape, weight, bias, scale, zero_point, eps=1e-5, elementwise_affine=True, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__( normalized_shape, eps=eps, elementwise_affine=elementwise_affine, **factory_kwargs, ) self.weight = weight self.bias = bias self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) def forward(self, input): return torch.ops.quantized.layer_norm( input, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps, output_scale=self.scale, output_zero_point=self.zero_point, ) def _get_name(self): return "QuantizedLayerNorm" @classmethod def from_float(cls, mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.normalized_shape, mod.weight, mod.bias, float(scale), int(zero_point), mod.eps, mod.elementwise_affine, ) return new_mod @classmethod def from_reference(cls, mod, scale, zero_point): return cls( mod.normalized_shape, mod.weight, mod.bias, float(scale), int(zero_point), mod.eps, mod.elementwise_affine, )
[docs]class GroupNorm(torch.nn.GroupNorm): r"""This is the quantized version of :class:`~torch.nn.GroupNorm`. Additional args: * **scale** - quantization scale of the output, type: double. * **zero_point** - quantization zero point of the output, type: long. """ __constants__ = ["num_groups", "num_channels", "eps", "affine"] def __init__( self, num_groups, num_channels, weight, bias, scale, zero_point, eps=1e-5, affine=True, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs) self.weight = weight self.bias = bias self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) def forward(self, input): return torch.ops.quantized.group_norm( input, self.num_groups, self.weight, self.bias, self.eps, self.scale, self.zero_point, ) def _get_name(self): return "QuantizedGroupNorm" @classmethod def from_float(cls, mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_groups, mod.num_channels, mod.weight, mod.bias, float(scale), int(zero_point), mod.eps, mod.affine, ) return new_mod
[docs]class InstanceNorm1d(torch.nn.InstanceNorm1d): r"""This is the quantized version of :class:`~torch.nn.InstanceNorm1d`. Additional args: * **scale** - quantization scale of the output, type: double. * **zero_point** - quantization zero point of the output, type: long. """ def __init__( self, num_features, weight, bias, scale, zero_point, eps=1e-5, momentum=0.1, affine=False, track_running_stats=False, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__( num_features, eps, momentum, affine, track_running_stats, **factory_kwargs ) self.weight = weight self.bias = bias self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) def forward(self, input): return torch.ops.quantized.instance_norm( input, self.weight, self.bias, self.eps, self.scale, self.zero_point ) def _get_name(self): return "QuantizedInstanceNorm1d" @classmethod def from_float(cls, mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), mod.eps, mod.affine, ) return new_mod @classmethod def from_reference(cls, mod, scale, zero_point): return cls( mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), mod.eps, mod.affine, )
[docs]class InstanceNorm2d(torch.nn.InstanceNorm2d): r"""This is the quantized version of :class:`~torch.nn.InstanceNorm2d`. Additional args: * **scale** - quantization scale of the output, type: double. * **zero_point** - quantization zero point of the output, type: long. """ def __init__( self, num_features, weight, bias, scale, zero_point, eps=1e-5, momentum=0.1, affine=False, track_running_stats=False, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__( num_features, eps, momentum, affine, track_running_stats, **factory_kwargs ) self.weight = weight self.bias = bias self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) def forward(self, input): return torch.ops.quantized.instance_norm( input, self.weight, self.bias, self.eps, self.scale, self.zero_point ) def _get_name(self): return "QuantizedInstanceNorm2d" @classmethod def from_float(cls, mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), mod.eps, mod.affine, ) return new_mod @classmethod def from_reference(cls, mod, scale, zero_point): return cls( mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), mod.eps, mod.affine, )
[docs]class InstanceNorm3d(torch.nn.InstanceNorm3d): r"""This is the quantized version of :class:`~torch.nn.InstanceNorm3d`. Additional args: * **scale** - quantization scale of the output, type: double. * **zero_point** - quantization zero point of the output, type: long. """ def __init__( self, num_features, weight, bias, scale, zero_point, eps=1e-5, momentum=0.1, affine=False, track_running_stats=False, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__( num_features, eps, momentum, affine, track_running_stats, **factory_kwargs ) self.weight = weight self.bias = bias self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) def forward(self, input): return torch.ops.quantized.instance_norm( input, self.weight, self.bias, self.eps, self.scale, self.zero_point ) def _get_name(self): return "QuantizedInstanceNorm3d" @classmethod def from_float(cls, mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), mod.eps, mod.affine, ) return new_mod @classmethod def from_reference(cls, mod, scale, zero_point): return cls( mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), mod.eps, mod.affine, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources