• Tutorials >
  • torch.export AOTInductor Tutorial for Python runtime (Beta)
Shortcuts

torch.export AOTInductor Tutorial for Python runtime (Beta)

Author: Ankith Gunapal, Bin Bao, Angela Yi

Warning

torch._inductor.aot_compile and torch._export.aot_load are in Beta status and are subject to backwards compatibility breaking changes. This tutorial provides an example of how to use these APIs for model deployment using Python runtime.

It has been shown previously how AOTInductor can be used to do Ahead-of-Time compilation of PyTorch exported models by creating a shared library that can be run in a non-Python environment.

In this tutorial, you will learn an end-to-end example of how to use AOTInductor for Python runtime. We will look at how to use torch._inductor.aot_compile() along with torch.export.export() to generate a shared library. Additionally, we will examine how to execute the shared library in Python runtime using torch._export.aot_load(). You will learn about the speed up seen in the first inference time using AOTInductor, especially when using max-autotune mode which can take some time to execute.

Contents

Prerequisites

What you will learn

  • How to use AOTInductor for python runtime.

  • How to use torch._inductor.aot_compile() along with torch.export.export() to generate a shared library

  • How to run a shared library in Python runtime using torch._export.aot_load().

  • When do you use AOTInductor for python runtime

Model Compilation

We will use the TorchVision pretrained ResNet18 model and TorchInductor on the exported PyTorch program using torch._inductor.aot_compile().

Note

This API also supports torch.compile() options like mode This means that if used on a CUDA enabled device, you can, for example, set "max_autotune": True which leverages Triton based matrix multiplications & convolutions, and enables CUDA graphs by default.

We also specify dynamic_shapes for the batch dimension. In this example, min=2 is not a bug and is explained in The 0/1 Specialization Problem

import os
import torch
from torchvision.models import ResNet18_Weights, resnet18

model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.eval()

with torch.inference_mode():

    # Specify the generated shared library path
    aot_compile_options = {
            "aot_inductor.output_path": os.path.join(os.getcwd(), "resnet18_pt2.so"),
    }
    if torch.cuda.is_available():
        device = "cuda"
        aot_compile_options.update({"max_autotune": True})
    else:
        device = "cpu"

    model = model.to(device=device)
    example_inputs = (torch.randn(2, 3, 224, 224, device=device),)

    # min=2 is not a bug and is explained in the 0/1 Specialization Problem
    batch_dim = torch.export.Dim("batch", min=2, max=32)
    exported_program = torch.export.export(
        model,
        example_inputs,
        # Specify the first dimension of the input x as dynamic
        dynamic_shapes={"x": {0: batch_dim}},
    )
    so_path = torch._inductor.aot_compile(
        exported_program.module(),
        example_inputs,
        # Specify the generated shared library path
        options=aot_compile_options
    )
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

 91%|######### | 40.5M/44.7M [00:00<00:00, 424MB/s]
100%|##########| 44.7M/44.7M [00:00<00:00, 424MB/s]
AUTOTUNE convolution(2x3x224x224, 64x3x7x7)
  convolution 0.0464 ms 100.0%
  triton_convolution2d_0 0.1039 ms 44.7% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_4 0.1066 ms 43.5% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_3 0.1276 ms 36.4% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_1 0.1409 ms 33.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_5 0.1869 ms 24.8% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_2 0.2194 ms 21.2% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.7965 seconds and 0.0091 seconds precompiling
AUTOTUNE convolution(2x64x56x56, 64x64x3x3)
  convolution 0.0451 ms 100.0%
  triton_convolution2d_6 0.0738 ms 61.1% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_9 0.0751 ms 60.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_12 0.0777 ms 58.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_11 0.0831 ms 54.2% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_10 0.0844 ms 53.4% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_7 0.1010 ms 44.6% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_8 0.1409 ms 32.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.9516 seconds and 0.0008 seconds precompiling
AUTOTUNE convolution(2x64x56x56, 128x64x3x3)
  convolution 0.0344 ms 100.0%
  triton_convolution2d_38 0.0633 ms 54.3% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_40 0.0823 ms 41.7% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_34 0.0865 ms 39.7% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_39 0.0911 ms 37.7% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_37 0.1070 ms 32.1% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_35 0.1102 ms 31.2% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_36 0.3066 ms 11.2% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.9631 seconds and 0.0006 seconds precompiling
