Shortcuts

MultiHeadAttention

class torchtune.modules.MultiHeadAttention(*, embed_dim: int, num_heads: int, num_kv_heads: int, head_dim: int, q_proj: Module, k_proj: Module, v_proj: Module, output_proj: Module, pos_embeddings: Optional[Module] = None, q_norm: Optional[Module] = None, k_norm: Optional[Module] = None, kv_cache: Optional[KVCache] = None, max_seq_len: int = 4096, is_causal: bool = True, attn_dropout: float = 0.0)[source]

Multi-headed attention layer with support for grouped query attention (GQA) introduced in https://arxiv.org/abs/2305.13245v1.

GQA is a version of multiheaded attention (MHA) which uses fewer key/value heads than query heads by grouping n query heads for each key and value head. Multi-Query Attention is an extreme version where we have a single key and value head shared by all query heads.

Following is an example of MHA, GQA and MQA with num_heads = 4

(credit for the documentation: litgpt.Config).

┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
│ v ││ v ││ v ││ v │     │ v │    │ v │             │ v │
└───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
│    │    │    │         │        │                 │
┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
│ k ││ k ││ k ││ k │     │ k │    │ k │             │ k │
└───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
│    │    │    │      ┌──┴──┐  ┌──┴──┐      ┌────┬──┴─┬────┐
┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐
│ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │
└───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘
◀──────────────────▶  ◀──────────────────▶  ◀──────────────────▶
        MHA                    GQA                   MQA
n_kv_heads =4          n_kv_heads=2           n_kv_heads=1
Parameters:
  • embed_dim (int) – embedding dimension for the model

  • num_heads (int) – number of query heads. For MHA this is also the number of heads for key and value

  • num_kv_heads (int) – number of key and value heads. User should ensure num_heads % num_kv_heads == 0. For standard MHA set num_kv_heads == num_heads, for GQA num_kv_heads < num_heads, and for MQA set num_kv_heads == 1.

  • head_dim (int) – dimension of each head, calculated by embed_dim // num_heads.

  • q_proj (nn.Module) – projection layer for query.

  • k_proj (nn.Module) – projection layer for key.

  • v_proj (nn.Module) – projection layer for value.

  • output_proj (nn.Module) – projection layer for output.

  • pos_embeddings (Optional[nn.Module]) – positional embeddings layer, e.g. RotaryPositionalEmbeddings.

  • q_norm (Optional[nn.Module]) – normalization layer for query, e.g. RMSNorm. For decoding, this is applied before updating from kv_cache. This means it will only support token wide normalization and not batch or sequence wide normalization.

  • k_norm (Optional[nn.Module]) – normalization layer for key, must be set if q_norm is.

  • kv_cache (Optional[KVCache]) – KVCache object used to cache key and value

  • max_seq_len (int) – maximum sequence length supported by the model. This is needed to compute the RoPE Cache. Default: 4096.

  • is_causal (bool) – sets the default mask to causal when no mask is provided

  • attn_dropout (float) – dropout value passed onto the scaled_dot_product_attention function. Default value is 0.0.

Raises:
  • ValueError – If num_heads % num_kv_heads != 0

  • ValueError – If embed_dim % num_heads != 0

  • ValueError – If attn_dropout < 0 or attn_dropout > 1

  • ValueError – if q_norm is defined without k_norm or vice versa

forward(x: Tensor, y: Optional[Tensor] = None, *, mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) Tensor[source]
Parameters:
  • x (torch.Tensor) – input tensor with shape [b x s_x x d] for the query

  • y (Optional[torch.Tensor]) – second input tensor with shape [b x s_y x d], is the input for k and v. For self attention, x=y. Optional only with kv_cache enabled.

  • mask (Optional[_MaskType]) –

    Used to mask the scores after the query-key multiplication and before the softmax. Either:

    A boolean tensor with shape [b x s x s], [b x s x self.encoder_max_cache_seq_len], or [b x s x self.encoder_max_cache_seq_len] if using KV-cacheing with encoder/decoder layers. A value of True in row i and column j means token i attends to token j. A value of False means token i does not attend to token j. If no mask is specified, a causal mask is used by default.

    A BlockMask for document masking in a packed sequence created via create_block_mask. We use flex_attention() when computing attention with block masks. Default is None.

  • input_pos (Optional[torch.Tensor]) – Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. If none, assume the index of the token is its position id. Default is None.

Raises:

ValueError – If no y input and kv_cache is not enabled.

Returns:

output tensor with attention applied

Return type:

torch.Tensor

Notation used for tensor shapes:
  • b: batch size

  • s_x: sequence length for x

  • s_y: sequence length for y

  • n_h: num heads

  • n_kv: num kv heads

  • d: embed dim

  • h_d: head dim

reset_cache()[source]

Reset the key value caches.

setup_cache(batch_size: int, dtype: dtype, max_seq_len: int) None[source]

Setup key value caches for attention calculation. If called after kv_cache is already setup, this will be skipped.

Parameters:
  • batch_size (int) – batch size for the caches.

  • dtype (torch.dpython:type) – dtype for the caches.

  • max_seq_len (int) – maximum sequence length model will be run with.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources