Compile Time Caching in ``torch.compile`` ========================================================= **Author:** `Oguz Ulgen <https://github.com/oulgen>`_ Introduction ------------------ PyTorch Compiler provides several caching offerings to reduce compilation latency. This recipe will explain these offerings in detail to help users pick the best option for their use case. Check out `Compile Time Caching Configurations <https://pytorch.org/tutorials/recipes/torch_compile_caching_configuration_tutorial.html>`__ for how to configure these caches. Also check out our caching benchmark at `PT CacheBench Benchmarks <https://hud.pytorch.org/benchmark/llms?repoName=pytorch%2Fpytorch&benchmarkName=TorchCache+Benchmark>`__. Prerequisites ------------------- Before starting this recipe, make sure that you have the following: * Basic understanding of ``torch.compile``. See: * `torch.compiler API documentation <https://pytorch.org/docs/stable/torch.compiler.html#torch-compiler>`__ * `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__ * `Triton language documentation <https://triton-lang.org/main/index.html>`__ * PyTorch 2.4 or later Caching Offerings --------------------- ``torch.compile`` provides the following caching offerings: * End to end caching (also known as ``Mega-Cache``) * Modular caching of ``TorchDynamo``, ``TorchInductor``, and ``Triton`` It is important to note that caching validates that the cache artifacts are used with the same PyTorch and Triton version, as well as, same GPU when device is set to be cuda. ``torch.compile`` end-to-end caching (``Mega-Cache``) ------------------------------------------------------------ End to end caching, from here onwards referred to ``Mega-Cache``, is the ideal solution for users looking for a portable caching solution that can be stored in a database and can later be fetched possibly on a separate machine. ``Mega-Cache`` provides two compiler APIs: * ``torch.compiler.save_cache_artifacts()`` * ``torch.compiler.load_cache_artifacts()`` The intended use case is after compiling and executing a model, the user calls ``torch.compiler.save_cache_artifacts()`` which will return the compiler artifacts in a portable form. Later, potentially on a different machine, the user may call ``torch.compiler.load_cache_artifacts()`` with these artifacts to pre-populate the ``torch.compile`` caches in order to jump-start their cache. Consider the following example. First, compile and save the cache artifacts. .. code-block:: python @torch.compile def fn(x, y): return x.sin() @ y a = torch.rand(100, 100, dtype=dtype, device=device) b = torch.rand(100, 100, dtype=dtype, device=device) result = fn(a, b) artifacts = torch.compiler.save_cache_artifacts() assert artifacts is not None artifact_bytes, cache_info = artifacts # Now, potentially store artifact_bytes in a database # You can use cache_info for logging Later, you can jump-start the cache by the following: .. code-block:: python # Potentially download/fetch the artifacts from the database torch.compiler.load_cache_artifacts(artifact_bytes) This operation populates all the modular caches that will be discussed in the next section, including ``PGO``, ``AOTAutograd``, ``Inductor``, ``Triton``, and ``Autotuning``. Modular caching of ``TorchDynamo``, ``TorchInductor``, and ``Triton`` ----------------------------------------------------------- The aforementioned ``Mega-Cache`` is composed of individual components that can be used without any user intervention. By default, PyTorch Compiler comes with local on-disk caches for ``TorchDynamo``, ``TorchInductor``, and ``Triton``. These caches include: * ``FXGraphCache``: A cache of graph-based IR components used in compilation. * ``TritonCache``: A cache of Triton-compilation results, including ``cubin`` files generated by ``Triton`` and other caching artifacts. * ``InductorCache``: A bundle of ``FXGraphCache`` and ``Triton`` cache. * ``AOTAutogradCache``: A cache of joint graph artifacts. * ``PGO-cache``: A cache of dynamic shape decisions to reduce number of recompilations. All these cache artifacts are written to ``TORCHINDUCTOR_CACHE_DIR`` which by default will look like ``/tmp/torchinductor_myusername``. Remote Caching ---------------- We also provide a remote caching option for users who would like to take advantage of a Redis based cache. Check out `Compile Time Caching Configurations <https://pytorch.org/tutorials/recipes/torch_compile_caching_configuration_tutorial.html>`__ to learn more about how to enable the Redis-based caching. Conclusion ------------- In this recipe, we have learned that PyTorch Inductor's caching mechanisms significantly reduce compilation latency by utilizing both local and remote caches, which operate seamlessly in the background without requiring user intervention.