Compile Time Caching in torch.compile
Created On: Jun 20, 2024 | Last Updated: Feb 27, 2025 | Last Verified: Nov 05, 2024
Author: Oguz Ulgen
Introduction
PyTorch Compiler provides several caching offerings to reduce compilation latency. This recipe will explain these offerings in detail to help users pick the best option for their use case.
Check out Compile Time Caching Configurations for how to configure these caches.
Also check out our caching benchmark at PT CacheBench Benchmarks.
Prerequisites
Before starting this recipe, make sure that you have the following:
Basic understanding of
torch.compile
. See:PyTorch 2.4 or later
Caching Offerings
torch.compile
provides the following caching offerings:
End to end caching (also known as
Mega-Cache
)Modular caching of
TorchDynamo
,TorchInductor
, andTriton
It is important to note that caching validates that the cache artifacts are used with the same PyTorch and Triton version, as well as, same GPU when device is set to be cuda.
torch.compile
end-to-end caching (Mega-Cache
)
End to end caching, from here onwards referred to Mega-Cache
, is the ideal solution for users looking for a portable caching solution that can be stored in a database and can later be fetched possibly on a separate machine.
Mega-Cache
provides two compiler APIs:
torch.compiler.save_cache_artifacts()
torch.compiler.load_cache_artifacts()
The intended use case is after compiling and executing a model, the user calls torch.compiler.save_cache_artifacts()
which will return the compiler artifacts in a portable form. Later, potentially on a different machine, the user may call torch.compiler.load_cache_artifacts()
with these artifacts to pre-populate the torch.compile
caches in order to jump-start their cache.
Consider the following example. First, compile and save the cache artifacts.
@torch.compile
def fn(x, y):
return x.sin() @ y
a = torch.rand(100, 100, dtype=dtype, device=device)
b = torch.rand(100, 100, dtype=dtype, device=device)
result = fn(a, b)
artifacts = torch.compiler.save_cache_artifacts()
assert artifacts is not None
artifact_bytes, cache_info = artifacts
# Now, potentially store artifact_bytes in a database
# You can use cache_info for logging
Later, you can jump-start the cache by the following:
# Potentially download/fetch the artifacts from the database
torch.compiler.load_cache_artifacts(artifact_bytes)
This operation populates all the modular caches that will be discussed in the next section, including PGO
, AOTAutograd
, Inductor
, Triton
, and Autotuning
.
Modular caching of TorchDynamo
, TorchInductor
, and Triton
The aforementioned Mega-Cache
is composed of individual components that can be used without any user intervention. By default, PyTorch Compiler comes with local on-disk caches for TorchDynamo
, TorchInductor
, and Triton
. These caches include:
FXGraphCache
: A cache of graph-based IR components used in compilation.TritonCache
: A cache of Triton-compilation results, includingcubin
files generated byTriton
and other caching artifacts.InductorCache
: A bundle ofFXGraphCache
andTriton
cache.AOTAutogradCache
: A cache of joint graph artifacts.PGO-cache
: A cache of dynamic shape decisions to reduce number of recompilations.
All these cache artifacts are written to TORCHINDUCTOR_CACHE_DIR
which by default will look like /tmp/torchinductor_myusername
.
Remote Caching
We also provide a remote caching option for users who would like to take advantage of a Redis based cache. Check out Compile Time Caching Configurations to learn more about how to enable the Redis-based caching.