• Tutorials >
  • Using User-Defined Triton Kernels with torch.compile
Shortcuts

Using User-Defined Triton Kernels with torch.compile

Created On: Apr 19, 2024 | Last Updated: Oct 16, 2024 | Last Verified: Nov 05, 2024

Author: Oguz Ulgen

User-defined Triton kernels can be used to optimize specific parts of your model’s computation. These kernels are written in Triton’s language, which is designed to make it easier to achieve peak hardware performance. By using user-defined Triton kernels with torch.compile, you can integrate these optimized computations into your PyTorch model, potentially achieving significant performance improvements.

This recipes demonstrates how you can use user-defined Triton kernels with torch.compile.

Prerequisites

Before starting this recipe, make sure that you have the following:

import torch
from torch.utils._triton import has_triton

Basic Usage

In this example, we will use a simple vector addition kernel from the Triton documentation with torch.compile. For reference, see Triton documentation.

if not has_triton():
    print("Skipping because triton is not supported on this device.")
else:
    import triton
    from triton import language as tl

    @triton.jit
    def add_kernel(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = x + y
        tl.store(out_ptr + offsets, output, mask=mask)

    @torch.compile(fullgraph=True)
    def add_fn(x, y):
        output = torch.zeros_like(x)
        n_elements = output.numel()
        grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
        add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)
        return output

    x = torch.randn(4, device="cuda")
    y = torch.randn(4, device="cuda")
    out = add_fn(x, y)
    print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
Vector addition of
X:      tensor([ 0.1940,  2.1614, -0.1721,  0.8491], device='cuda:0')
Y:      tensor([ 0.1391, -0.1082, -0.7174,  0.7566], device='cuda:0')
is equal to
tensor([ 0.3332,  2.0532, -0.8895,  1.6057], device='cuda:0')

Advanced Usage

Triton’s autotune feature is a powerful tool that automatically optimizes the configuration parameters of your Triton kernels. It explores a range of possible configurations and selects the one that delivers the best performance for your specific use case.

When used with torch.compile, triton.autotune can help ensure that your PyTorch model is running as efficiently as possible. Here is an example of using torch.compile and triton.autotune.

Note

torch.compile only supports configs and key arguments to triton.autotune.

if not has_triton():
    print("Skipping because triton is not supported on this device.")
else:
    import triton
    from triton import language as tl

    @triton.autotune(
        configs=[
            triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8),
            triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4),
            triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8),
            triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4),
        ],
        key=[],
    )
    @triton.jit
    def add_kernel_autotuned(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = x + y
        tl.store(out_ptr + offsets, output, mask=mask)

    @torch.compile(fullgraph=True)
    def add_fn(x, y):
        output = torch.zeros_like(x)
        n_elements = output.numel()
        grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
        add_kernel_autotuned[grid](x, y, output, n_elements)
        return output

    x = torch.randn(4, device="cuda")
    y = torch.randn(4, device="cuda")
    out = add_fn(x, y)
    print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
Vector addition of
X:      tensor([-0.5187,  1.2268,  0.6255, -0.9117], device='cuda:0')
Y:      tensor([-0.6974, -1.8688, -0.8832, -1.6627], device='cuda:0')
is equal to
tensor([-1.2161, -0.6421, -0.2577, -2.5744], device='cuda:0')

Composibility and Limitations

As of PyTorch 2.3, the support for user-defined Triton kernels in torch.compile includes dynamic shapes, torch.autograd.Function, JIT inductor, and AOT inductor. You can use these features together to build complex, high-performance models.

However, there are certain limitations to be aware of:

  • Tensor Subclasses: Currently, there is no support for tensor subclasses and other advanced features.

  • Triton Features: While triton.heuristics can be used either standalone or before triton.autotune, it cannot be used after triton.autotune. This implies that if triton.heuristics and triton.autotune are to be used together, triton.heuristics must be used first.

Conclusion

In this recipe, we explored how to utilize user-defined Triton kernels with torch.compile. We delved into the basic usage of a simple vector addition kernel and advanced usage involving Triton’s autotune feature. We also discussed the composability of user-defined Triton kernels with other PyTorch features and highlighted some current limitations.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources