

class torchtune.modules.CausalSelfAttention(embed_dim: int, num_heads: int, num_kv_heads: int, head_dim: int, q_proj: Module, k_proj: Module, v_proj: Module, output_proj: Module, pos_embeddings: Module, kv_cache: Optional[KVCache] = None, max_seq_len: int = 4096, attn_dropout: float = 0.0)[source]

Multi-headed grouped query self-attention (GQA) layer introduced in

GQA is a version of multiheaded attention (MHA) which uses fewer key/value heads than query heads by grouping n query heads for each key and value head. Multi-Query Attention is an extreme version where we have a single key and value head shared by all query heads.

Following is an example of MHA, GQA and MQA with num_heads = 4

(credit for the documentation:

┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
│ v ││ v ││ v ││ v │     │ v │    │ v │             │ v │
└───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
│    │    │    │         │        │                 │
┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
│ k ││ k ││ k ││ k │     │ k │    │ k │             │ k │
└───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
│    │    │    │      ┌──┴──┐  ┌──┴──┐      ┌────┬──┴─┬────┐
┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐
│ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │
└───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘
◀──────────────────▶  ◀──────────────────▶  ◀──────────────────▶
        MHA                    GQA                   MQA
n_kv_heads =4          n_kv_heads=2           n_kv_heads=1
  • embed_dim (int) – embedding dimension for the model

  • num_heads (int) – number of query heads. For MHA this is also the number of heads for key and value

  • num_kv_heads (int) – number of key and value heads. If specified, user should ensure num_heads % num_kv_heads == 0. Default value is None, in which case this is the same as MHA

  • head_dim (int) – dimension of each head, calculated by embed_dim // num_heads.

  • q_proj (nn.Module) – projection layer for query.

  • k_proj (nn.Module) – projection layer for key.

  • v_proj (nn.Module) – projection layer for value.

  • output_proj (nn.Module) – projection layer for output.

  • pos_embeddings (nn.Module) – positional embeddings layer, e.g. RotaryPositionalEmbeddings.

  • kv_cache (Optional[KVCache]) – KVCache object used to cache key and value. If not specified, then no caching is used.

  • max_seq_len (int) – maximum sequence length supported by the model. This is needed to compute the RoPE Cache. Default: 4096.

  • attn_dropout (float) – dropout value passed onto the scaled_dot_product_attention function. This argument is ignored if the is False. Default value is 0.0.

forward(x: Tensor, *, mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) Tensor[source]
  • x (Tensor) – input tensor with shape [batch_size x seq_length x embed_dim]

  • mask (Optional[Tensor]) – Optional boolean tensor which contains the attention mask with shape [batch_size x seq_length x seq_length]. This is applied after the query-key multiplication and before the softmax. A value of True in row i and column j means token i attends to token j. A value of False means token i does not attend to token j. If no mask is specified, a causal mask is used by default. Default is None.

  • input_pos (Optional[Tensor]) – Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. If none, assume the index of the token is its position id. Default is None.


output tensor with attention applied

Return type:



ValueError – if seq_len of x is bigger than max_seq_len

Notation used for tensor shapes:
  • b: batch size

  • s: sequence length

  • n_h: num heads

  • n_kv: num kv heads

  • d: embed dim

  • h_d: head dim


  • Return the attention weights

  • Make application of positional embeddings optional


