Source code for torch.ao.nn.quantized.modules.normalization
# mypy: allow-untyped-defs
import torch
__all__ = ['LayerNorm', 'GroupNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d']
[docs]class LayerNorm(torch.nn.LayerNorm):
r"""This is the quantized version of :class:`~torch.nn.LayerNorm`.
Additional args:
* **scale** - quantization scale of the output, type: double.
* **zero_point** - quantization zero point of the output, type: long.
"""
def __init__(self, normalized_shape, weight, bias, scale, zero_point, eps=1e-5,
elementwise_affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine,
**factory_kwargs)
self.weight = weight
self.bias = bias
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
return torch.ops.quantized.layer_norm(
input, self.normalized_shape, weight=self.weight, bias=self.bias,
eps=self.eps, output_scale=self.scale, output_zero_point=self.zero_point)
def _get_name(self):
return 'QuantizedLayerNorm'
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(
mod.normalized_shape, mod.weight, mod.bias, float(scale),
int(zero_point), mod.eps, mod.elementwise_affine)
return new_mod
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(
mod.normalized_shape, mod.weight, mod.bias, float(scale),
int(zero_point), mod.eps, mod.elementwise_affine)
[docs]class GroupNorm(torch.nn.GroupNorm):
r"""This is the quantized version of :class:`~torch.nn.GroupNorm`.
Additional args:
* **scale** - quantization scale of the output, type: double.
* **zero_point** - quantization zero point of the output, type: long.
"""
__constants__ = ['num_groups', 'num_channels', 'eps', 'affine']
def __init__(self, num_groups, num_channels, weight, bias, scale, zero_point, eps=1e-5,
affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs)
self.weight = weight
self.bias = bias
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
return torch.ops.quantized.group_norm(
input, self.num_groups, self.weight, self.bias, self.eps, self.scale,
self.zero_point)
def _get_name(self):
return 'QuantizedGroupNorm'
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(
mod.num_groups, mod.num_channels, mod.weight, mod.bias, float(scale), int(zero_point),
mod.eps, mod.affine)
return new_mod
[docs]class InstanceNorm1d(torch.nn.InstanceNorm1d):
r"""This is the quantized version of :class:`~torch.nn.InstanceNorm1d`.
Additional args:
* **scale** - quantization scale of the output, type: double.
* **zero_point** - quantization zero point of the output, type: long.
"""
def __init__(self, num_features, weight, bias, scale, zero_point,
eps=1e-5, momentum=0.1, affine=False,
track_running_stats=False, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
self.weight = weight
self.bias = bias
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
return torch.ops.quantized.instance_norm(
input, self.weight, self.bias, self.eps, self.scale,
self.zero_point)
def _get_name(self):
return 'QuantizedInstanceNorm1d'
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
mod.eps, mod.affine)
return new_mod
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
mod.eps, mod.affine)
[docs]class InstanceNorm2d(torch.nn.InstanceNorm2d):
r"""This is the quantized version of :class:`~torch.nn.InstanceNorm2d`.
Additional args:
* **scale** - quantization scale of the output, type: double.
* **zero_point** - quantization zero point of the output, type: long.
"""
def __init__(self, num_features, weight, bias, scale, zero_point,
eps=1e-5, momentum=0.1, affine=False,
track_running_stats=False, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
self.weight = weight
self.bias = bias
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
return torch.ops.quantized.instance_norm(
input, self.weight, self.bias, self.eps, self.scale,
self.zero_point)
def _get_name(self):
return 'QuantizedInstanceNorm2d'
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
mod.eps, mod.affine)
return new_mod
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
mod.eps, mod.affine)
[docs]class InstanceNorm3d(torch.nn.InstanceNorm3d):
r"""This is the quantized version of :class:`~torch.nn.InstanceNorm3d`.
Additional args:
* **scale** - quantization scale of the output, type: double.
* **zero_point** - quantization zero point of the output, type: long.
"""
def __init__(self, num_features, weight, bias, scale, zero_point,
eps=1e-5, momentum=0.1, affine=False,
track_running_stats=False, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
self.weight = weight
self.bias = bias
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
return torch.ops.quantized.instance_norm(
input, self.weight, self.bias, self.eps, self.scale,
self.zero_point)
def _get_name(self):
return 'QuantizedInstanceNorm3d'
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
mod.eps, mod.affine)
return new_mod
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
mod.eps, mod.affine)