KVCache¶
- class torchtune.modules.KVCache(max_batch_size: int, max_seq_len: int, num_heads: int, head_dim: int, dtype: dtype)[source]¶
Standalone nn.Module containing a kv-cache to cache past key and values during inference.
- Parameters:
max_batch_size (int) – maximum batch size model will be run with
max_seq_len (int) – maximum sequence length model will be run with
num_heads (int) – number of heads. We take num_heads instead of num_kv_heads because the cache is created after we’ve expanded the key and value tensors to have the same shape as the query tensor. See attention.py for more details
head_dim (int) – per-attention head embedding dimension
dtype (torch.dpython:type) – dtype for the caches