[docs]classQuantize(torch.nn.Module):r"""Quantizes an incoming tensor Args: `scale`: scale of the output Quantized Tensor `zero_point`: zero_point of output Quantized Tensor `dtype`: data type of output Quantized Tensor `factory_kwargs`: Dictionary of kwargs used for configuring initialization of internal buffers. Currently, `device` and `dtype` are supported. Example: `factory_kwargs={'device': 'cuda', 'dtype': torch.float64}` will initialize internal buffers as type `torch.float64` on the current CUDA device. Note that `dtype` only applies to floating-point buffers. Examples:: >>> t = torch.tensor([[1., -1.], [1., -1.]]) >>> scale, zero_point, dtype = 1.0, 2, torch.qint8 >>> qm = Quantize(scale, zero_point, dtype) >>> qt = qm(t) >>> print(qt) tensor([[ 1., -1.], [ 1., -1.]], size=(2, 2), dtype=torch.qint8, scale=1.0, zero_point=2) """scale:torch.Tensorzero_point:torch.Tensordef__init__(self,scale,zero_point,dtype,factory_kwargs=None):factory_kwargs=torch.nn.factory_kwargs(factory_kwargs)super(Quantize,self).__init__()self.register_buffer('scale',torch.tensor([scale],**factory_kwargs))self.register_buffer('zero_point',torch.tensor([zero_point],dtype=torch.long,**{k:vfork,vinfactory_kwargs.items()ifk!='dtype'}))self.dtype=dtypedefforward(self,X):returntorch.quantize_per_tensor(X,float(self.scale),int(self.zero_point),self.dtype)@staticmethoddeffrom_float(mod):asserthasattr(mod,'activation_post_process')scale,zero_point=mod.activation_post_process.calculate_qparams()returnQuantize(scale.float().item(),zero_point.long().item(),mod.activation_post_process.dtype)defextra_repr(self):return'scale={}, zero_point={}, dtype={}'.format(self.scale,self.zero_point,self.dtype)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.