# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.fromtypingimportOptionalfromtorchimportnn,Tensorfromtorchtune.modules.kv_cacheimportKVCache
[docs]classCausalSelfAttention(nn.Module):"""Multi-headed grouped query self-attention (GQA) layer introduced in https://arxiv.org/abs/2305.13245v1. GQA is a version of multiheaded attention (MHA) which uses fewer key/value heads than query heads by grouping n query heads for each key and value head. Multi-Query Attention is an extreme version where we have a single key and value head shared by all query heads. Following is an example of MHA, GQA and MQA with num_heads = 4 (credit for the documentation: https://github.com/Lightning-AI/lit-gpt/blob/main/lit_gpt/config.py). :: ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ │ │ │ │ │ │ │ ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ MHA GQA MQA n_kv_heads =4 n_kv_heads=2 n_kv_heads=1 Args: embed_dim (int): embedding dimension for the model num_heads (int): number of query heads. For MHA this is also the number of heads for key and value num_kv_heads (int): number of key and value heads. If specified, user should ensure `num_heads` % `num_kv_heads` == 0. Default value is `None`, in which case this is the same as MHA head_dim (int): dimension of each head, calculated by ``embed_dim`` // ``num_heads``. q_proj (nn.Module): projection layer for query. k_proj (nn.Module): projection layer for key. v_proj (nn.Module): projection layer for value. output_proj (nn.Module): projection layer for output. pos_embeddings (nn.Module): positional embeddings layer, e.g. RotaryPositionalEmbeddings. kv_cache (Optional[KVCache]): KVCache object used to cache key and value. If not specified, then no caching is used. max_seq_len (int): maximum sequence length supported by the model. This is needed to compute the RoPE Cache. Default: 4096. attn_dropout (float): dropout value passed onto the scaled_dot_product_attention function. This argument is ignored if the self.training is False. Default value is 0.0. Raises: ValueError: If `num_heads` % `num_kv_heads` != 0 ValueError: If `embed_dim` % `num_heads` != 0 ValueError: If `attn_dropout` < 0 or > 1 """def__init__(self,embed_dim:int,num_heads:int,num_kv_heads:int,head_dim:int,q_proj:nn.Module,k_proj:nn.Module,v_proj:nn.Module,output_proj:nn.Module,pos_embeddings:nn.Module,kv_cache:Optional[KVCache]=None,max_seq_len:int=4096,attn_dropout:float=0.0,)->None:super().__init__()ifnum_heads%num_kv_heads!=0:raiseValueError(f"num_heads ({num_heads}) must be divisible by "f"num_kv_heads ({num_kv_heads})")ifembed_dim%num_heads!=0:raiseValueError(f"embed_dim ({embed_dim}) must be divisible by "f"num_heads ({num_heads})")ifattn_dropout<0orattn_dropout>1:raiseValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0")# Set attributesself.num_heads=num_headsself.num_kv_heads=num_kv_headsself.embed_dim=embed_dimself.attn_dropout=attn_dropoutself.head_dim=head_dimself.max_seq_len=max_seq_len# Set layersself.kv_cache=kv_cacheself.q_proj=q_projself.k_proj=k_projself.v_proj=v_projself.output_proj=output_projself.pos_embeddings=pos_embeddings
[docs]defforward(self,x:Tensor,*,mask:Optional[Tensor]=None,input_pos:Optional[Tensor]=None,)->Tensor:""" Args: x (Tensor): input tensor with shape [batch_size x seq_length x embed_dim] mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask with shape [batch_size x seq_length x seq_length]. This is applied after the query-key multiplication and before the softmax. A value of True in row i and column j means token i attends to token j. A value of False means token i does not attend to token j. If no mask is specified, a causal mask is used by default. Default is None. input_pos (Optional[Tensor]): Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. If none, assume the index of the token is its position id. Default is None. Returns: Tensor: output tensor with attention applied Raises: ValueError: if seq_len of x is bigger than max_seq_len Notation used for tensor shapes: - b: batch size - s: sequence length - n_h: num heads - n_kv: num kv heads - d: embed dim - h_d: head dim TODO: - Return the attention weights - Make application of positional embeddings optional """# input has shape [b, s, d]bsz,seq_len,_=x.shapeifseq_len>self.max_seq_len:raiseValueError(f"seq_len ({seq_len}) of input tensor should be smaller "f"than max_seq_len ({self.max_seq_len})")# q has shape [b, s, num_heads * head_dim]# k has shape [b, s, num_kv_heads * head_dim]# v has shape [b, s, num_kv_heads * head_dim]q=self.q_proj(x)k=self.k_proj(x)v=self.v_proj(x)# number of queries per key/valueq_per_kv=self.num_heads//self.num_kv_heads# q: [b, s, n_kv, q_per_kv, h_d]# k: [b, s, n_kv, 1, h_d]# v: [b, s, n_kv, 1, h_d]q=q.view(bsz,seq_len,self.num_kv_heads,q_per_kv,self.head_dim)k=k.view(bsz,seq_len,self.num_kv_heads,1,self.head_dim)v=v.view(bsz,seq_len,self.num_kv_heads,1,self.head_dim)# if needed, expand the key and value tensors to have the same shape# as the query tensor by copying values across the relevant dimifself.num_heads!=self.num_kv_heads:k=k.expand(bsz,seq_len,self.num_kv_heads,q_per_kv,self.head_dim)v=v.expand(bsz,seq_len,self.num_kv_heads,q_per_kv,self.head_dim)# llama2 applies the RoPE embeddings on tensors with shape# [b, s, n_h, h_d]# Reshape the tensors before we apply RoPEq=q.reshape(bsz,seq_len,-1,self.head_dim)k=k.reshape(bsz,seq_len,-1,self.head_dim)v=v.reshape(bsz,seq_len,-1,self.head_dim)# Apply positional embeddingsq=self.pos_embeddings(q,input_pos=input_pos)k=self.pos_embeddings(k,input_pos=input_pos)# [b, n_h, s, h_d]q=q.transpose(1,2)k=k.transpose(1,2)v=v.transpose(1,2)# Update key-value cacheifself.kv_cacheisnotNone:k,v=self.kv_cache.update(input_pos,k,v)# shape: [b, 1, s, s]ifmaskisnotNone:mask=mask[:,None,:,:]# Flash attention from https://pytorch.org/blog/accelerating-large-language-models/output=nn.functional.scaled_dot_product_attention(q,k,v,attn_mask=mask,dropout_p=self.attn_dropout,is_causal=self.kv_cacheisNoneandmaskisNone,)# reshape the output to be the same shape as the inputoutput=output.transpose(1,2).contiguous().view(bsz,seq_len,-1)returnself.output_proj(output)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.