Source code for torchtune.modules.attention
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
from torch import nn, Tensor
from torchtune.modules.kv_cache import KVCache
[docs]class CausalSelfAttention(nn.Module):
"""Multi-headed grouped query self-attention (GQA) layer introduced
in https://arxiv.org/abs/2305.13245v1.
GQA is a version of multiheaded attention (MHA) which uses fewer
key/value heads than query heads by grouping n query heads for each
key and value head. Multi-Query Attention is an extreme
version where we have a single key and value head shared by all
query heads.
Following is an example of MHA, GQA and MQA with num_heads = 4
(credit for the documentation:
https://github.com/Lightning-AI/lit-gpt/blob/main/lit_gpt/config.py).
::
┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
│ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
└───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
│ │ │ │ │ │ │
┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
│ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
└───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
│ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
│ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
└───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
MHA GQA MQA
n_kv_heads =4 n_kv_heads=2 n_kv_heads=1
Args:
embed_dim (int): embedding dimension for the model
num_heads (int): number of query heads. For MHA this is also the
number of heads for key and value
num_kv_heads (int): number of key and value heads. If specified,
user should ensure `num_heads` % `num_kv_heads` == 0. Default value is
`None`, in which case this is the same as MHA
head_dim (int): dimension of each head, calculated by ``embed_dim`` // ``num_heads``.
q_proj (nn.Module): projection layer for query.
k_proj (nn.Module): projection layer for key.
v_proj (nn.Module): projection layer for value.
output_proj (nn.Module): projection layer for output.
pos_embeddings (nn.Module): positional embeddings layer, e.g. RotaryPositionalEmbeddings.
kv_cache (Optional[KVCache]): KVCache object used to cache key and value.
If not specified, then no caching is used.
max_seq_len (int): maximum sequence length supported by the model.
This is needed to compute the RoPE Cache. Default: 4096.
attn_dropout (float): dropout value passed onto the
scaled_dot_product_attention function. This argument is ignored if the
self.training is False. Default value is 0.0.
Raises:
ValueError: If `num_heads` % `num_kv_heads` != 0
ValueError: If `embed_dim` % `num_heads` != 0
ValueError: If `attn_dropout` < 0 or > 1
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
q_proj: nn.Module,
k_proj: nn.Module,
v_proj: nn.Module,
output_proj: nn.Module,
pos_embeddings: nn.Module,
kv_cache: Optional[KVCache] = None,
max_seq_len: int = 4096,
attn_dropout: float = 0.0,
) -> None:
super().__init__()
if num_heads % num_kv_heads != 0:
raise ValueError(
f"num_heads ({num_heads}) must be divisible by "
f"num_kv_heads ({num_kv_heads})"
)
if embed_dim % num_heads != 0:
raise ValueError(
f"embed_dim ({embed_dim}) must be divisible by "
f"num_heads ({num_heads})"
)
if attn_dropout < 0 or attn_dropout > 1:
raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0")
# Set attributes
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.embed_dim = embed_dim
self.attn_dropout = attn_dropout
self.head_dim = head_dim
self.max_seq_len = max_seq_len
# Set layers
self.kv_cache = kv_cache
self.q_proj = q_proj
self.k_proj = k_proj
self.v_proj = v_proj
self.output_proj = output_proj
self.pos_embeddings = pos_embeddings
[docs] def forward(
self,
x: Tensor,
*,
mask: Optional[Tensor] = None,
input_pos: Optional[Tensor] = None,
) -> Tensor:
"""
Args:
x (Tensor): input tensor with shape
[batch_size x seq_length x embed_dim]
mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask
with shape [batch_size x seq_length x seq_length]. This is applied after
the query-key multiplication and before the softmax. A value of True in row i
and column j means token i attends to token j. A value of False means token i
does not attend to token j. If no mask is specified, a causal mask
is used by default. Default is None.
input_pos (Optional[Tensor]): Optional tensor which contains the position ids
of each token. During training, this is used to indicate the positions
of each token relative to its sample when packed, shape [b x s].
During inference, this indicates the position of the current token.
If none, assume the index of the token is its position id. Default is None.
Returns:
Tensor: output tensor with attention applied
Raises:
ValueError: if seq_len of x is bigger than max_seq_len
Notation used for tensor shapes:
- b: batch size
- s: sequence length
- n_h: num heads
- n_kv: num kv heads
- d: embed dim
- h_d: head dim
TODO:
- Return the attention weights
- Make application of positional embeddings optional
"""
# input has shape [b, s, d]
bsz, seq_len, _ = x.shape
if seq_len > self.max_seq_len:
raise ValueError(
f"seq_len ({seq_len}) of input tensor should be smaller "
f"than max_seq_len ({self.max_seq_len})"
)
# q has shape [b, s, num_heads * head_dim]
# k has shape [b, s, num_kv_heads * head_dim]
# v has shape [b, s, num_kv_heads * head_dim]
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
# number of queries per key/value
q_per_kv = self.num_heads // self.num_kv_heads
# q: [b, s, n_kv, q_per_kv, h_d]
# k: [b, s, n_kv, 1, h_d]
# v: [b, s, n_kv, 1, h_d]
q = q.view(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim)
k = k.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim)
v = v.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim)
# if needed, expand the key and value tensors to have the same shape
# as the query tensor by copying values across the relevant dim
if self.num_heads != self.num_kv_heads:
k = k.expand(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim)
v = v.expand(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim)
# llama2 applies the RoPE embeddings on tensors with shape
# [b, s, n_h, h_d]
# Reshape the tensors before we apply RoPE
q = q.reshape(bsz, seq_len, -1, self.head_dim)
k = k.reshape(bsz, seq_len, -1, self.head_dim)
v = v.reshape(bsz, seq_len, -1, self.head_dim)
# Apply positional embeddings
q = self.pos_embeddings(q, input_pos=input_pos)
k = self.pos_embeddings(k, input_pos=input_pos)
# [b, n_h, s, h_d]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Update key-value cache
if self.kv_cache is not None:
k, v = self.kv_cache.update(input_pos, k, v)
# shape: [b, 1, s, s]
if mask is not None:
mask = mask[:, None, :, :]
# Flash attention from https://pytorch.org/blog/accelerating-large-language-models/
output = nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=self.attn_dropout,
is_causal=self.kv_cache is None and mask is None,
)
# reshape the output to be the same shape as the input
output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
return self.output_proj(output)