local_kv_cache¶
- torchtune.modules.common_utils.local_kv_cache(model: Module, *, batch_size: int, device: device, dtype: dtype, encoder_max_seq_len: Optional[int] = None, decoder_max_seq_len: Optional[int] = None) Generator[None, None, None] [source]¶
This context manager temporarily enables KV-cacheing on a given model, which does not already have KV-caches setup. All forward passes using the model within this context manager will use KV-caches.
KV-caches will be set-up with the given
batch_size
,dtype
, andmax_seq_len
when entering the context manager, and will be deleted on exit.Example
>>> from torchtune.models.llama3_2 import llama3_2_1b >>> from torchtune.modules import local_kv_cache >>> import torch >>> model = llama3_2_1b() >>> print(model.caches_are_setup()) False >>> print(model.caches_are_enabled()) False >>> print(model.layers[0].attn.kv_cache) None >>> # entering cacheing mode >>> with local_kv_cache(model, >>> batch_size=1, >>> device=torch.device("cpu"), >>> dtype=torch.float32, >>> decoder_max_seq_len=1024): >>> print(model.caches_are_setup()) True >>> print(model.caches_are_enabled()) True >>> print(model.layers[0].attn.kv_cache) KVCache() >>> # exited cacheing mode >>> print(model.caches_are_setup()) False >>> print(model.caches_are_enabled()) False >>> print(model.layers[0].attn.kv_cache) None
- Parameters:
model (nn.Module) – model to enable KV-cacheing for.
batch_size (int) – batch size for the caches.
device (torch.device) – device to setup caches on. this should be the same device the model is on.
dtype (torch.dpython:type) – dtype for the caches.
encoder_max_seq_len (Optional[int]) – maximum encoder cache sequence length.
decoder_max_seq_len (Optional[int]) – maximum decoder cache sequence length.
- Yields:
None – Returns control to the caller with KV-caches setup and enabled on the given model.
- Raises:
ValueError – If the model already has caches setup. You may use
delete_kv_caches()
to delete existing caches.