CausalSelfAttention¶
- class torchtune.modules.CausalSelfAttention(embed_dim: int, num_heads: int, num_kv_heads: int, head_dim: int, q_proj: Module, k_proj: Module, v_proj: Module, output_proj: Module, pos_embeddings: Module, kv_cache: Optional[KVCache] = None, max_seq_len: int = 4096, attn_dropout: float = 0.0)[source]¶
Multi-headed grouped query self-attention (GQA) layer introduced in https://arxiv.org/pdf/2305.13245v1.pdf.
GQA is a version of multiheaded attention (MHA) which uses fewer key/value heads than query heads by grouping n query heads for each key and value head. Multi-Query Attention is an extreme version where we have a single key and value head shared by all query heads.
Following is an example of MHA, GQA and MQA with num_heads = 4
(credit for the documentation: https://github.com/Lightning-AI/lit-gpt/blob/main/lit_gpt/config.py).
┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ │ │ │ │ │ │ │ ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ MHA GQA MQA n_kv_heads =4 n_kv_heads=2 n_kv_heads=1
- Parameters:
embed_dim (int) – embedding dimension for the model
num_heads (int) – number of query heads. For MHA this is also the number of heads for key and value
num_kv_heads (int) – number of key and value heads. If specified, user should ensure num_heads % num_kv_heads == 0. Default value is None, in which case this is the same as MHA
head_dim (int) – dimension of each head, calculated by
embed_dim
//num_heads
.q_proj (nn.Module) – projection layer for query.
k_proj (nn.Module) – projection layer for key.
v_proj (nn.Module) – projection layer for value.
output_proj (nn.Module) – projection layer for output.
pos_embeddings (nn.Module) – positional embeddings layer, e.g. RotaryPositionalEmbeddings.
kv_cache (Optional[KVCache]) – KVCache object used to cache key and value. If not specified, then no caching is used.
max_seq_len (int) – maximum sequence length supported by the model. This is needed to compute the RoPE Cache. Default: 4096.
attn_dropout (float) – dropout value passed onto the scaled_dot_product_attention function. This argument is ignored if the self.training is False. Default value is 0.0.
- Raises:
ValueError – If num_heads % num_kv_heads != 0
ValueError – If embed_dim % num_heads != 0
ValueError – If attn_dropout < 0 or > 1
- forward(x: Tensor, mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) Tensor [source]¶
- Parameters:
x (Tensor) – input tensor with shape [batch_size x seq_length x embed_dim]
mask (Optional[Tensor]) – Optional tensor which contains the mask. Only used during inference. Default is None.
input_pos (Optional[Tensor]) – Optional tensor which contains the position of the current token. This is only used during inference. Default is None
- Returns:
output tensor with attention applied
- Return type:
Tensor
- Raises:
ValueError – if seq_len of x is bigger than max_seq_len
- Notation used for tensor shapes:
b: batch size
s: sequence length
n_h: num heads
n_kv: num kv heads
d: embed dim
h_d: head dim
Todo
Return the attention weights
Make application of positional embeddings optional