KVCache¶
- class torchtune.modules.KVCache(batch_size: int, max_seq_len: int, num_kv_heads: int, head_dim: int, dtype: dtype)[source]¶
Standalone
nn.Module
containing a kv-cache to cache past key and values during inference.- Parameters:
- update(k_val: Tensor, v_val: Tensor) Tuple[Tensor, Tensor] [source]¶
Update KV cache with the new
k_val
,v_val
and return the updated cache.Note
When updating the KV cache, it is assumed that subsequent updates should update key-value positions in consecutive sequence positions. If you wish to update cache values which have already been filled, use
.reset()
, which will reset the cache to the zero-th position.Example
>>> cache = KVCache(batch_size=2, max_seq_len=16, num_kv_heads=4, head_dim=32, dtype=torch.bfloat16) >>> keys, values = torch.ones((2, 4, 8, 32)), torch.ones((2, 4, 8, 32)) >>> cache.update(keys, values) >>> # now positions 0 through 7 are filled >>> cache.size >>> 8 >>> keys, values = torch.ones((2, 4, 1, 32)), torch.ones((2, 4, 1, 32)) >>> cache.update(keys, values) >>> # this will fill at position 8 >>> cache.size >>> 9
- Parameters:
k_val (torch.Tensor) – Current key tensor with shape [B, H, S, D]
v_val (torch.Tensor) – Current value tensor with shape [B, H, S, D]
- Returns:
Updated key and value cache tensors, respectively.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- Raises:
AssertionError – if the sequence length of
k_val
is longer than the maximum cache sequence length.ValueError – if the batch size of the new key (or value) tensor is greater than the batch size used during cache setup.