.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "recipes/torch_export_aoti_python.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here <sphx_glr_download_recipes_torch_export_aoti_python.py>` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_recipes_torch_export_aoti_python.py: .. meta:: :description: An end-to-end example of how to use AOTInductor for Python runtime. :keywords: torch.export, AOTInductor, torch._inductor.aoti_compile_and_package, aot_compile, torch._export.aoti_load_package ``torch.export`` AOTInductor Tutorial for Python runtime (Beta) =============================================================== **Author:** Ankith Gunapal, Bin Bao, Angela Yi .. GENERATED FROM PYTHON SOURCE LINES 14-33 .. warning:: ``torch._inductor.aoti_compile_and_package`` and ``torch._inductor.aoti_load_package`` are in Beta status and are subject to backwards compatibility breaking changes. This tutorial provides an example of how to use these APIs for model deployment using Python runtime. It has been shown `previously <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`__ how AOTInductor can be used to do Ahead-of-Time compilation of PyTorch exported models by creating an artifact that can be run in a non-Python environment. In this tutorial, you will learn an end-to-end example of how to use AOTInductor for Python runtime. **Contents** .. contents:: :local: .. GENERATED FROM PYTHON SOURCE LINES 36-41 Prerequisites ------------- * PyTorch 2.6 or later * Basic understanding of ``torch.export`` and AOTInductor * Complete the `AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`_ tutorial .. GENERATED FROM PYTHON SOURCE LINES 43-49 What you will learn ---------------------- * How to use AOTInductor for Python runtime. * How to use :func:`torch._inductor.aoti_compile_and_package` along with :func:`torch.export.export` to generate a compiled artifact * How to load and run the artifact in a Python runtime using :func:`torch._export.aot_load`. * When to you use AOTInductor with a Python runtime .. GENERATED FROM PYTHON SOURCE LINES 51-72 Model Compilation ----------------- We will use the TorchVision pretrained ``ResNet18`` model as an example. The first step is to export the model to a graph representation using :func:`torch.export.export`. To learn more about using this function, you can check out the `docs <https://pytorch.org/docs/main/export.html>`_ or the `tutorial <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html>`_. Once we have exported the PyTorch model and obtained an ``ExportedProgram``, we can apply :func:`torch._inductor.aoti_compile_and_package` to AOTInductor to compile the program to a specified device, and save the generated contents into a ".pt2" artifact. .. note:: This API supports the same available options that :func:`torch.compile` has, such as ``mode`` and ``max_autotune`` (for those who want to enable CUDA graphs and leverage Triton based matrix multiplications and convolutions) .. GENERATED FROM PYTHON SOURCE LINES 72-103 .. code-block:: default import os import torch import torch._inductor from torchvision.models import ResNet18_Weights, resnet18 model = resnet18(weights=ResNet18_Weights.DEFAULT) model.eval() with torch.inference_mode(): inductor_configs = {} if torch.cuda.is_available(): device = "cuda" inductor_configs["max_autotune"] = True else: device = "cpu" model = model.to(device=device) example_inputs = (torch.randn(2, 3, 224, 224, device=device),) exported_program = torch.export.export( model, example_inputs, ) path = torch._inductor.aoti_compile_and_package( exported_program, package_path=os.path.join(os.getcwd(), "resnet18.pt2"), inductor_configs=inductor_configs ) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth 0%| | 0.00/44.7M [00:00<?, ?B/s] 92%|#########2| 41.1M/44.7M [00:00<00:00, 431MB/s] 100%|##########| 44.7M/44.7M [00:00<00:00, 430MB/s] AUTOTUNE convolution(2x3x224x224, 64x3x7x7) convolution 0.0531 ms 100.0% triton_convolution2d_4 0.1393 ms 38.1% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_0 0.1497 ms 35.5% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4 triton_convolution2d_3 0.1834 ms 29.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_5 0.2474 ms 21.5% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_2 0.5345 ms 9.9% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=1, num_warps=8 triton_convolution2d_1 0.8941 ms 5.9% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4 SingleProcess AUTOTUNE benchmarking takes 0.8979 seconds and 0.0077 seconds precompiling for 7 choices AUTOTUNE convolution(2x64x56x56, 64x64x3x3) triton_convolution2d_10 0.0355 ms 100.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4 triton_convolution2d_6 0.0364 ms 97.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4 triton_convolution2d_11 0.0366 ms 96.8% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_9 0.0421 ms 84.3% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8 convolution 0.0458 ms 77.4% triton_convolution2d_12 0.0650 ms 54.6% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_7 0.0769 ms 46.1% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4 triton_convolution2d_8 0.1265 ms 28.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8 SingleProcess AUTOTUNE benchmarking takes 0.9425 seconds and 0.0007 seconds precompiling for 8 choices AUTOTUNE convolution(2x64x56x56, 128x64x3x3) triton_convolution2d_38 0.0290 ms 100.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4 triton_convolution2d_39 0.0399 ms 72.6% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8 convolution 0.0456 ms 63.6% triton_convolution2d_34 0.0477 ms 60.8% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4 triton_convolution2d_37 0.0601 ms 48.3% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_40 0.0618 ms 46.9% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_35 0.0719 ms 40.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4 triton_convolution2d_36 0.1330 ms 21.8% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=1, num_warps=8 SingleProcess AUTOTUNE benchmarking takes 0.9405 seconds and 0.0006 seconds precompiling for 8 choices AUTOTUNE convolution(2x128x28x28, 128x128x3x3) convolution 0.0448 ms 100.0% triton_convolution2d_45 0.0498 ms 90.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4 triton_convolution2d_46 0.0702 ms 63.8% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_41 0.0845 ms 53.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4 triton_convolution2d_44 0.1040 ms 43.1% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_47 0.1170 ms 38.3% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_42 0.1399 ms 32.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4 triton_convolution2d_43 0.2408 ms 18.6% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8 SingleProcess AUTOTUNE benchmarking takes 0.9516 seconds and 0.0008 seconds precompiling for 8 choices AUTOTUNE convolution(2x64x56x56, 128x64x1x1) triton_convolution2d_52 0.0086 ms 100.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4 triton_convolution2d_53 0.0099 ms 86.5% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8 triton_convolution2d_48 0.0107 ms 80.5% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4 triton_convolution2d_51 0.0131 ms 65.7% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8 triton_convolution2d_54 0.0133 ms 64.4% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8 convolution 0.0134 ms 64.1% triton_convolution2d_49 0.0147 ms 58.5% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4 triton_convolution2d_50 0.0218 ms 39.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=1, num_warps=8 SingleProcess AUTOTUNE benchmarking takes 0.9260 seconds and 0.0006 seconds precompiling for 8 choices AUTOTUNE convolution(2x128x28x28, 256x128x3x3) convolution 0.0368 ms 100.0% triton_convolution2d_73 0.0493 ms 74.6% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4 triton_convolution2d_74 0.1113 ms 33.1% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_72 0.1135 ms 32.5% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_75 0.1160 ms 31.8% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_69 0.1325 ms 27.8% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4 triton_convolution2d_70 0.1343 ms 27.4% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4 triton_convolution2d_71 0.2053 ms 17.9% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=1, num_warps=8 SingleProcess AUTOTUNE benchmarking takes 0.9550 seconds and 0.0006 seconds precompiling for 8 choices AUTOTUNE convolution(2x256x14x14, 256x256x3x3) convolution 0.0564 ms 100.0% triton_convolution2d_80 0.0915 ms 61.7% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4 triton_convolution2d_79 0.2100 ms 26.9% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_81 0.2148 ms 26.3% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_82 0.2267 ms 24.9% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_76 0.2569 ms 22.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4 triton_convolution2d_77 0.2720 ms 20.7% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4 triton_convolution2d_78 0.3744 ms 15.1% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=512, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8 SingleProcess AUTOTUNE benchmarking takes 0.9770 seconds and 0.0007 seconds precompiling for 8 choices AUTOTUNE convolution(2x128x28x28, 256x128x1x1) triton_convolution2d_87 0.0106 ms 100.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4 triton_convolution2d_86 0.0189 ms 55.7% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8 triton_convolution2d_89 0.0190 ms 55.5% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8 triton_convolution2d_88 0.0193 ms 54.7% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8 triton_convolution2d_84 0.0219 ms 48.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4 triton_convolution2d_83 0.0220 ms 48.1% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4 convolution 0.0257 ms 41.0% triton_convolution2d_85 0.0266 ms 39.8% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=1, num_warps=8 SingleProcess AUTOTUNE benchmarking takes 0.9271 seconds and 0.0006 seconds precompiling for 8 choices AUTOTUNE convolution(2x256x14x14, 512x256x3x3) convolution 0.0576 ms 100.0% triton_convolution2d_108 0.0929 ms 62.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4 triton_convolution2d_109 0.2170 ms 26.5% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_107 0.2213 ms 26.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_110 0.2248 ms 25.6% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_106 0.2410 ms 23.9% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=512, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=1, num_warps=8 triton_convolution2d_105 0.2598 ms 22.2% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4 triton_convolution2d_104 0.2603 ms 22.1% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4 SingleProcess AUTOTUNE benchmarking takes 0.9733 seconds and 0.0006 seconds precompiling for 8 choices AUTOTUNE convolution(2x512x7x7, 512x512x3x3) convolution 0.0851 ms 100.0% triton_convolution2d_115 0.1805 ms 47.1% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4 triton_convolution2d_113 0.2153 ms 39.5% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=128, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8 triton_convolution2d_117 0.2669 ms 31.9% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_112 0.3220 ms 26.4% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4 triton_convolution2d_114 0.4026 ms 21.1% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_116 0.4275 ms 19.9% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8 triton_convolution2d_111 0.5135 ms 16.6% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4 SingleProcess AUTOTUNE benchmarking takes 0.9952 seconds and 0.0007 seconds precompiling for 8 choices AUTOTUNE convolution(2x256x14x14, 512x256x1x1) triton_convolution2d_122 0.0150 ms 100.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4 triton_convolution2d_120 0.0264 ms 56.8% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=512, BLOCK_N=16, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=1, num_warps=8 convolution 0.0286 ms 52.5% triton_convolution2d_121 0.0304 ms 49.4% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8 triton_convolution2d_124 0.0308 ms 48.8% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8 triton_convolution2d_118 0.0310 ms 48.4% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4 triton_convolution2d_123 0.0315 ms 47.7% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8 triton_convolution2d_119 0.0332 ms 45.2% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4 SingleProcess AUTOTUNE benchmarking takes 0.9282 seconds and 0.0006 seconds precompiling for 8 choices AUTOTUNE addmm(2x1000, 2x512, 512x1000) triton_mm_146 0.0116 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_141 0.0118 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4 triton_mm_153 0.0120 ms 96.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4 triton_mm_142 0.0124 ms 93.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2 triton_mm_143 0.0126 ms 91.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2 triton_mm_152 0.0127 ms 91.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_140 0.0132 ms 87.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=2 triton_mm_150 0.0133 ms 87.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_149 0.0134 ms 86.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8 triton_mm_148 0.0143 ms 81.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 SingleProcess AUTOTUNE benchmarking takes 1.7579 seconds and 0.0022 seconds precompiling for 18 choices .. GENERATED FROM PYTHON SOURCE LINES 104-163 The result of :func:`aoti_compile_and_package` is an artifact "resnet18.pt2" which can be loaded and executed in Python and C++. The artifact itself contains a bunch of AOTInductor generated code, such as a generated C++ runner file, a shared library compiled from the C++ file, and CUDA binary files, aka cubin files, if optimizing for CUDA. Structure-wise, the artifact is a structured ``.zip`` file, with the following specification: .. code:: . ├── archive_format ├── version ├── data │ ├── aotinductor │ │ └── model │ │ ├── xxx.cpp # AOTInductor generated cpp file │ │ ├── xxx.so # AOTInductor generated shared library │ │ ├── xxx.cubin # Cubin files (if running on CUDA) │ │ └── xxx_metadata.json # Additional metadata to save │ ├── weights │ │ └── TBD │ └── constants │ └── TBD └── extra └── metadata.json We can use the following command to inspect the artifact contents: .. code:: bash $ unzip -l resnet18.pt2 .. code:: Archive: resnet18.pt2 Length Date Time Name --------- ---------- ----- ---- 1 01-08-2025 16:40 version 3 01-08-2025 16:40 archive_format 10088 01-08-2025 16:40 data/aotinductor/model/cagzt6akdaczvxwtbvqe34otfe5jlorktbqlojbzqjqvbfsjlge4.cubin 17160 01-08-2025 16:40 data/aotinductor/model/c6oytfjmt5w4c7onvtm6fray7clirxt7q5xjbwx3hdydclmwoujz.cubin 16616 01-08-2025 16:40 data/aotinductor/model/c7ydp7nocyz323hij4tmlf2kcedmwlyg6r57gaqzcsy3huneamu6.cubin 17776 01-08-2025 16:40 data/aotinductor/model/cyqdf46ordevqhiddvpdpp3uzwatfbzdpl3auj2nx23uxvplnne2.cubin 10856 01-08-2025 16:40 data/aotinductor/model/cpzfebfgrusqslui7fxsuoo4tvwulmrxirc5tmrpa4mvrbdno7kn.cubin 14608 01-08-2025 16:40 data/aotinductor/model/c5ukeoz5wmaszd7vczdz2qhtt6n7tdbl3b6wuy4rb2se24fjwfoy.cubin 11376 01-08-2025 16:40 data/aotinductor/model/csu3nstcp56tsjfycygaqsewpu64l5s6zavvz7537cm4s4cv2k3r.cubin 10984 01-08-2025 16:40 data/aotinductor/model/cp76lez4glmgq7gedf2u25zvvv6rksv5lav4q22dibd2zicbgwj3.cubin 14736 01-08-2025 16:40 data/aotinductor/model/c2bb5p6tnwz4elgujqelsrp3unvkgsyiv7xqxmpvuxcm4jfl7pc2.cubin 11376 01-08-2025 16:40 data/aotinductor/model/c6eopmb2b4ngodwsayae4r5q6ni3jlfogfbdk3ypg56tgpzhubfy.cubin 11624 01-08-2025 16:40 data/aotinductor/model/chmwe6lvoekzfowdbiizitm3haiiuad5kdm6sd2m6mv6dkn2zk32.cubin 15632 01-08-2025 16:40 data/aotinductor/model/c3jop5g344hj3ztsu4qm6ibxyaaerlhkzh2e6emak23rxfje6jam.cubin 25472 01-08-2025 16:40 data/aotinductor/model/chaiixybeiuuitm2nmqnxzijzwgnn2n7uuss4qmsupgblfh3h5hk.cubin 139389 01-08-2025 16:40 data/aotinductor/model/cvk6qzuybruhwxtfblzxiov3rlrziv5fkqc4mdhbmantfu3lmd6t.cpp 27 01-08-2025 16:40 data/aotinductor/model/cvk6qzuybruhwxtfblzxiov3rlrziv5fkqc4mdhbmantfu3lmd6t_metadata.json 47195424 01-08-2025 16:40 data/aotinductor/model/cvk6qzuybruhwxtfblzxiov3rlrziv5fkqc4mdhbmantfu3lmd6t.so --------- ------- 47523148 18 files .. GENERATED FROM PYTHON SOURCE LINES 166-171 Model Inference in Python ------------------------- To load and run the artifact in Python, we can use :func:`torch._inductor.aoti_load_package`. .. GENERATED FROM PYTHON SOURCE LINES 171-185 .. code-block:: default import os import torch import torch._inductor model_path = os.path.join(os.getcwd(), "resnet18.pt2") compiled_model = torch._inductor.aoti_load_package(model_path) example_inputs = (torch.randn(2, 3, 224, 224, device=device),) with torch.inference_mode(): output = compiled_model(example_inputs) .. GENERATED FROM PYTHON SOURCE LINES 186-207 When to use AOTInductor with a Python Runtime --------------------------------------------- There are mainly two reasons why one would use AOTInductor with a Python Runtime: - ``torch._inductor.aoti_compile_and_package`` generates a singular serialized artifact. This is useful for model versioning for deployments and tracking model performance over time. - With :func:`torch.compile` being a JIT compiler, there is a warmup cost associated with the first compilation. Your deployment needs to account for the compilation time taken for the first inference. With AOTInductor, the compilation is done ahead of time using ``torch.export.export`` and ``torch._inductor.aoti_compile_and_package``. At deployment time, after loading the model, running inference does not have any additional cost. The section below shows the speedup achieved with AOTInductor for first inference We define a utility function ``timed`` to measure the time taken for inference .. GENERATED FROM PYTHON SOURCE LINES 207-236 .. code-block:: default import time def timed(fn): # Returns the result of running `fn()` and the time it took for `fn()` to run, # in seconds. We use CUDA events and synchronization for accurate # measurement on CUDA enabled devices. if torch.cuda.is_available(): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() else: start = time.time() result = fn() if torch.cuda.is_available(): end.record() torch.cuda.synchronize() else: end = time.time() # Measure time taken to execute the function in miliseconds if torch.cuda.is_available(): duration = start.elapsed_time(end) else: duration = (end - start) * 1000 return result, duration .. GENERATED FROM PYTHON SOURCE LINES 237-238 Lets measure the time for first inference using AOTInductor .. GENERATED FROM PYTHON SOURCE LINES 238-249 .. code-block:: default torch._dynamo.reset() model = torch._inductor.aoti_load_package(model_path) example_inputs = (torch.randn(1, 3, 224, 224, device=device),) with torch.inference_mode(): _, time_taken = timed(lambda: model(example_inputs)) print(f"Time taken for first inference for AOTInductor is {time_taken:.2f} ms") .. rst-class:: sphx-glr-script-out .. code-block:: none Time taken for first inference for AOTInductor is 3.99 ms .. GENERATED FROM PYTHON SOURCE LINES 250-251 Lets measure the time for first inference using ``torch.compile`` .. GENERATED FROM PYTHON SOURCE LINES 251-264 .. code-block:: default torch._dynamo.reset() model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device) model.eval() model = torch.compile(model) example_inputs = torch.randn(1, 3, 224, 224, device=device) with torch.inference_mode(): _, time_taken = timed(lambda: model(example_inputs)) print(f"Time taken for first inference for torch.compile is {time_taken:.2f} ms") .. rst-class:: sphx-glr-script-out .. code-block:: none Time taken for first inference for torch.compile is 5632.31 ms .. GENERATED FROM PYTHON SOURCE LINES 265-267 We see that there is a drastic speedup in first inference time using AOTInductor compared to ``torch.compile`` .. GENERATED FROM PYTHON SOURCE LINES 269-277 Conclusion ---------- In this recipe, we have learned how to effectively use the AOTInductor for Python runtime by compiling and loading a pretrained ``ResNet18`` model. This process demonstrates the practical application of generating a compiled artifact and running it within a Python environment. We also looked at the advantage of using AOTInductor in model deployments, with regards to speed up in first inference time. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 1 minutes 3.895 seconds) .. _sphx_glr_download_recipes_torch_export_aoti_python.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: torch_export_aoti_python.py <torch_export_aoti_python.py>` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: torch_export_aoti_python.ipynb <torch_export_aoti_python.ipynb>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_