AUTOTUNE convolution(2x64x56x56, 128x64x1x1)
  triton_convolution2d_52 0.0110 ms 100.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  convolution 0.0129 ms 85.6%
  triton_convolution2d_53 0.0129 ms 85.6% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_48 0.0131 ms 84.4% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  triton_convolution2d_54 0.0150 ms 73.6% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_49 0.0158 ms 70.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  triton_convolution2d_51 0.0163 ms 67.6% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_50 0.0466 ms 23.7% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.9588 seconds and 0.0006 seconds precompiling
AUTOTUNE convolution(2x128x28x28, 128x128x3x3)
  convolution 0.0436 ms 100.0%
  triton_convolution2d_59 0.1173 ms 37.2% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_61 0.1362 ms 32.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_55 0.1659 ms 26.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_60 0.1757 ms 24.8% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_56 0.1900 ms 23.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_58 0.1946 ms 22.4% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_57 0.2665 ms 16.4% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.9528 seconds and 0.0006 seconds precompiling
AUTOTUNE convolution(2x128x28x28, 256x128x3x3)
  convolution 0.0372 ms 100.0%
  triton_convolution2d_73 0.0998 ms 37.3% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_75 0.1597 ms 23.3% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_72 0.2037 ms 18.3% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_70 0.2147 ms 17.4% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_71 0.2658 ms 14.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=1, num_warps=8
  triton_convolution2d_74 0.2849 ms 13.1% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_69 0.3378 ms 11.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.9639 seconds and 0.0005 seconds precompiling
AUTOTUNE convolution(2x128x28x28, 256x128x1x1)
  triton_convolution2d_87 0.0124 ms 100.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  convolution 0.0208 ms 59.6%
  triton_convolution2d_88 0.0214 ms 58.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_86 0.0217 ms 57.3% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_89 0.0224 ms 55.4% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_85 0.0333 ms 37.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=1, num_warps=8
  triton_convolution2d_83 0.0533 ms 23.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  triton_convolution2d_84 0.0658 ms 18.9% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.9972 seconds and 0.0006 seconds precompiling
AUTOTUNE convolution(2x256x14x14, 256x256x3x3)
  convolution 0.0535 ms 100.0%
  triton_convolution2d_94 0.1865 ms 28.7% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_92 0.2598 ms 20.6% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=512, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8
  triton_convolution2d_96 0.2629 ms 20.4% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_91 0.3745 ms 14.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_93 0.3758 ms 14.2% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_95 0.5476 ms 9.8% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_90 0.6542 ms 8.2% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.9437 seconds and 0.0004 seconds precompiling
AUTOTUNE convolution(2x256x14x14, 512x256x3x3)
  convolution 0.0531 ms 100.0%
  triton_convolution2d_108 0.1937 ms 27.4% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_106 0.2826 ms 18.8% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=512, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=1, num_warps=8
  triton_convolution2d_110 0.2939 ms 18.1% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_105 0.3821 ms 13.9% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_107 0.3896 ms 13.6% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_109 0.5600 ms 9.5% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_104 0.6858 ms 7.7% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.9560 seconds and 0.0007 seconds precompiling
AUTOTUNE convolution(2x256x14x14, 512x256x1x1)
  triton_convolution2d_122 0.0183 ms 100.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  convolution 0.0260 ms 70.1%
  triton_convolution2d_120 0.0336 ms 54.4% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=512, BLOCK_N=16, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=1, num_warps=8
  triton_convolution2d_124 0.0343 ms 53.2% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_123 0.0353 ms 51.7% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_121 0.0635 ms 28.8% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_118 0.1922 ms 9.5% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  triton_convolution2d_119 0.1972 ms 9.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.9982 seconds and 0.0005 seconds precompiling
AUTOTUNE convolution(2x512x7x7, 512x512x3x3)
  convolution 0.0858 ms 100.0%
  triton_convolution2d_127 0.2787 ms 30.8% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=128, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8
  triton_convolution2d_129 0.3616 ms 23.7% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_131 0.4244 ms 20.2% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_126 0.4851 ms 17.7% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_128 0.7236 ms 11.9% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_130 1.0993 ms 7.8% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_125 1.4547 ms 5.9% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.9573 seconds and 0.0006 seconds precompiling
