

class torchtune.modules.KVCache(batch_size: int, max_seq_len: int, num_heads: int, head_dim: int, dtype: dtype)[source]

Standalone nn.Module containing a kv-cache to cache past key and values during inference.

  • batch_size (int) – batch size model will be run with

  • max_seq_len (int) – maximum sequence length model will be run with

  • num_heads (int) – number of heads. We take num_heads instead of num_kv_heads because the cache is created after we’ve expanded the key and value tensors to have the same shape as the query tensor. See for more details

  • head_dim (int) – per-attention head embedding dimension

  • dtype (torch.dpython:type) – dtype for the caches

reset() None[source]

Reset the cache to zero.

update(k_val: Tensor, v_val: Tensor) Tuple[Tensor, Tensor][source]

Update KV cache with the new k_val, v_val and return the updated cache.


When updating the KV cache, it is assumed that subsequent updates should update key-value positions in consecutive sequence positions. If you wish to update cache values which have already been filled, use .reset(), which will reset the cache to the zero-th position.


>>> cache = KVCache(batch_size=2, max_seq_len=16, num_heads=4, head_dim=32, dtype=torch.bfloat16)
>>> keys, values = torch.ones((2, 4, 8, 32)), torch.ones((2, 4, 8, 32))
>>> cache.update(keys, values)
>>> # now positions 0 through 7 are filled
>>> cache.size
>>> 8
>>> keys, values = torch.ones((2, 4, 1, 32)), torch.ones((2, 4, 1, 32))
>>> cache.update(keys, values)
>>> # this will fill at position 8
>>> cache.size
>>> 9
  • k_val (torch.Tensor) – Current key tensor with shape [B, H, S, D]

  • v_val (torch.Tensor) – Current value tensor with shape [B, H, S, D]


Updated key and value cache tensors, respectively.

Return type:

Tuple[torch.Tensor, torch.Tensor]

  • ValueError – if the sequence length of k_val is longer than the maximum cache sequence length.

  • ValueError – if the batch size of the new key (or value) tensor is greater than the batch size used during cache setup.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources