• Docs >
  • Custom Kernels via Pallas
Shortcuts

Custom Kernels via Pallas

With the rise of OpenAI Triton, custom kernels become more and more popular in the GPU community, for instance, the introduction of FlashAttention and PagedAttention. In order to provide the feature parity in the TPU world, Google has introduced Pallas. For PyTorch/XLA to continue pushing the performance in TPU, we have to support custom kernels, and the best way is through Pallas.

Let’s assume you have a Pallas kernel defined as follow:

from torch_xla.experimental.custom_kernel import jax_import_guard
jax_import_guard()

import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp

def add_vectors_kernel(x_ref, y_ref, o_ref):
  x, y = x_ref[...], y_ref[...]
  o_ref[...] = x + y

@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
  return pl.pallas_call(add_vectors_kernel,
                        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
                        )(x, y)

To be noted, it’s very important to run jax_import_guard() before importing any jax modules. Otherwise, the program will hang on TPU as jax will lock the TPU and torch-xla cannot access it.

Adopt the above kernel to be compatible with PyTorch/XLA

Example usage:

q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")

# Adopts any Pallas kernel
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y: [(x.shape, x.dtype)])
output = pt_kernel(q, k)

For simple kernels, the adoption is just as simple as one liner. For more complicated kernels, you can refer to our Flash Attention implementation for details.

Use built-in kernels

Besides manually wrapping external Pallas kernels, there are built-in kernels where the adoptions are done by PyTorch/XLA already. These built-in kernels can be used like any other torch.ops. The current built-in kernels that are suppored are: - FlashAttention -PagedAttention

FlashAttention

Example usage

# Use built-in kernels
import torch_xla.experimental.custom_kernel
output = flash_attention(q, k, v)

Integration Example

We have an example of FlashAttention integration here in our training test script.

PagedAttention

Example usage

# Use built-in kernels
import torch_xla.experimental.custom_kernel
output = torch.ops.xla.paged_attention(
    query.squeeze(dim=1),
    key_cache,
    value_cache,
    context_lens,
    block_tables,
    pages_per_compute_block,
    megacore_mode=None,
)

Integration Example

The vLLM TPU integration utilizes PagedAttention here for effective memory management with KV cache.

Dependencies

The Pallas integration depends on JAX to function. However, not every JAX version is compatible with your installed PyTorch/XLA. To install the proper JAX:

pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources