Custom Kernels via Pallas¶
With the rise of OpenAI Triton, custom kernels become more and more popular in the GPU community, for instance, the introduction of FlashAttention and PagedAttention. In order to provide the feature parity in the TPU world, Google has introduced Pallas. For PyTorch/XLA to continue pushing the performance in TPU, we have to support custom kernels, and the best way is through Pallas.
Let’s assume you have a Pallas kernel defined as follow:
from torch_xla.experimental.custom_kernel import jax_import_guard
jax_import_guard()
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y
@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
return pl.pallas_call(add_vectors_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
To be noted, it’s very important to run jax_import_guard()
before
importing any jax modules. Otherwise, the program will hang on TPU as
jax will lock the TPU and torch-xla cannot access it.
Adopt the above kernel to be compatible with PyTorch/XLA¶
Example usage:
q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")
# Adopts any Pallas kernel
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y: [(x.shape, x.dtype)])
output = pt_kernel(q, k)
For simple kernels, the adoption is just as simple as one liner. For more complicated kernels, you can refer to our Flash Attention implementation for details.
Use built-in kernels¶
Besides manually wrapping external Pallas kernels, there are built-in kernels where the adoptions are done by PyTorch/XLA already. These built-in kernels can be used like any other torch.ops. The current built-in kernels that are suppored are: - FlashAttention -PagedAttention
FlashAttention¶
Example usage¶
# Use built-in kernels
import torch_xla.experimental.custom_kernel
output = flash_attention(q, k, v)
Integration Example¶
We have an example of FlashAttention integration here in our training test script.
PagedAttention¶
Example usage¶
# Use built-in kernels
import torch_xla.experimental.custom_kernel
output = torch.ops.xla.paged_attention(
query.squeeze(dim=1),
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
megacore_mode=None,
)
Integration Example¶
The vLLM TPU integration utilizes PagedAttention here for effective memory management with KV cache.
Dependencies¶
The Pallas integration depends on JAX to function. However, not every JAX version is compatible with your installed PyTorch/XLA. To install the proper JAX:
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html