Source code for torchtune.modules.kv_cache
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple
import torch
from torch import nn
[docs]class KVCache(nn.Module):
"""
Standalone ``nn.Module`` containing a kv-cache to cache past key and values during inference.
Args:
batch_size (int): batch size model will be run with
max_seq_len (int): maximum sequence length model will be run with
num_kv_heads (int): number of key/value heads.
head_dim (int): per-attention head embedding dimension
dtype (torch.dtype): dtype for the caches
"""
def __init__(
self,
batch_size: int,
max_seq_len: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype,
) -> None:
super().__init__()
cache_shape = (batch_size, num_kv_heads, max_seq_len, head_dim)
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
)
self.register_buffer(
"v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
)
self.register_buffer(
"cache_pos", torch.arange(0, cache_shape[2]), persistent=False
)
self.batch_size = batch_size
[docs] def reset(self) -> None:
"""Reset the cache to zero."""
self.k_cache.zero_()
self.v_cache.zero_()
self.cache_pos -= self.size
@property
def size(self) -> int:
return self.cache_pos[0].item()
[docs] def update(
self, k_val: torch.Tensor, v_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Update KV cache with the new ``k_val``, ``v_val`` and return the updated cache.
Note:
When updating the KV cache, it is assumed that subsequent updates should update key-value
positions in consecutive sequence positions. If you wish to update cache values which have
already been filled, use ``.reset()``, which will reset the cache to the zero-th position.
Example:
>>> cache = KVCache(batch_size=2, max_seq_len=16, num_kv_heads=4, head_dim=32, dtype=torch.bfloat16)
>>> keys, values = torch.ones((2, 4, 8, 32)), torch.ones((2, 4, 8, 32))
>>> cache.update(keys, values)
>>> # now positions 0 through 7 are filled
>>> cache.size
>>> 8
>>> keys, values = torch.ones((2, 4, 1, 32)), torch.ones((2, 4, 1, 32))
>>> cache.update(keys, values)
>>> # this will fill at position 8
>>> cache.size
>>> 9
Args:
k_val (torch.Tensor): Current key tensor with shape [B, H, S, D]
v_val (torch.Tensor): Current value tensor with shape [B, H, S, D]
Returns:
Tuple[torch.Tensor, torch.Tensor]: Updated key and value cache tensors, respectively.
Raises:
AssertionError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length.
ValueError: if the batch size of the new key (or value) tensor is greater than the batch size
used during cache setup.
"""
bsz, _, seq_len, _ = k_val.shape
if bsz > self.k_cache.shape[0]:
raise ValueError(
f"The current cache has been setup with a batch size of {self.k_cache.shape[0]}"
f", but found new key tensors with batch size {k_val.shape[0]}!"
)
assert (self.cache_pos[0] + seq_len) <= self.k_cache.shape[2]
k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, self.cache_pos[:seq_len]] = k_val
v_out[:, :, self.cache_pos[:seq_len]] = v_val
# forward cache_pos seq_len positions along
# cache_pos starts at (0, 1, 2, 3, 4, 5, ...)
# an update of seq_len = 5 tokens brings it to
# (5, 6, 7, 8, 9, ...)
# this allows us to track the current position in the cache
# after the last update in a compile-friendly way without any dynamism
# e.g. relying on an int size tracker, or re-creating cache_pos every time
self.cache_pos += seq_len
return k_out, v_out