Shortcuts

local_kv_cache

torchtune.modules.common_utils.local_kv_cache(model: Module, *, batch_size: int, device: device, dtype: dtype, encoder_max_seq_len: Optional[int] = None, decoder_max_seq_len: Optional[int] = None) Generator[None, None, None][source]

This context manager temporarily enables KV-cacheing on a given model, which does not already have KV-caches setup. All forward passes using the model within this context manager will use KV-caches.

KV-caches will be set-up with the given batch_size, dtype, and max_seq_len when entering the context manager, and will be deleted on exit.

Example

>>> from torchtune.models.llama3_2 import llama3_2_1b
>>> from torchtune.modules import local_kv_cache
>>> import torch
>>> model = llama3_2_1b()
>>> print(model.caches_are_setup())
False
>>> print(model.caches_are_enabled())
False
>>> print(model.layers[0].attn.kv_cache)
None
>>> # entering cacheing mode
>>> with local_kv_cache(model,
>>>                     batch_size=1,
>>>                     device=torch.device("cpu"),
>>>                     dtype=torch.float32,
>>>                     decoder_max_seq_len=1024):
>>>     print(model.caches_are_setup())
True
>>>     print(model.caches_are_enabled())
True
>>>     print(model.layers[0].attn.kv_cache)
KVCache()
>>> # exited cacheing mode
>>> print(model.caches_are_setup())
False
>>> print(model.caches_are_enabled())
False
>>> print(model.layers[0].attn.kv_cache)
None
Parameters:
  • model (nn.Module) – model to enable KV-cacheing for.

  • batch_size (int) – batch size for the caches.

  • device (torch.device) – device to setup caches on. this should be the same device the model is on.

  • dtype (torch.dpython:type) – dtype for the caches.

  • encoder_max_seq_len (Optional[int]) – maximum encoder cache sequence length.

  • decoder_max_seq_len (Optional[int]) – maximum decoder cache sequence length.

Yields:

None – Returns control to the caller with KV-caches setup and enabled on the given model.

Raises:

ValueError – If the model already has caches setup. You may use delete_kv_caches() to delete existing caches.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources