Shortcuts

Source code for torchtune.modules.attention

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

from torch import nn, Tensor
from torchtune.modules.kv_cache import KVCache


[docs]class CausalSelfAttention(nn.Module): """Multi-headed grouped query self-attention (GQA) layer introduced in https://arxiv.org/abs/2305.13245v1. GQA is a version of multiheaded attention (MHA) which uses fewer key/value heads than query heads by grouping n query heads for each key and value head. Multi-Query Attention is an extreme version where we have a single key and value head shared by all query heads. Following is an example of MHA, GQA and MQA with num_heads = 4 (credit for the documentation: https://github.com/Lightning-AI/lit-gpt/blob/main/lit_gpt/config.py). :: ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ │ │ │ │ │ │ │ ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ MHA GQA MQA n_kv_heads =4 n_kv_heads=2 n_kv_heads=1 Args: embed_dim (int): embedding dimension for the model num_heads (int): number of query heads. For MHA this is also the number of heads for key and value num_kv_heads (int): number of key and value heads. If specified, user should ensure `num_heads` % `num_kv_heads` == 0. Default value is `None`, in which case this is the same as MHA head_dim (int): dimension of each head, calculated by ``embed_dim`` // ``num_heads``. q_proj (nn.Module): projection layer for query. k_proj (nn.Module): projection layer for key. v_proj (nn.Module): projection layer for value. output_proj (nn.Module): projection layer for output. pos_embeddings (nn.Module): positional embeddings layer, e.g. RotaryPositionalEmbeddings. kv_cache (Optional[KVCache]): KVCache object used to cache key and value. If not specified, then no caching is used. max_seq_len (int): maximum sequence length supported by the model. This is needed to compute the RoPE Cache. Default: 4096. attn_dropout (float): dropout value passed onto the scaled_dot_product_attention function. This argument is ignored if the self.training is False. Default value is 0.0. Raises: ValueError: If `num_heads` % `num_kv_heads` != 0 ValueError: If `embed_dim` % `num_heads` != 0 ValueError: If `attn_dropout` < 0 or > 1 """ def __init__( self, embed_dim: int, num_heads: int, num_kv_heads: int, head_dim: int, q_proj: nn.Module, k_proj: nn.Module, v_proj: nn.Module, output_proj: nn.Module, pos_embeddings: nn.Module, kv_cache: Optional[KVCache] = None, max_seq_len: int = 4096, attn_dropout: float = 0.0, ) -> None: super().__init__() if num_heads % num_kv_heads != 0: raise ValueError( f"num_heads ({num_heads}) must be divisible by " f"num_kv_heads ({num_kv_heads})" ) if embed_dim % num_heads != 0: raise ValueError( f"embed_dim ({embed_dim}) must be divisible by " f"num_heads ({num_heads})" ) if attn_dropout < 0 or attn_dropout > 1: raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0") # Set attributes self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.embed_dim = embed_dim self.attn_dropout = attn_dropout self.head_dim = head_dim self.max_seq_len = max_seq_len # Set layers self.kv_cache = kv_cache self.q_proj = q_proj self.k_proj = k_proj self.v_proj = v_proj self.output_proj = output_proj self.pos_embeddings = pos_embeddings
[docs] def forward( self, x: Tensor, *, mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None, ) -> Tensor: """ Args: x (Tensor): input tensor with shape [batch_size x seq_length x embed_dim] mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask with shape [batch_size x seq_length x seq_length]. This is applied after the query-key multiplication and before the softmax. A value of True in row i and column j means token i attends to token j. A value of False means token i does not attend to token j. If no mask is specified, a causal mask is used by default. Default is None. input_pos (Optional[Tensor]): Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. If none, assume the index of the token is its position id. Default is None. Returns: Tensor: output tensor with attention applied Raises: ValueError: if seq_len of x is bigger than max_seq_len Notation used for tensor shapes: - b: batch size - s: sequence length - n_h: num heads - n_kv: num kv heads - d: embed dim - h_d: head dim TODO: - Return the attention weights - Make application of positional embeddings optional """ # input has shape [b, s, d] bsz, seq_len, _ = x.shape if seq_len > self.max_seq_len: raise ValueError( f"seq_len ({seq_len}) of input tensor should be smaller " f"than max_seq_len ({self.max_seq_len})" ) # q has shape [b, s, num_heads * head_dim] # k has shape [b, s, num_kv_heads * head_dim] # v has shape [b, s, num_kv_heads * head_dim] q = self.q_proj(x) k = self.k_proj(x) v = self.v_proj(x) # number of queries per key/value q_per_kv = self.num_heads // self.num_kv_heads # q: [b, s, n_kv, q_per_kv, h_d] # k: [b, s, n_kv, 1, h_d] # v: [b, s, n_kv, 1, h_d] q = q.view(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim) k = k.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) v = v.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) # if needed, expand the key and value tensors to have the same shape # as the query tensor by copying values across the relevant dim if self.num_heads != self.num_kv_heads: k = k.expand(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim) v = v.expand(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim) # llama2 applies the RoPE embeddings on tensors with shape # [b, s, n_h, h_d] # Reshape the tensors before we apply RoPE q = q.reshape(bsz, seq_len, -1, self.head_dim) k = k.reshape(bsz, seq_len, -1, self.head_dim) v = v.reshape(bsz, seq_len, -1, self.head_dim) # Apply positional embeddings q = self.pos_embeddings(q, input_pos=input_pos) k = self.pos_embeddings(k, input_pos=input_pos) # [b, n_h, s, h_d] q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # Update key-value cache if self.kv_cache is not None: k, v = self.kv_cache.update(input_pos, k, v) # shape: [b, 1, s, s] if mask is not None: mask = mask[:, None, :, :] # Flash attention from https://pytorch.org/blog/accelerating-large-language-models/ output = nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=self.attn_dropout, is_causal=self.kv_cache is None and mask is None, ) # reshape the output to be the same shape as the input output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) return self.output_proj(output)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources