Source code for torch.ao.nn.quantizable.modules.activation
importtorchimporttorch.jit# this is needed to avoid a circular importfromtorchimportnnimporttorch.nn.functionalasnnFfromtorchimportTensorfromtypingimportOptional,Tupleimportwarnings__all__=["MultiheadAttention"]
[docs]classMultiheadAttention(nn.MultiheadAttention):_FLOAT_MODULE=nn.MultiheadAttentionr"""Quantizable implementation of the MultiheadAttention. Note:: Please, refer to :class:`~torch.nn.MultiheadAttention` for more information Allows the model to jointly attend to information from different representation subspaces. See reference: Attention Is All You Need The original MHA module is not quantizable. This reimplements it by explicitly instantiating the linear layers. .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) Args: embed_dim: total dimension of the model. num_heads: parallel attention heads. dropout: a Dropout layer on attn_output_weights. Default: 0.0. bias: add bias as module parameter. Default: True. add_bias_kv: add bias to the key and value sequences at dim=0. add_zero_attn: add a new batch of zeros to the key and value sequences at dim=1. kdim: total number of features in key. Default: None. vdim: total number of features in value. Default: None. batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False`` (seq, batch, feature). Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set to :attr:`embed_dim` such that query, key, and value have the same number of features. Examples:: >>> import torch.ao.nn.quantizable as nnqa >>> multihead_attn = nnqa.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value) Note:: Please, follow the quantization flow to convert the quantizable MHA. """__constants__=['batch_first']def__init__(self,embed_dim:int,num_heads:int,dropout:float=0.,bias:bool=True,add_bias_kv:bool=False,add_zero_attn:bool=False,kdim:Optional[int]=None,vdim:Optional[int]=None,batch_first:bool=False,device=None,dtype=None)->None:factory_kwargs={'device':device,'dtype':dtype}super().__init__(embed_dim,num_heads,dropout,bias,add_bias_kv,add_zero_attn,kdim,vdim,batch_first,**factory_kwargs)self.linear_Q=nn.Linear(self.embed_dim,self.embed_dim,bias=bias,**factory_kwargs)self.linear_K=nn.Linear(self.kdim,self.embed_dim,bias=bias,**factory_kwargs)self.linear_V=nn.Linear(self.vdim,self.embed_dim,bias=bias,**factory_kwargs)# for the type: ignore, see https://github.com/pytorch/pytorch/issues/58969self.out_proj=nn.Linear(self.embed_dim,self.embed_dim,bias=bias,**factory_kwargs)# type: ignore[assignment]# Functionalsself.q_scaling_product=torch.ao.nn.quantized.FloatFunctional()# note: importing torch.ao.nn.quantized at top creates a circular import# Quant/Dequantself.quant_attn_output=torch.ao.quantization.QuantStub()self.quant_attn_output_weights=torch.ao.quantization.QuantStub()self.dequant_q=torch.ao.quantization.DeQuantStub()self.dequant_k=torch.ao.quantization.DeQuantStub()self.dequant_v=torch.ao.quantization.DeQuantStub()def_get_name(self):return'QuantizableMultiheadAttention'@classmethoddeffrom_float(cls,other):asserttype(other)==cls._FLOAT_MODULEasserthasattr(other,'qconfig'),"The float module must have 'qconfig'"# Setting the dropout to 0.0!observed=cls(other.embed_dim,other.num_heads,other.dropout,(other.in_proj_biasisnotNone),(other.bias_kisnotNone),other.add_zero_attn,other.kdim,other.vdim,other.batch_first)observed.bias_k=other.bias_kobserved.bias_v=other.bias_vobserved.qconfig=other.qconfig# Set the linear weights# for the type: ignores, see https://github.com/pytorch/pytorch/issues/58969observed.out_proj.weight=other.out_proj.weight# type: ignore[has-type]observed.out_proj.bias=other.out_proj.bias# type: ignore[has-type]ifother._qkv_same_embed_dim:# Use separate paramsbias=other.in_proj_bias_start=0_end=_start+other.embed_dimweight=other.in_proj_weight[_start:_end,:]ifbiasisnotNone:bias=torch.nn.Parameter(bias[_start:_end],bias.requires_grad)observed.linear_Q.weight=torch.nn.Parameter(weight,weight.requires_grad)observed.linear_Q.bias=biasbias=other.in_proj_bias_start=_end_end=_start+other.embed_dimweight=other.in_proj_weight[_start:_end,:]ifbiasisnotNone:bias=torch.nn.Parameter(bias[_start:_end],bias.requires_grad)observed.linear_K.weight=torch.nn.Parameter(weight,weight.requires_grad)observed.linear_K.bias=biasbias=other.in_proj_bias_start=_endweight=other.in_proj_weight[_start:,:]ifbiasisnotNone:bias=torch.nn.Parameter(bias[_start:],bias.requires_grad)observed.linear_V.weight=torch.nn.Parameter(weight,weight.requires_grad)observed.linear_V.bias=biaselse:observed.linear_Q.weight=nn.Parameter(other.q_proj_weight)observed.linear_K.weight=nn.Parameter(other.k_proj_weight)observed.linear_V.weight=nn.Parameter(other.v_proj_weight)ifother.in_proj_biasisNone:observed.linear_Q.bias=None# type: ignore[assignment]observed.linear_K.bias=None# type: ignore[assignment]observed.linear_V.bias=None# type: ignore[assignment]else:observed.linear_Q.bias=nn.Parameter(other.in_proj_bias[0:other.embed_dim])observed.linear_K.bias=nn.Parameter(other.in_proj_bias[other.embed_dim:(other.embed_dim*2)])observed.linear_V.bias=nn.Parameter(other.in_proj_bias[(other.embed_dim*2):])observed.eval()# Explicit prepareobserved=torch.ao.quantization.prepare(observed,inplace=True)returnobserved
[docs]@torch.jit.unuseddefdequantize(self):r"""Utility to convert the quantized MHA back to float. The motivation for this is that it is not trivial to conver the weights from the format that is used in the quantized version back to the float. """fp=self._FLOAT_MODULE(self.embed_dim,self.num_heads,self.dropout,(self.in_proj_biasisnotNone),(self.bias_kisnotNone),self.add_zero_attn,self.kdim,self.vdim,self.batch_first)assertfp._qkv_same_embed_dim==self._qkv_same_embed_dimifself.bias_kisnotNone:fp.bias_k=nn.Parameter(self.bias_k.dequantize())ifself.bias_visnotNone:fp.bias_v=nn.Parameter(self.bias_v.dequantize())# Set the linear weights# Note: Because the linear layers are quantized, mypy does not nkow how# to deal with them -- might need to ignore the typing checks.# for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969w,b=self.out_proj._weight_bias()# type: ignore[operator, has-type]fp.out_proj.weight=nn.Parameter(w.dequantize())ifbisnotNone:fp.out_proj.bias=nn.Parameter(b)wQ,bQ=self.linear_Q._weight_bias()# type: ignore[operator]wQ=wQ.dequantize()wK,bK=self.linear_K._weight_bias()# type: ignore[operator]wK=wK.dequantize()wV,bV=self.linear_V._weight_bias()# type: ignore[operator]wV=wV.dequantize()iffp._qkv_same_embed_dim:# Use separate params_start=0_end=_start+fp.embed_dimfp.in_proj_weight[_start:_end,:]=wQiffp.in_proj_biasisnotNone:assertall(bQ==0)fp.in_proj_bias[_start:_end]=bQ_start=_end_end=_start+fp.embed_dimfp.in_proj_weight[_start:_end,:]=wKiffp.in_proj_biasisnotNone:assertall(bK==0)fp.in_proj_bias[_start:_end]=bK_start=_endfp.in_proj_weight[_start:,:]=wViffp.in_proj_biasisnotNone:assertall(bV==0)fp.in_proj_bias[_start:]=bVelse:fp.q_proj_weight=nn.Parameter(wQ)fp.k_proj_weight=nn.Parameter(wK)fp.v_proj_weight=nn.Parameter(wV)iffp.in_proj_biasisNone:self.linear_Q.bias=Noneself.linear_K.bias=Noneself.linear_V.bias=Noneelse:fp.in_proj_bias[0:fp.embed_dim]=bQfp.in_proj_bias[fp.embed_dim:(fp.embed_dim*2)]=bKfp.in_proj_bias[(fp.embed_dim*2):]=bVreturnfp
@classmethoddeffrom_observed(cls,other):# The whole flow is float -> observed -> quantized# This class does float -> observed only# See nn.quantized.MultiheadAttentionraiseNotImplementedError("It looks like you are trying to prepare an ""MHA module. Please, see ""the examples on quantizable MHAs.")
[docs]defforward(self,query:Tensor,key:Tensor,value:Tensor,key_padding_mask:Optional[Tensor]=None,need_weights:bool=True,attn_mask:Optional[Tensor]=None,average_attn_weights:bool=True,is_causal:bool=False)->Tuple[Tensor,Optional[Tensor]]:r""" Note:: Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more information Args: query, key, value: map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. key_padding_mask: if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. Shape: - Inputs: - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``. - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a BoolTensor is provided, positions with ``True`` is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. - is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask. Default: ``False``. - average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect when ``need_weights=True.``. Default: True (i.e. average weights across heads) - Outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. - attn_output_weights: If ``average_attn_weights=True``, returns attention weights averaged across heads of shape :math:`(N, L, S)`, where N is the batch size, L is the target sequence length, S is the source sequence length. If ``average_attn_weights=False``, returns attention weights per head of shape :math:`(N, num_heads, L, S)`. """returnself._forward_impl(query,key,value,key_padding_mask,need_weights,attn_mask,average_attn_weights,is_causal)
def_forward_impl(self,query:Tensor,key:Tensor,value:Tensor,key_padding_mask:Optional[Tensor]=None,need_weights:bool=True,attn_mask:Optional[Tensor]=None,average_attn_weights:bool=True,is_causal:bool=False)->Tuple[Tensor,Optional[Tensor]]:# This version will not deal with the static key/value pairs.# Keeping it here for future changes.## TODO: This method has some duplicate lines with the# `torch.nn.functional.multi_head_attention`. Will need to refactor.static_k=Nonestatic_v=Noneifattn_maskisnotNoneandis_causal:raiseAssertionError("Only allow causal mask or attn_mask")ifis_causal:raiseAssertionError("causal mask not supported by AO MHA module")ifself.batch_first:query,key,value=(x.transpose(0,1)forxin(query,key,value))tgt_len,bsz,embed_dim_to_check=query.size()assertself.embed_dim==embed_dim_to_check# allow MHA to have different sizes for the feature dimensionassertkey.size(0)==value.size(0)andkey.size(1)==value.size(1)head_dim=self.embed_dim//self.num_headsasserthead_dim*self.num_heads==self.embed_dim,"embed_dim must be divisible by num_heads"scaling=float(head_dim)**-0.5q=self.linear_Q(query)k=self.linear_K(key)v=self.linear_V(value)q=self.q_scaling_product.mul_scalar(q,scaling)ifattn_maskisnotNone:ifattn_mask.dtype==torch.uint8:warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")attn_mask=attn_mask.to(torch.bool)assertattn_mask.is_floating_point()orattn_mask.dtype==torch.bool, \
f'Only float and bool types are supported for attn_mask, not {attn_mask.dtype}'ifattn_mask.dim()==2:attn_mask=attn_mask.unsqueeze(0)iflist(attn_mask.size())!=[1,query.size(0),key.size(0)]:raiseRuntimeError('The size of the 2D attn_mask is not correct.')elifattn_mask.dim()==3:iflist(attn_mask.size())!=[bsz*self.num_heads,query.size(0),key.size(0)]:raiseRuntimeError('The size of the 3D attn_mask is not correct.')else:raiseRuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")# attn_mask's dim is 3 now.# convert ByteTensor key_padding_mask to boolifkey_padding_maskisnotNoneandkey_padding_mask.dtype==torch.uint8:warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")key_padding_mask=key_padding_mask.to(torch.bool)ifself.bias_kisnotNoneandself.bias_visnotNone:ifstatic_kisNoneandstatic_visNone:# Explicitly assert that bias_k and bias_v are not None# in a way that TorchScript can understand.bias_k=self.bias_kassertbias_kisnotNonebias_v=self.bias_vassertbias_visnotNonek=torch.cat([k,bias_k.repeat(1,bsz,1)])v=torch.cat([v,bias_v.repeat(1,bsz,1)])ifattn_maskisnotNone:attn_mask=nnF.pad(attn_mask,(0,1))ifkey_padding_maskisnotNone:key_padding_mask=nnF.pad(key_padding_mask,(0,1))else:assertstatic_kisNone,"bias cannot be added to static key."assertstatic_visNone,"bias cannot be added to static value."else:assertself.bias_kisNoneassertself.bias_visNoneq=q.contiguous().view(tgt_len,bsz*self.num_heads,head_dim).transpose(0,1)ifkisnotNone:k=k.contiguous().view(-1,bsz*self.num_heads,head_dim).transpose(0,1)ifvisnotNone:v=v.contiguous().view(-1,bsz*self.num_heads,head_dim).transpose(0,1)ifstatic_kisnotNone:assertstatic_k.size(0)==bsz*self.num_headsassertstatic_k.size(2)==head_dimk=static_kifstatic_visnotNone:assertstatic_v.size(0)==bsz*self.num_headsassertstatic_v.size(2)==head_dimv=static_vsrc_len=k.size(1)ifkey_padding_maskisnotNone:assertkey_padding_mask.size(0)==bszassertkey_padding_mask.size(1)==src_lenifself.add_zero_attn:src_len+=1k_zeros=torch.zeros((k.size(0),1)+k.size()[2:])ifk.is_quantized:k_zeros=torch.quantize_per_tensor(k_zeros,k.q_scale(),k.q_zero_point(),k.dtype)k=torch.cat([k,k_zeros],dim=1)v_zeros=torch.zeros((v.size(0),1)+k.size()[2:])ifv.is_quantized:v_zeros=torch.quantize_per_tensor(v_zeros,v.q_scale(),v.q_zero_point(),v.dtype)v=torch.cat([v,v_zeros],dim=1)ifattn_maskisnotNone:attn_mask=nnF.pad(attn_mask,(0,1))ifkey_padding_maskisnotNone:key_padding_mask=nnF.pad(key_padding_mask,(0,1))# Leaving the quantized zone hereq=self.dequant_q(q)k=self.dequant_k(k)v=self.dequant_v(v)attn_output_weights=torch.bmm(q,k.transpose(1,2))assertlist(attn_output_weights.size())==[bsz*self.num_heads,tgt_len,src_len]ifattn_maskisnotNone:ifattn_mask.dtype==torch.bool:attn_output_weights.masked_fill_(attn_mask,float('-inf'))else:attn_output_weights+=attn_maskifkey_padding_maskisnotNone:attn_output_weights=attn_output_weights.view(bsz,self.num_heads,tgt_len,src_len)attn_output_weights=attn_output_weights.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2),float('-inf'),)attn_output_weights=attn_output_weights.view(bsz*self.num_heads,tgt_len,src_len)attn_output_weights=nnF.softmax(attn_output_weights,dim=-1)attn_output_weights=nnF.dropout(attn_output_weights,p=self.dropout,training=self.training)attn_output=torch.bmm(attn_output_weights,v)assertlist(attn_output.size())==[bsz*self.num_heads,tgt_len,head_dim]ifself.batch_first:attn_output=attn_output.view(bsz,tgt_len,self.embed_dim)else:attn_output=attn_output.transpose(0,1).contiguous().view(tgt_len,bsz,self.embed_dim)# Reentering the quantized zoneattn_output=self.quant_attn_output(attn_output)# for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969attn_output=self.out_proj(attn_output)# type: ignore[has-type]attn_output_weights=self.quant_attn_output_weights(attn_output_weights)ifneed_weights:# average attention weights over headsattn_output_weights=attn_output_weights.view(bsz,self.num_heads,tgt_len,src_len)ifaverage_attn_weights:attn_output_weights=attn_output_weights.mean(dim=1)returnattn_output,attn_output_weightselse:returnattn_output,None
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.