Shortcuts

KVCache

class torchtune.modules.KVCache(batch_size: int, max_seq_len: int, num_heads: int, head_dim: int, dtype: dtype)[source]

Standalone nn.Module containing a kv-cache to cache past key and values during inference.

Parameters:
  • batch_size (int) – batch size model will be run with

  • max_seq_len (int) – maximum sequence length model will be run with

  • num_heads (int) – number of heads. We take num_heads instead of num_kv_heads because the cache is created after we’ve expanded the key and value tensors to have the same shape as the query tensor. See attention.py for more details

  • head_dim (int) – per-attention head embedding dimension

  • dtype (torch.dpython:type) – dtype for the caches

reset() None[source]

Reset the cache to zero.

update(input_pos: Tensor, k_val: Tensor, v_val: Tensor) Tuple[Tensor, Tensor][source]

Update KV cache with the new k_val, v_val and return the updated cache.

Parameters:
  • input_pos (Tensor) – Current position tensor with shape [S]

  • k_val (Tensor) – Current key tensor with shape [B, H, S, D]

  • v_val (Tensor) – Current value tensor with shape [B, H, S, D]

Raises:

ValueError – if input_pos is longer than the maximum sequence length

Returns:

Updated KV cache with key first

Return type:

Tuple[Tensor, Tensor]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources