Shortcuts

delete_kv_caches

torchtune.modules.common_utils.delete_kv_caches(model: Module)[source]

Deletes KV caches from all attention layers in a model, and also ensures cache_enabled is set to False.

Example

>>> from torchtune.models.llama3_2 import llama3_2_1b
>>> from torchtune.modules import delete_kv_caches
>>> import torch
>>> model = llama3_2_1b()
>>> model.setup_caches(batch_size=1,
>>>                     dtype=torch.float32,
>>>                     decoder_max_seq_len=1024)
>>> print(model.caches_are_setup())
True
>>> print(model.caches_are_enabled())
True
>>> print(model.layers[0].attn.kv_cache)
KVCache()
>>> delete_kv_caches(model)
>>> print(model.caches_are_setup())
False
>>> print(model.caches_are_enabled())
False
>>> print(model.layers[0].attn.kv_cache)
None
Parameters:

model (nn.Module) – model to enable KV-cacheing for.

Raises:

ValueError – if this function is called on a model which does not have caches setup. Use setup_caches() to setup caches first.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources