Shortcuts

CausalSelfAttention

class torchtune.modules.CausalSelfAttention(embed_dim: int, num_heads: int, num_kv_heads: int, head_dim: int, q_proj: Module, k_proj: Module, v_proj: Module, output_proj: Module, pos_embeddings: Module, kv_cache: Optional[KVCache] = None, max_seq_len: int = 4096, attn_dropout: float = 0.0)[source]

Multi-headed grouped query self-attention (GQA) layer introduced in https://arxiv.org/abs/2305.13245v1.

GQA is a version of multiheaded attention (MHA) which uses fewer key/value heads than query heads by grouping n query heads for each key and value head. Multi-Query Attention is an extreme version where we have a single key and value head shared by all query heads.

Following is an example of MHA, GQA and MQA with num_heads = 4

(credit for the documentation: https://github.com/Lightning-AI/lit-gpt/blob/main/lit_gpt/config.py).

┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
│ v ││ v ││ v ││ v │     │ v │    │ v │             │ v │
└───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
│    │    │    │         │        │                 │
┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
│ k ││ k ││ k ││ k │     │ k │    │ k │             │ k │
└───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
│    │    │    │      ┌──┴──┐  ┌──┴──┐      ┌────┬──┴─┬────┐
┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐
│ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │
└───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘
◀──────────────────▶  ◀──────────────────▶  ◀──────────────────▶
        MHA                    GQA                   MQA
n_kv_heads =4          n_kv_heads=2           n_kv_heads=1
Parameters:
  • embed_dim (int) – embedding dimension for the model

  • num_heads (int) – number of query heads. For MHA this is also the number of heads for key and value

  • num_kv_heads (int) – number of key and value heads. If specified, user should ensure num_heads % num_kv_heads == 0. Default value is None, in which case this is the same as MHA

  • head_dim (int) – dimension of each head, calculated by embed_dim // num_heads.

  • q_proj (nn.Module) – projection layer for query.

  • k_proj (nn.Module) – projection layer for key.

  • v_proj (nn.Module) – projection layer for value.

  • output_proj (nn.Module) – projection layer for output.

  • pos_embeddings (nn.Module) – positional embeddings layer, e.g. RotaryPositionalEmbeddings.

  • kv_cache (Optional[KVCache]) – KVCache object used to cache key and value. If not specified, then no caching is used.

  • max_seq_len (int) – maximum sequence length supported by the model. This is needed to compute the RoPE Cache. Default: 4096.

  • attn_dropout (float) – dropout value passed onto the scaled_dot_product_attention function. This argument is ignored if the self.training is False. Default value is 0.0.

Raises:
forward(x: Tensor, *, mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) Tensor[source]
Parameters:
  • x (Tensor) – input tensor with shape [batch_size x seq_length x embed_dim]

  • mask (Optional[Tensor]) – Optional boolean tensor which contains the attention mask with shape [batch_size x seq_length x seq_length]. This is applied after the query-key multiplication and before the softmax. A value of True in row i and column j means token i attends to token j. A value of False means token i does not attend to token j. If no mask is specified, a causal mask is used by default. Default is None.

  • input_pos (Optional[Tensor]) – Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. If none, assume the index of the token is its position id. Default is None.

Returns:

output tensor with attention applied

Return type:

Tensor

Raises:

ValueError – if seq_len of x is bigger than max_seq_len

Notation used for tensor shapes:
  • b: batch size

  • s: sequence length

  • n_h: num heads

  • n_kv: num kv heads

  • d: embed dim

  • h_d: head dim

Todo

  • Return the attention weights

  • Make application of positional embeddings optional

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources