MultiHeadAttention¶
- class torchtune.modules.MultiHeadAttention(*, embed_dim: int, num_heads: int, num_kv_heads: int, head_dim: int, q_proj: Module, k_proj: Module, v_proj: Module, output_proj: Module, pos_embeddings: Optional[Module] = None, q_norm: Optional[Module] = None, k_norm: Optional[Module] = None, kv_cache: Optional[KVCache] = None, max_seq_len: int = 4096, is_causal: bool = True, attn_dropout: float = 0.0)[source]¶
Multi-headed attention layer with support for grouped query attention (GQA) introduced in https://arxiv.org/abs/2305.13245v1.
GQA is a version of multiheaded attention (MHA) which uses fewer key/value heads than query heads by grouping n query heads for each key and value head. Multi-Query Attention is an extreme version where we have a single key and value head shared by all query heads.
Following is an example of MHA, GQA and MQA with num_heads = 4
(credit for the documentation: litgpt.Config).
┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ │ │ │ │ │ │ │ ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ MHA GQA MQA n_kv_heads =4 n_kv_heads=2 n_kv_heads=1
- Parameters:
embed_dim (int) – embedding dimension for the model
num_heads (int) – number of query heads. For MHA this is also the number of heads for key and value
num_kv_heads (int) – number of key and value heads. User should ensure
num_heads % num_kv_heads == 0
. For standard MHA setnum_kv_heads == num_heads
, for GQAnum_kv_heads < num_heads
, and for MQA setnum_kv_heads == 1
.head_dim (int) – dimension of each head, calculated by
embed_dim // num_heads
.q_proj (nn.Module) – projection layer for query.
k_proj (nn.Module) – projection layer for key.
v_proj (nn.Module) – projection layer for value.
output_proj (nn.Module) – projection layer for output.
pos_embeddings (Optional[nn.Module]) – positional embeddings layer, e.g. RotaryPositionalEmbeddings.
q_norm (Optional[nn.Module]) – normalization layer for query, e.g. RMSNorm. For decoding, this is applied before updating from kv_cache. This means it will only support token wide normalization and not batch or sequence wide normalization.
k_norm (Optional[nn.Module]) – normalization layer for key, must be set if q_norm is.
kv_cache (Optional[KVCache]) – KVCache object used to cache key and value
max_seq_len (int) – maximum sequence length supported by the model. This is needed to compute the RoPE Cache. Default: 4096.
is_causal (bool) – sets the default mask to causal when no mask is provided
attn_dropout (float) – dropout value passed onto the scaled_dot_product_attention function. Default value is 0.0.
- Raises:
ValueError – If
num_heads % num_kv_heads != 0
ValueError – If
embed_dim % num_heads != 0
ValueError – If
attn_dropout < 0
orattn_dropout > 1
ValueError – if q_norm is defined without k_norm or vice versa
- forward(x: Tensor, y: Optional[Tensor] = None, *, mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) Tensor [source]¶
- Parameters:
x (torch.Tensor) – input tensor with shape [b x s_x x d] for the query
y (Optional[torch.Tensor]) – second input tensor with shape [b x s_y x d], is the input for k and v. For self attention, x=y. Optional only with kv_cache enabled.
mask (Optional[_MaskType]) –
Used to mask the scores after the query-key multiplication and before the softmax. Either:
A boolean tensor with shape
[b x s x s]
,[b x s x self.encoder_max_cache_seq_len]
, or[b x s x self.decoder_max_cache_seq_len]
if using KV-cacheing with encoder/decoder layers. A value of True in rowi
and columnj
means tokeni
attends to tokenj
. A value of False means tokeni
does not attend to tokenj
. If no mask is specified, a causal mask is used by default.A
BlockMask
for document masking in a packed sequence created via create_block_mask. We useflex_attention()
when computing attention with block masks. Default is None.input_pos (Optional[torch.Tensor]) – Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. If none, assume the index of the token is its position id. Default is None.
- Raises:
ValueError – If no
y
input andkv_cache
is not enabled.- Returns:
output tensor with attention applied
- Return type:
- Notation used for tensor shapes:
b: batch size
s_x: sequence length for x
s_y: sequence length for y
n_h: num heads
n_kv: num kv heads
d: embed dim
h_d: head dim