disable_kv_cache¶
- torchtune.modules.common_utils.disable_kv_cache(model: Module) Generator[None, None, None] [source]¶
This context manager temporarily disables KV-cacheing on a given model, which must already already have KV-caches setup. All forward passes using the model within this context manager will not use KV-caches.
KV-caches will be disabled when entering the context manager, and will be enabled upon exit, without being modified.
This is useful in cases where we might wish to alternate between model calls which use KV-cacheing, and model calls which do not use KV-cacheing, without the additional overhead of deleting and setting caches up every time.
Example
>>> from torchtune.models.llama3_2 import llama3_2_1b >>> from torchtune.modules import disable_kv_cache >>> import torch >>> model = llama3_2_1b() >>> # setup caches >>> model.setup_caches(batch_size=1, >>> dtype=torch.float32, >>> decoder_max_seq_len=1024) >>> print(model.caches_are_setup()) True >>> print(model.caches_are_enabled()) True >>> print(model.layers[0].attn.kv_cache) KVCache() >>> # now temporarily disable caches >>> with disable_kv_cache(model): >>> print(model.caches_are_setup()) True >>> print(model.caches_are_enabled()) False >>> print(model.layers[0].attn.kv_cache) KVCache() >>> # caches are now re-enabled, and their state is untouched >>> print(model.caches_are_setup()) True >>> print(model.caches_are_enabled()) True >>> print(model.layers[0].attn.kv_cache) KVCache()
- Parameters:
model (nn.Module) – model to disable KV-cacheing for.
- Yields:
None – Returns control to the caller with KV-caches disabled on the given model.
- Raises:
ValueError – If the model does not have caches setup. Use
setup_caches()
to setup caches first.