AUTOTUNE addmm(2x1000, 2x512, 512x1000)
  addmm 0.0157 ms 100.0%
  triton_mm_142 0.0220 ms 71.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
  triton_mm_152 0.0304 ms 51.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_153 0.0309 ms 50.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
  triton_mm_141 0.0310 ms 50.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
  triton_mm_146 0.0312 ms 50.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_139 0.0340 ms 46.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=1, num_warps=2
  triton_mm_145 0.0373 ms 42.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_144 0.0460 ms 34.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_148 0.0499 ms 31.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 1.8137 seconds and 0.0020 seconds precompiling

Model Inference in Python

Typically, the shared object generated above is used in a non-Python environment. In PyTorch 2.3, we added a new API called torch._export.aot_load() to load the shared library in the Python runtime. The API follows a structure similar to the torch.jit.load() API . You need to specify the path of the shared library and the device where it should be loaded.

Note

In the example above, we specified batch_size=1 for inference and it still functions correctly even though we specified min=2 in torch.export.export().

import os
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
model_so_path = os.path.join(os.getcwd(), "resnet18_pt2.so")

model = torch._export.aot_load(model_so_path, device)
example_inputs = (torch.randn(1, 3, 224, 224, device=device),)

with torch.inference_mode():
    output = model(example_inputs)

When to use AOTInductor for Python Runtime

One of the requirements for using AOTInductor is that the model shouldn’t have any graph breaks. Once this requirement is met, the primary use case for using AOTInductor Python Runtime is for model deployment using Python. There are mainly two reasons why you would use AOTInductor Python Runtime:

  • torch._inductor.aot_compile generates a shared library. This is useful for model versioning for deployments and tracking model performance over time.

  • With torch.compile() being a JIT compiler, there is a warmup cost associated with the first compilation. Your deployment needs to account for the compilation time taken for the first inference. With AOTInductor, the compilation is done offline using torch.export.export & torch._indutor.aot_compile. The deployment would only load the shared library using torch._export.aot_load and run inference.

The section below shows the speedup achieved with AOTInductor for first inference

We define a utility function timed to measure the time taken for inference

import time
def timed(fn):
    # Returns the result of running `fn()` and the time it took for `fn()` to run,
    # in seconds. We use CUDA events and synchronization for accurate
    # measurement on CUDA enabled devices.
    if torch.cuda.is_available():
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
    else:
        start = time.time()

    result = fn()
    if torch.cuda.is_available():
        end.record()
        torch.cuda.synchronize()
    else:
        end = time.time()

    # Measure time taken to execute the function in miliseconds
    if torch.cuda.is_available():
        duration = start.elapsed_time(end)
    else:
        duration = (end - start) * 1000

    return result, duration

Lets measure the time for first inference using AOTInductor

torch._dynamo.reset()

model = torch._export.aot_load(model_so_path, device)
example_inputs = (torch.randn(1, 3, 224, 224, device=device),)

with torch.inference_mode():
    _, time_taken = timed(lambda: model(example_inputs))
    print(f"Time taken for first inference for AOTInductor is {time_taken:.2f} ms")
Time taken for first inference for AOTInductor is 3.02 ms

Lets measure the time for first inference using torch.compile

torch._dynamo.reset()

model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)
model.eval()

model = torch.compile(model)
example_inputs = torch.randn(1, 3, 224, 224, device=device)

with torch.inference_mode():
    _, time_taken = timed(lambda: model(example_inputs))
    print(f"Time taken for first inference for torch.compile is {time_taken:.2f} ms")
Time taken for first inference for torch.compile is 6874.33 ms

We see that there is a drastic speedup in first inference time using AOTInductor compared to torch.compile

Conclusion

In this recipe, we have learned how to effectively use the AOTInductor for Python runtime by compiling and loading a pretrained ResNet18 model using the torch._inductor.aot_compile and torch._export.aot_load APIs. This process demonstrates the practical application of generating a shared library and running it within a Python environment, even with dynamic shape considerations and device-specific optimizations. We also looked at the advantage of using AOTInductor in model deployments, with regards to speed up in first inference time.

Total running time of the script: ( 1 minutes 26.867 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources