Custom GPU Kernels via Triton¶
PyTorch/XLA now supports Triton kernels, enabling high-performance deep learning model execution on GPUs. Triton, a specialized language and compiler for GPU programming, empowers developers to write custom kernels that leverage the full potential of GPUs for various operations in deep learning models.
Given a Triton kernel defined as follows:
def add_kernel(
x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
# Triton add kernel from
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y + offsets, output, mask=mask)
We can run make this kernel a part of the PyTorch/XLA execution graph as follows:
import torch
import torch_xla.experimental.triton as xla_triton
import torch_xla
import triton
import triton.language as tl
size = 16
x = torch.arange(size, dtype=torch.int64).to("xla")
y = torch.arange(size, dtype=torch.int64).to("xla")
output = torch.empty_like(x)
block_size = 8
grid = (triton.cdiv(size, block_size),)
# triton_call takes the same arguments as the triton.jit function, in addition
to the kernel itself and the grid that is used to execute the kernel.
All the tl.constexpr terms are passed as kwargs at the end.
payload = xla_triton.triton_call(
x, y, output, size, kernel=add_kernel, grid=grid, BLOCK_SIZE=block_size)
# To make the triton kernel, a part of the PyTorch/XLA graph, we create a
# custom call node with the expected inputs, payload from triton_call,
# the output shapes and output dtypes. The payload already contains information
# regarding how the GPU buffers will be loaded when this node is executed.
output = torch_xla._XLAC._xla_gpu_custom_call([x, y], payload,
[output.shape], [torch.int64])
For more complex kernels, you can also refer to the Triton Flash Attention kernel test in PyTorch/XLA.
The Triton integration depends on the triton
package to function. This
code is tested with triton==2.3.0
. To install:
pip install --no-deps triton==2.3.0