Source code for torch.ao.nn.quantized.modules.embedding_ops
importtorchimporttorch.nnasnnfromtorchimportTensor# noqa: F401fromtorch._jit_internalimportOptional,List# noqa: F401from.utilsimport_hide_packed_params_reprfrom.utilsimport_quantize_weight__all__=['EmbeddingPackedParams','Embedding','EmbeddingBag']classEmbeddingPackedParams(torch.nn.Module):_version=1def__init__(self,num_embeddings,embedding_dim,dtype=torch.quint8):super().__init__()self.dtype=dtypeifself.dtypein[torch.quint8,torch.quint4x2]:scales=torch.ones(num_embeddings,dtype=torch.float)zero_points=torch.zeros(num_embeddings,dtype=torch.float)wq=torch._empty_per_channel_affine_quantized([num_embeddings,embedding_dim],scales=scales,zero_points=zero_points,axis=0,dtype=self.dtype)self.set_weight(wq)else:raiseNotImplementedError(f'Unsupported dtype on quantized embedding! Supports quint8 and quint4x2. Got dtype: {dtype}')@torch.jit.exportdefset_weight(self,weight:torch.Tensor)->None:ifself.dtypein[torch.quint8,torch.quint4x2]:self._packed_weight=torch.ops.quantized.embedding_bag_prepack(weight)else:raiseNotImplementedError('Unsupported dtype for quantized embedding prepack! Supports quint8 and quint4x2.')@torch.jit.exportdef_weight(self):ifself.dtypein[torch.quint8,torch.quint4x2]:returntorch.ops.quantized.embedding_bag_unpack(self._packed_weight)else:raiseNotImplementedError('Unsupported dtype for quantized embedding unpack! Supports quint8 and quint4x2.')defforward(self,x):returnx# Version 1# self# |--- _packed_weight : Tensor representing weight of EmbeddingPackedParamsBase# |--- dtype : torch.dtypedef_save_to_state_dict(self,destination,prefix,keep_vars):super()._save_to_state_dict(destination,prefix,keep_vars)destination[prefix+'dtype']=self.dtypedestination[prefix+'_packed_weight']=self._weight()def_load_from_state_dict(self,state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs):self.dtype=state_dict[prefix+'dtype']state_dict.pop(prefix+'dtype')weight=state_dict[prefix+'_packed_weight']state_dict.pop(prefix+'_packed_weight')self.set_weight(weight)super()._load_from_state_dict(state_dict,prefix,local_metadata,False,missing_keys,unexpected_keys,error_msgs)def__repr__(self):returnself._weight().__repr__()
[docs]classEmbedding(torch.nn.Module):r""" A quantized Embedding module with quantized packed weights as inputs. We adopt the same interface as `torch.nn.Embedding`, please see https://pytorch.org/docs/stable/nn.html#torch.nn.Embedding for documentation. Similar to :class:`~torch.nn.Embedding`, attributes will be randomly initialized at module creation time and will be overwritten later Attributes: weight (Tensor): the non-learnable quantized weights of the module of shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`. Examples:: >>> m = nn.quantized.Embedding(num_embeddings=10, embedding_dim=12) >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8]) >>> output = m(indices) >>> print(output.size()) torch.Size([9, 12]) """_version=1def__init__(self,num_embeddings:int,embedding_dim:int,padding_idx:Optional[int]=None,max_norm:Optional[float]=None,norm_type:float=2.,scale_grad_by_freq:bool=False,sparse:bool=False,_weight:Optional[Tensor]=None,dtype=torch.quint8)->None:super().__init__()self.num_embeddings=num_embeddingsself.embedding_dim=embedding_dimself.dtype=dtypeif_weightisNone:scales=torch.ones(num_embeddings,dtype=torch.float)zero_points=torch.zeros(num_embeddings,dtype=torch.float)qweight=torch._empty_per_channel_affine_quantized([num_embeddings,embedding_dim],scales=scales,zero_points=zero_points,axis=0,dtype=torch.quint8)else:assertlist(_weight.shape)==[num_embeddings,embedding_dim], \
'Shape of weight does not match num_embeddings and embedding_dim'qweight=_weightself._packed_params=EmbeddingPackedParams(num_embeddings,embedding_dim,dtype)self._packed_params.set_weight(qweight)defforward(self,indices:Tensor)->Tensor:ifself.dtype==torch.quint4x2:returntorch.ops.quantized.embedding_4bit(self._packed_params._packed_weight,indices)else:returntorch.ops.quantized.embedding_byte(self._packed_params._packed_weight,indices)def_get_name(self):return'QuantizedEmbedding'def__repr__(self):return_hide_packed_params_repr(self,EmbeddingPackedParams)defextra_repr(self):extra_repr_str='num_embeddings={}, embedding_dim={}, dtype={}, qscheme={}'.format(self.num_embeddings,self.embedding_dim,self._packed_params.dtype,self.weight().qscheme())returnextra_repr_strdefset_weight(self,w:torch.Tensor)->None:self._packed_params.set_weight(w)defweight(self):returnself._packed_params._weight()
[docs]@classmethoddeffrom_float(cls,mod):r"""Create a quantized embedding module from a float module Args: mod (Module): a float module, either produced by torch.ao.quantization utilities or provided by user """ifhasattr(mod,'weight_fake_quant'):asserttype(mod)==torch.ao.nn.qat.Embedding,'nnq.'+cls.__name__+'.from_float '+ \
'with fake quant only works for '+torch.ao.nn.qat.Embedding.__name__weight_observer=mod.weight_fake_quantactivation_post_process=mod.activation_post_processelse:asserttype(mod)==nn.Embedding,'nnq.'+cls.__name__+'.from_float only works for '+ \
nn.Embedding.__name__asserthasattr(mod,'qconfig'),'Embedding input float module must have qconfig defined'fromtorch.ao.quantizationimportfloat_qparams_weight_only_qconfigifmod.qconfigisnotNoneandmod.qconfig.weightisnotNone:# type: ignore[union-attr]weight_observer=mod.qconfig.weight()# type: ignore[union-attr, operator]else:weight_observer=float_qparams_weight_only_qconfig.weight()dtype=weight_observer.dtypeis_float_qparams_qconfig=weight_observer.qscheme==torch.per_channel_affine_float_qparamsassertis_float_qparams_qconfig, \
'Embedding quantization is only supported with float_qparams_weight_only_qconfig.'assertdtype==torch.quint8ordtype==torch.quint4x2, \
f'The only supported dtype for nnq.Embedding is torch.quint8 and torch.quint4x2, got {dtype}'# Run the observer to calculate qparams.weight_observer(mod.weight)qweight=_quantize_weight(mod.weight.float(),weight_observer)# Create quantized Embedding module and pass in the quantized weightqembedding=Embedding(mod.num_embeddings,mod.embedding_dim)qembedding.set_weight(qweight)returnqembedding
[docs]classEmbeddingBag(Embedding):r""" A quantized EmbeddingBag module with quantized packed weights as inputs. We adopt the same interface as `torch.nn.EmbeddingBag`, please see https://pytorch.org/docs/stable/nn.html#torch.nn.EmbeddingBag for documentation. Similar to :class:`~torch.nn.EmbeddingBag`, attributes will be randomly initialized at module creation time and will be overwritten later Attributes: weight (Tensor): the non-learnable quantized weights of the module of shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`. Examples:: >>> m = nn.quantized.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True, mode='sum') >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) >>> offsets = torch.tensor([0, 19, 20, 28, 28, 32]) >>> output = m(indices, offsets) >>> print(output.size()) torch.Size([5, 12]) """_version=1def__init__(self,num_embeddings:int,embedding_dim:int,max_norm:Optional[float]=None,norm_type:float=2.,scale_grad_by_freq:bool=False,mode:str='sum',sparse:bool=False,_weight:Optional[Tensor]=None,include_last_offset:bool=False,dtype=torch.quint8)->None:super().__init__(num_embeddings,embedding_dim,_weight=_weight,dtype=dtype)self.mode=modeself.pruned_weights=Falseself.include_last_offset=include_last_offsetself.dtype=dtypedefforward(self,indices:Tensor,offsets:Optional[Tensor]=None,per_sample_weights:Optional[Tensor]=None,compressed_indices_mapping:Optional[Tensor]=None)->Tensor:ifself.dtype==torch.quint4x2:returntorch.ops.quantized.embedding_bag_4bit(self._packed_params._packed_weight,indices,offsets,False,0,self.pruned_weights,per_sample_weights,compressed_indices_mapping,self.include_last_offset)else:returntorch.ops.quantized.embedding_bag_byte(self._packed_params._packed_weight,indices,offsets,False,0,self.pruned_weights,per_sample_weights,compressed_indices_mapping,self.include_last_offset)def_get_name(self):return'QuantizedEmbeddingBag'
[docs]@classmethoddeffrom_float(cls,mod):r"""Create a quantized embedding_bag module from a float module Args: mod (Module): a float module, either produced by torch.ao.quantization utilities or provided by user """ifhasattr(mod,'weight_fake_quant'):weight_observer=mod.weight_fake_quantelse:asserttype(mod)==nn.EmbeddingBag,'nnq.'+cls.__name__+'.from_float only works for '+ \
nn.EmbeddingBag.__name__asserthasattr(mod,'qconfig'),'EmbeddingBag input float module must have qconfig defined'fromtorch.ao.quantization.qconfigimportfloat_qparams_weight_only_qconfigifmod.qconfigisnotNoneandmod.qconfig.weightisnotNone:# type: ignore[union-attr]weight_observer=mod.qconfig.weight()# type: ignore[union-attr, operator]else:weight_observer=float_qparams_weight_only_qconfig.weight()dtype=weight_observer.dtypeis_float_qparams_qconfig=weight_observer.qscheme==torch.per_channel_affine_float_qparamsassertis_float_qparams_qconfig, \
'EmbeddingBag quantization is only supported with float_qparams_weight_only_qconfig.'assertdtype==torch.quint8ordtype==torch.quint4x2, \
f'The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2, got {dtype}'# Run the observer to calculate qparams.weight_observer(mod.weight)qweight=_quantize_weight(mod.weight.float(),weight_observer)# Create quantized EmbeddingBag module and pass in the quantized weightqembedding_bag=EmbeddingBag(mod.num_embeddings,mod.embedding_dim,dtype=dtype)qembedding_bag.set_weight(qweight)returnqembedding_bag
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.