KVCache¶
- class torchtune.modules.KVCache(batch_size: int, max_seq_len: int, num_heads: int, head_dim: int, dtype: dtype)[source]¶
Standalone
nn.Module
containing a kv-cache to cache past key and values during inference.- Parameters:
batch_size (int) – batch size model will be run with
max_seq_len (int) – maximum sequence length model will be run with
num_heads (int) – number of heads. We take num_heads instead of num_kv_heads because the cache is created after we’ve expanded the key and value tensors to have the same shape as the query tensor. See attention.py for more details
head_dim (int) – per-attention head embedding dimension
dtype (torch.dpython:type) – dtype for the caches
- update(input_pos: Tensor, k_val: Tensor, v_val: Tensor) Tuple[Tensor, Tensor] [source]¶
Update KV cache with the new k_val, v_val and return the updated cache.
- Parameters:
input_pos (Tensor) – Current position tensor with shape [S]
k_val (Tensor) – Current key tensor with shape [B, H, S, D]
v_val (Tensor) – Current value tensor with shape [B, H, S, D]
- Raises:
ValueError – if
input_pos
is longer than the maximum sequence length- Returns:
Updated KV cache with key first
- Return type:
Tuple[Tensor, Tensor]