Shortcuts

KVCache

class torchtune.modules.KVCache(batch_size: int, max_seq_len: int, num_heads: int, head_dim: int, dtype: dtype)[source]

Standalone nn.Module containing a kv-cache to cache past key and values during inference.

Parameters:
  • batch_size (int) – batch size model will be run with

  • max_seq_len (int) – maximum sequence length model will be run with

  • num_heads (int) – number of heads. We take num_heads instead of num_kv_heads because the cache is created after we’ve expanded the key and value tensors to have the same shape as the query tensor. See attention.py for more details

  • head_dim (int) – per-attention head embedding dimension

  • dtype (torch.dpython:type) – dtype for the caches

reset() None[source]

Reset the cache to zero.

update(k_val: Tensor, v_val: Tensor) Tuple[Tensor, Tensor][source]

Update KV cache with the new k_val, v_val and return the updated cache.

Note

When updating the KV cache, it is assumed that subsequent updates should update key-value positions in consecutive sequence positions. If you wish to update cache values which have already been filled, use .reset(), which will reset the cache to the zero-th position.

Example

>>> cache = KVCache(batch_size=2, max_seq_len=16, num_heads=4, head_dim=32, dtype=torch.bfloat16)
>>> keys, values = torch.ones((2, 4, 8, 32)), torch.ones((2, 4, 8, 32))
>>> cache.update(keys, values)
>>> # now positions 0 through 7 are filled
>>> cache.size
>>> 8
>>> keys, values = torch.ones((2, 4, 1, 32)), torch.ones((2, 4, 1, 32))
>>> cache.update(keys, values)
>>> # this will fill at position 8
>>> cache.size
>>> 9
Parameters:
  • k_val (torch.Tensor) – Current key tensor with shape [B, H, S, D]

  • v_val (torch.Tensor) – Current value tensor with shape [B, H, S, D]

Returns:

Updated key and value cache tensors, respectively.

Return type:

Tuple[torch.Tensor, torch.Tensor]

Raises:
  • ValueError – if the sequence length of k_val is longer than the maximum cache sequence length.

  • ValueError – if the batch size of the new key (or value) tensor is greater than the batch size used during cache setup.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources