# Custom Kernels via Pallas With the rise of OpenAI [Triton](https://openai.com/research/triton), custom kernels become more and more popular in the GPU community, for instance, the introduction of [FlashAttention](https://github.com/Dao-AILab/flash-attention) and [PagedAttention](https://blog.vllm.ai/2023/06/20/vllm.html). In order to provide the feature parity in the TPU world, Google has introduced [Pallas](https://jax.readthedocs.io/en/latest/pallas/index.html). For PyTorch/XLA to continue pushing the performance in TPU, we have to support custom kernels, and the best way is through Pallas. Let's assume you have a Pallas kernel defined as follow: ``` python3 from torch_xla.experimental.custom_kernel import jax_import_guard jax_import_guard() import jax from jax.experimental import pallas as pl import jax.numpy as jnp def add_vectors_kernel(x_ref, y_ref, o_ref): x, y = x_ref[...], y_ref[...] o_ref[...] = x + y @jax.jit def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: return pl.pallas_call(add_vectors_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) )(x, y) ``` To be noted, it's very important to run `jax_import_guard()` before importing any jax modules. Otherwise, the program will hang on TPU as jax will lock the TPU and torch-xla cannot access it. ## Adopt the above kernel to be compatible with PyTorch/XLA Example usage: ``` python3 q = torch.randn(3, 2, 128, 4).to("xla") k = torch.randn(3, 2, 128, 4).to("xla") v = torch.randn(3, 2, 128, 4).to("xla") # Adopts any Pallas kernel from torch_xla.experimental.custom_kernel import make_kernel_from_pallas pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y: [(x.shape, x.dtype)]) output = pt_kernel(q, k) ``` For simple kernels, the adoption is just as simple as one liner. For more complicated kernels, you can refer to our Flash Attention implementation for details. ## Use built-in kernels Besides manually wrapping external Pallas kernels, there are built-in kernels where the adoptions are done by PyTorch/XLA already. These built-in kernels can be used like any other torch.ops. The current built-in kernels that are suppored are: - FlashAttention -PagedAttention ### FlashAttention #### Example usage ``` python3 # Use built-in kernels import torch_xla.experimental.custom_kernel output = flash_attention(q, k, v) ``` #### Integration Example We have an example of [FlashAttention integration here](https://github.com/pytorch/xla/blob/master/examples/flash_attention/train_decoder_only_flash_attention.py) in our training test script. ### PagedAttention #### Example usage ``` python3 # Use built-in kernels import torch_xla.experimental.custom_kernel output = torch.ops.xla.paged_attention( query.squeeze(dim=1), key_cache, value_cache, context_lens, block_tables, pages_per_compute_block, megacore_mode=None, ) ``` #### Integration Example The vLLM TPU integration utilizes [PagedAttention here](https://github.com/vllm-project/vllm/blob/f5e1bf5d44877149eaabf9c04379a4e14a023145/vllm/attention/backends/pallas.py#L194) for effective memory management with KV cache. ## Dependencies The Pallas integration depends on JAX to function. However, not every JAX version is compatible with your installed PyTorch/XLA. To install the proper JAX: ``` bash pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html ```