• Docs >
  • Custom GPU Kernels via Triton
Shortcuts

Custom GPU Kernels via Triton

PyTorch/XLA now supports Triton kernels, enabling high-performance deep learning model execution on GPUs. Triton, a specialized language and compiler for GPU programming, empowers developers to write custom kernels that leverage the full potential of GPUs for various operations in deep learning models.

Given a Triton kernel defined as follows:

@triton.jit
def add_kernel(
    x_ptr,  # *Pointer* to first input vector.
    y_ptr,  # *Pointer* to second input vector.
    output_ptr,  # *Pointer* to output vector.
    n_elements,  # Size of the vector.
    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
    # NOTE: `constexpr` so it can be used as a shape value.
):
  # Triton add kernel from https://github.com/openai/triton/blob/main/python/tutorials/01-vector-add.py#L28
  pid = tl.program_id(axis=0)
  block_start = pid * BLOCK_SIZE
  offsets = block_start + tl.arange(0, BLOCK_SIZE)
  mask = offsets < n_elements
  x = tl.load(x_ptr + offsets, mask=mask)
  y = tl.load(y_ptr + offsets, mask=mask)
  output = x + y
  tl.store(output_ptr + offsets, output, mask=mask)

We can run make this kernel a part of the PyTorch/XLA execution graph as follows:

import torch

import torch_xla.experimental.triton as xla_triton
import torch_xla

import triton
import triton.language as tl

size = 16
x = torch.arange(size, dtype=torch.int64).to("xla")
y = torch.arange(size, dtype=torch.int64).to("xla")
output = torch.empty_like(x)
block_size = 8
grid = (triton.cdiv(size, block_size),)

# triton_call takes the same arguments as the triton.jit function, in addition
to the kernel itself and the grid that is used to execute the kernel.
All the tl.constexpr terms are passed as kwargs at the end.
payload = xla_triton.triton_call(
    x, y, output, size, kernel=add_kernel, grid=grid, BLOCK_SIZE=block_size)

# To make the triton kernel, a part of the PyTorch/XLA graph, we create a
# custom call node with the expected inputs, payload from triton_call,
# the output shapes and output dtypes. The payload already contains information
# regarding how the GPU buffers will be loaded when this node is executed.
output = torch_xla._XLAC._xla_gpu_custom_call([x, y], payload,
                                                [output.shape], [torch.int64])

For more complex kernels, you can also refer to the Triton Flash Attention kernel test in PyTorch/XLA.

Dependencies

The Triton integration depends on the triton package to function. This code is tested with triton==2.3.0. To install:

pip install --no-deps triton==2.3.0

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources