
Source code for

# mypy: allow-untyped-defs
import warnings
from typing import Optional, Tuple

import torch
import torch.jit  # this is needed to avoid a circular import
import torch.nn.functional as F
from torch import nn, Tensor

__all__ = ["MultiheadAttention"]

[docs]class MultiheadAttention(nn.MultiheadAttention): _FLOAT_MODULE = nn.MultiheadAttention r"""Quantizable implementation of the MultiheadAttention. Note:: Please, refer to :class:`~torch.nn.MultiheadAttention` for more information Allows the model to jointly attend to information from different representation subspaces. See reference: Attention Is All You Need The original MHA module is not quantizable. This reimplements it by explicitly instantiating the linear layers. .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) Args: embed_dim: total dimension of the model. num_heads: parallel attention heads. dropout: a Dropout layer on attn_output_weights. Default: 0.0. bias: add bias as module parameter. Default: True. add_bias_kv: add bias to the key and value sequences at dim=0. add_zero_attn: add a new batch of zeros to the key and value sequences at dim=1. kdim: total number of features in key. Default: None. vdim: total number of features in value. Default: None. batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False`` (seq, batch, feature). Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set to :attr:`embed_dim` such that query, key, and value have the same number of features. Examples:: >>> import as nnqa >>> multihead_attn = nnqa.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value) Note:: Please, follow the quantization flow to convert the quantizable MHA. """ __constants__ = ["batch_first"] def __init__( self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, add_bias_kv: bool = False, add_zero_attn: bool = False, kdim: Optional[int] = None, vdim: Optional[int] = None, batch_first: bool = False, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__( embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first, **factory_kwargs, ) self.linear_Q = nn.Linear( self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs ) self.linear_K = nn.Linear( self.kdim, self.embed_dim, bias=bias, **factory_kwargs ) self.linear_V = nn.Linear( self.vdim, self.embed_dim, bias=bias, **factory_kwargs ) # for the type: ignore, see self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs) # type: ignore[assignment] # Functionals self.q_scaling_product = # note: importing at top creates a circular import # Quant/Dequant self.quant_attn_output = self.quant_attn_output_weights = self.dequant_q = self.dequant_k = self.dequant_v = def _get_name(self): return "QuantizableMultiheadAttention" @classmethod def from_float(cls, other): assert type(other) == cls._FLOAT_MODULE assert hasattr(other, "qconfig"), "The float module must have 'qconfig'" # Setting the dropout to 0.0! observed = cls( other.embed_dim, other.num_heads, other.dropout, (other.in_proj_bias is not None), (other.bias_k is not None), other.add_zero_attn, other.kdim, other.vdim, other.batch_first, ) observed.bias_k = other.bias_k observed.bias_v = other.bias_v observed.qconfig = other.qconfig # Set the linear weights # for the type: ignores, see observed.out_proj.weight = other.out_proj.weight # type: ignore[has-type] observed.out_proj.bias = other.out_proj.bias # type: ignore[has-type] if other._qkv_same_embed_dim: # Use separate params bias = other.in_proj_bias _start = 0 _end = _start + other.embed_dim weight = other.in_proj_weight[_start:_end, :] if bias is not None: bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad) observed.linear_Q.weight = torch.nn.Parameter(weight, weight.requires_grad) observed.linear_Q.bias = bias bias = other.in_proj_bias _start = _end _end = _start + other.embed_dim weight = other.in_proj_weight[_start:_end, :] if bias is not None: bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad) observed.linear_K.weight = torch.nn.Parameter(weight, weight.requires_grad) observed.linear_K.bias = bias bias = other.in_proj_bias _start = _end weight = other.in_proj_weight[_start:, :] if bias is not None: bias = torch.nn.Parameter(bias[_start:], bias.requires_grad) observed.linear_V.weight = torch.nn.Parameter(weight, weight.requires_grad) observed.linear_V.bias = bias else: observed.linear_Q.weight = nn.Parameter(other.q_proj_weight) observed.linear_K.weight = nn.Parameter(other.k_proj_weight) observed.linear_V.weight = nn.Parameter(other.v_proj_weight) if other.in_proj_bias is None: observed.linear_Q.bias = None # type: ignore[assignment] observed.linear_K.bias = None # type: ignore[assignment] observed.linear_V.bias = None # type: ignore[assignment] else: observed.linear_Q.bias = nn.Parameter( other.in_proj_bias[0 : other.embed_dim] ) observed.linear_K.bias = nn.Parameter( other.in_proj_bias[other.embed_dim : (other.embed_dim * 2)] ) observed.linear_V.bias = nn.Parameter( other.in_proj_bias[(other.embed_dim * 2) :] ) observed.eval() # Explicit prepare observed =, inplace=True) return observed
[docs] @torch.jit.unused def dequantize(self): r"""Utility to convert the quantized MHA back to float. The motivation for this is that it is not trivial to conver the weights from the format that is used in the quantized version back to the float. """ fp = self._FLOAT_MODULE( self.embed_dim, self.num_heads, self.dropout, (self.linear_Q._weight_bias()[1] is not None), (self.bias_k is not None), self.add_zero_attn, self.kdim, self.vdim, self.batch_first, ) assert fp._qkv_same_embed_dim == self._qkv_same_embed_dim if self.bias_k is not None: fp.bias_k = nn.Parameter(self.bias_k.dequantize()) if self.bias_v is not None: fp.bias_v = nn.Parameter(self.bias_v.dequantize()) # Set the linear weights # Note: Because the linear layers are quantized, mypy does not nkow how # to deal with them -- might need to ignore the typing checks. # for the type: ignore[has-type], see w, b = self.out_proj._weight_bias() # type: ignore[operator, has-type] fp.out_proj.weight = nn.Parameter(w.dequantize()) if b is not None: fp.out_proj.bias = nn.Parameter(b) wQ, bQ = self.linear_Q._weight_bias() # type: ignore[operator] wQ = wQ.dequantize() wK, bK = self.linear_K._weight_bias() # type: ignore[operator] wK = wK.dequantize() wV, bV = self.linear_V._weight_bias() # type: ignore[operator] wV = wV.dequantize() if fp._qkv_same_embed_dim: # Use separate params _start = 0 _end = _start + fp.embed_dim fp.in_proj_weight[_start:_end, :] = wQ if fp.in_proj_bias is not None: assert all(bQ == 0) fp.in_proj_bias[_start:_end] = bQ _start = _end _end = _start + fp.embed_dim fp.in_proj_weight[_start:_end, :] = wK if fp.in_proj_bias is not None: assert all(bK == 0) fp.in_proj_bias[_start:_end] = bK _start = _end fp.in_proj_weight[_start:, :] = wV if fp.in_proj_bias is not None: assert all(bV == 0) fp.in_proj_bias[_start:] = bV else: fp.q_proj_weight = nn.Parameter(wQ) fp.k_proj_weight = nn.Parameter(wK) fp.v_proj_weight = nn.Parameter(wV) if fp.in_proj_bias is None: self.linear_Q.bias = None self.linear_K.bias = None self.linear_V.bias = None else: fp.in_proj_bias[0 : fp.embed_dim] = bQ fp.in_proj_bias[fp.embed_dim : (fp.embed_dim * 2)] = bK fp.in_proj_bias[(fp.embed_dim * 2) :] = bV return fp
@classmethod def from_observed(cls, other): # The whole flow is float -> observed -> quantized # This class does float -> observed only # See nn.quantized.MultiheadAttention raise NotImplementedError( "It looks like you are trying to prepare an " "MHA module. Please, see " "the examples on quantizable MHAs." )
[docs] def forward( self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, average_attn_weights: bool = True, is_causal: bool = False, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Note:: Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more information Args: query, key, value: map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. key_padding_mask: if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. Shape: - Inputs: - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``. - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a BoolTensor is provided, positions with ``True`` is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. - is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask. Default: ``False``. - average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect when ``need_weights=True.``. Default: True (i.e. average weights across heads) - Outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. - attn_output_weights: If ``average_attn_weights=True``, returns attention weights averaged across heads of shape :math:`(N, L, S)`, where N is the batch size, L is the target sequence length, S is the source sequence length. If ``average_attn_weights=False``, returns attention weights per head of shape :math:`(N, num_heads, L, S)`. """ return self._forward_impl( query, key, value, key_padding_mask, need_weights, attn_mask, average_attn_weights, is_causal, )
def _forward_impl( self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, average_attn_weights: bool = True, is_causal: bool = False, ) -> Tuple[Tensor, Optional[Tensor]]: # This version will not deal with the static key/value pairs. # Keeping it here for future changes. # # TODO: This method has some duplicate lines with the # `torch.nn.functional.multi_head_attention`. Will need to refactor. static_k = None static_v = None if attn_mask is not None and is_causal: raise AssertionError("Only allow causal mask or attn_mask") if is_causal: raise AssertionError("causal mask not supported by AO MHA module") if self.batch_first: query, key, value = (x.transpose(0, 1) for x in (query, key, value)) tgt_len, bsz, embed_dim_to_check = query.size() assert self.embed_dim == embed_dim_to_check # allow MHA to have different sizes for the feature dimension assert key.size(0) == value.size(0) and key.size(1) == value.size(1) head_dim = self.embed_dim // self.num_heads assert ( head_dim * self.num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" scaling = float(head_dim) ** -0.5 q = self.linear_Q(query) k = self.linear_K(key) v = self.linear_V(value) q = self.q_scaling_product.mul_scalar(q, scaling) if attn_mask is not None: if attn_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for `attn_mask` in `nn.MultiheadAttention` is deprecated. " "Use bool tensor instead.", stacklevel=3, ) attn_mask = assert ( attn_mask.is_floating_point() or attn_mask.dtype == torch.bool ), f"Only float and bool types are supported for attn_mask, not {attn_mask.dtype}" if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * self.num_heads, query.size(0), key.size(0), ]: raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( f"attn_mask's dimension {attn_mask.dim()} is not supported" ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for `key_padding_mask` in `nn.MultiheadAttention` is deprecated. " "Use bool tensor instead.", stacklevel=3, ) key_padding_mask = if self.bias_k is not None and self.bias_v is not None: if static_k is None and static_v is None: # Explicitly assert that bias_k and bias_v are not None # in a way that TorchScript can understand. bias_k = self.bias_k assert bias_k is not None bias_v = self.bias_v assert bias_v is not None k =[k, bias_k.repeat(1, bsz, 1)]) v =[v, bias_v.repeat(1, bsz, 1)]) if attn_mask is not None: attn_mask = F.pad(attn_mask, (0, 1)) if key_padding_mask is not None: key_padding_mask = F.pad(key_padding_mask, (0, 1)) else: assert static_k is None, "bias cannot be added to static key." assert static_v is None, "bias cannot be added to static value." else: assert self.bias_k is None assert self.bias_v is None q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1) if k is not None: k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1) if v is not None: v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1) if static_k is not None: assert static_k.size(0) == bsz * self.num_heads assert static_k.size(2) == head_dim k = static_k if static_v is not None: assert static_v.size(0) == bsz * self.num_heads assert static_v.size(2) == head_dim v = static_v src_len = k.size(1) if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len if self.add_zero_attn: src_len += 1 k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:]) if k.is_quantized: k_zeros = torch.quantize_per_tensor( k_zeros, k.q_scale(), k.q_zero_point(), k.dtype ) k =[k, k_zeros], dim=1) v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:]) if v.is_quantized: v_zeros = torch.quantize_per_tensor( v_zeros, v.q_scale(), v.q_zero_point(), v.dtype ) v =[v, v_zeros], dim=1) if attn_mask is not None: attn_mask = F.pad(attn_mask, (0, 1)) if key_padding_mask is not None: key_padding_mask = F.pad(key_padding_mask, (0, 1)) # Leaving the quantized zone here q = self.dequant_q(q) k = self.dequant_k(k) v = self.dequant_v(v) attn_output_weights = torch.bmm(q, k.transpose(1, 2)) assert list(attn_output_weights.size()) == [ bsz * self.num_heads, tgt_len, src_len, ] if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_output_weights.masked_fill_(attn_mask, float("-inf")) else: attn_output_weights += attn_mask if key_padding_mask is not None: attn_output_weights = attn_output_weights.view( bsz, self.num_heads, tgt_len, src_len ) attn_output_weights = attn_output_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"), ) attn_output_weights = attn_output_weights.view( bsz * self.num_heads, tgt_len, src_len ) attn_output_weights = F.softmax(attn_output_weights, dim=-1) attn_output_weights = F.dropout( attn_output_weights, p=self.dropout, ) attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim] if self.batch_first: attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) else: attn_output = ( attn_output.transpose(0, 1) .contiguous() .view(tgt_len, bsz, self.embed_dim) ) # Reentering the quantized zone attn_output = self.quant_attn_output(attn_output) # for the type: ignore[has-type], see attn_output = self.out_proj(attn_output) # type: ignore[has-type] attn_output_weights = self.quant_attn_output_weights(attn_output_weights) if need_weights: # average attention weights over heads attn_output_weights = attn_output_weights.view( bsz, self.num_heads, tgt_len, src_len ) if average_attn_weights: attn_output_weights = attn_output_weights.mean(dim=1) return attn_output, attn_output_weights else: return attn_output, None


