[docs]classReLU6(torch.nn.ReLU):r"""Applies the element-wise function: :math:`\text{ReLU6}(x) = \min(\max(x_0, x), q(6))`, where :math:`x_0` is the zero_point, and :math:`q(6)` is the quantized representation of number 6. Args: inplace: can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions - Output: :math:`(N, *)`, same shape as the input .. image:: ../scripts/activation_images/ReLU6.png Examples:: >>> m = nn.quantized.ReLU6() >>> input = torch.randn(2) >>> # xdoctest: +SKIP >>> input = torch.quantize_per_tensor(input, 1.0, 0, dtype=torch.qint32) >>> output = m(input) """def__init__(self,inplace=False):super().__init__(inplace)self.inplace=inplacedefforward(self,input):returntorch.ops.quantized.relu6(input,self.inplace)def_get_name(self):return'QuantizedReLU6'@staticmethoddeffrom_float(mod):returnReLU6(mod.inplace)
[docs]classHardswish(torch.nn.Hardswish):r"""This is the quantized version of :class:`~torch.nn.Hardswish`. Args: scale: quantization scale of the output tensor zero_point: quantization zero point of the output tensor """def__init__(self,scale,zero_point,device=None,dtype=None):factory_kwargs={'device':device,'dtype':dtype}super().__init__()self.register_buffer('scale',torch.tensor(scale,**factory_kwargs))self.register_buffer('zero_point',torch.tensor(zero_point,**factory_kwargs))defforward(self,input):returntorch.ops.quantized.hardswish(input,self.scale,self.zero_point)def_get_name(self):return'QuantizedHardswish'@staticmethoddeffrom_float(mod):scale,zero_point=mod.activation_post_process.calculate_qparams()returnHardswish(float(scale),int(zero_point))@classmethoddeffrom_reference(cls,mod,scale,zero_point):returncls(float(scale),int(zero_point))
[docs]classELU(torch.nn.ELU):r"""This is the quantized equivalent of :class:`~torch.nn.ELU`. Args: scale: quantization scale of the output tensor zero_point: quantization zero point of the output tensor alpha: the alpha constant """def__init__(self,scale,zero_point,alpha=1.):super().__init__(alpha)self.scale=scaleself.zero_point=zero_pointdefforward(self,input):returntorch.ao.nn.quantized.functional.elu(input,self.scale,self.zero_point,self.alpha)def_get_name(self):return'QuantizedELU'@staticmethoddeffrom_float(mod):scale,zero_point=mod.activation_post_process.calculate_qparams()returnELU(float(scale),int(zero_point),mod.alpha)@classmethoddeffrom_reference(cls,mod,scale,zero_point):returncls(float(scale),int(zero_point),mod.alpha)
[docs]classLeakyReLU(torch.nn.LeakyReLU):r"""This is the quantized equivalent of :class:`~torch.nn.LeakyReLU`. Args: scale: quantization scale of the output tensor zero_point: quantization zero point of the output tensor negative_slope: Controls the angle of the negative slope. Default: 1e-2 """def__init__(self,scale:float,zero_point:int,negative_slope:float=1e-2,inplace:bool=False,device=None,dtype=None)->None:factory_kwargs={'device':device,'dtype':dtype}super().__init__(negative_slope,inplace)self.register_buffer('scale',torch.tensor(scale,**factory_kwargs))self.register_buffer('zero_point',torch.tensor(zero_point,**factory_kwargs))defforward(self,input):returntorch.ops.quantized.leaky_relu(input,self.negative_slope,self.inplace,self.scale,self.zero_point)def_get_name(self):return'QuantizedLeakyReLU'@classmethoddeffrom_float(cls,mod):scale,zero_point=mod.activation_post_process.calculate_qparams()returncls(float(scale),int(zero_point),mod.negative_slope,mod.inplace)@classmethoddeffrom_reference(cls,mod,scale,zero_point):returncls(float(scale),int(zero_point),mod.negative_slope,mod.inplace)
[docs]classSigmoid(torch.nn.Sigmoid):r"""This is the quantized equivalent of :class:`~torch.nn.Sigmoid`. Args: scale: quantization scale of the output tensor zero_point: quantization zero point of the output tensor """def__init__(self,output_scale:float,output_zero_point:int):super().__init__()self.output_scale=output_scaleself.output_zero_point=output_zero_pointdefforward(self,input):returntorch.ops.quantized.sigmoid(input,self.output_scale,self.output_zero_point)@classmethoddeffrom_float(cls,mod):output_scale,output_zero_point=mod.activation_post_process.calculate_qparams()returncls(float(output_scale),int(output_zero_point))
classSoftmax(torch.nn.Softmax):r"""This is the quantized version of :class:`~torch.nn.Softmax`. Args: dim: A dimension along which Softmax will be computed (so every slice along dim will sum to 1). scale: quantization scale of the output tensor zero_point: quantization zero point of the output tensor """def__init__(self,dim=None,scale=1.0,zero_point=0):super().__init__()self.dim=dimself.scale=scaleself.zero_point=zero_pointdefforward(self,input):dim=self.dimifdimisNone:stacklevel=3# Note: adding the mypy ignore on _get_softmax_dim seems less bad# than making `_get_softmax_dim` an official API.dim=torch.nn.functional._get_softmax_dim(# type: ignore[attr-defined]"softmax",input.dim(),stacklevel)returntorch.ops.quantized.softmax(input,dim,self.scale,self.zero_point)def_get_name(self):return'QuantizedSoftmax'@staticmethoddeffrom_float(mod):scale,zero_point=mod.activation_post_process.calculate_qparams()returnSoftmax(mod.dim,float(scale),int(zero_point))@classmethoddeffrom_reference(cls,mod,scale,zero_point):returncls(mod.dim,float(scale),int(zero_point))classMultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):_FLOAT_MODULE=torch.ao.nn.quantizable.MultiheadAttentiondef_get_name(self):return"QuantizedMultiheadAttention"@classmethoddeffrom_float(cls,other):# The whole flow is float -> observed -> quantized# This class does observed -> quantized onlyraiseNotImplementedError("It looks like you are trying to convert a ""non-observed MHA module. Please, see ""the examples on quantizable MHAs.")@classmethoddeffrom_observed(cls,other):converted=torch.ao.quantization.convert(other,mapping=None,inplace=False,remove_qconfig=True,convert_custom_config_dict=None)converted.__class__=cls# Remove the parameters for the bias_k and bias_v to quantize them# TODO: This is a potential source of accuracy drop.# quantized cat takes the scale and zp of the first# element, which might lose the precision in the bias_k# and the bias_v (which are cat'ed with k/v being first).ifconverted.bias_kisnotNone:bias_k=converted._parameters.pop('bias_k')sc,zp=torch._choose_qparams_per_tensor(bias_k,reduce_range=False)bias_k=torch.quantize_per_tensor(bias_k,sc,zp,torch.quint8)setattr(converted,'bias_k',bias_k)# noqa: B010ifconverted.bias_visnotNone:bias_v=converted._parameters.pop('bias_v')sc,zp=torch._choose_qparams_per_tensor(bias_k,reduce_range=False)bias_v=torch.quantize_per_tensor(bias_v,sc,zp,torch.quint8)setattr(converted,'bias_v',bias_v)# noqa: B010returnconvertedclassPReLU(torch.nn.Module):r"""This is the quantized equivalent of :class:`~torch.nn.PReLU`. Args: scale: quantization scale of the output tensor zero_point: quantization zero point of the output tensor num_parameters: number of parameters: 1, or the number of channels at input. Default: 1 """def__init__(self,output_scale:float,output_zero_point:int,num_parameters:int=1)->None:super().__init__()self.num_parameters=num_parametersself.scale=output_scaleself.zero_point=output_zero_pointw=torch.randn(num_parameters,dtype=torch.float)qw=torch.quantize_per_tensor(w,scale=1.0,zero_point=0,dtype=torch.quint8)self.set_weight(qw)defset_weight(self,w:torch.Tensor)->None:self.weight=wdefforward(self,input:torch.Tensor)->torch.Tensor:returntorch.ops.quantized.prelu(input,self.weight,self.scale,self.zero_point)def_get_name(self):return'QuantizedPReLU'@classmethoddeffrom_float(cls,mod):scale,zero_point=mod.activation_post_process.calculate_qparams()qprelu=cls(float(scale),int(zero_point),mod.num_parameters)float_wt=mod.weight.float()observer=mod.qconfig.weight()observer(float_wt)ifobserver.dtype!=torch.quint8:warn(f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}")wt_scale,wt_zp=observer.calculate_qparams()qweight=torch.quantize_per_tensor(float_wt,float(wt_scale),int(wt_zp),torch.quint8)qprelu.set_weight(qweight)returnqprelu@classmethoddeffrom_reference(cls,mod,scale,zero_point):qprelu=cls(float(scale),int(zero_point),mod.num_parameters)float_wt=mod.weight.float()observer=mod.qconfig.weight()observer(float_wt)ifobserver.dtype!=torch.quint8:warn(f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}")wt_scale,wt_zp=observer.calculate_qparams()qweight=torch.quantize_per_tensor(float_wt,float(wt_scale),int(wt_zp),torch.quint8)qprelu.set_weight(qweight)returnqprelu
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.