Note
Click here to download the full example code
torch.export
AOTInductor Tutorial for Python runtime (Beta)¶
Author: Ankith Gunapal, Bin Bao, Angela Yi
Warning
torch._inductor.aot_compile
and torch._export.aot_load
are in Beta status and are subject to backwards compatibility
breaking changes. This tutorial provides an example of how to use these APIs for model deployment using Python runtime.
It has been shown previously how AOTInductor can be used to do Ahead-of-Time compilation of PyTorch exported models by creating a shared library that can be run in a non-Python environment.
In this tutorial, you will learn an end-to-end example of how to use AOTInductor for Python runtime.
We will look at how to use torch._inductor.aot_compile()
along with torch.export.export()
to generate a
shared library. Additionally, we will examine how to execute the shared library in Python runtime using torch._export.aot_load()
.
You will learn about the speed up seen in the first inference time using AOTInductor, especially when using
max-autotune
mode which can take some time to execute.
Contents
Prerequisites¶
PyTorch 2.4 or later
Basic understanding of
torch.export
and AOTInductorComplete the AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models tutorial
What you will learn¶
How to use AOTInductor for python runtime.
How to use
torch._inductor.aot_compile()
along withtorch.export.export()
to generate a shared libraryHow to run a shared library in Python runtime using
torch._export.aot_load()
.When do you use AOTInductor for python runtime
Model Compilation¶
We will use the TorchVision pretrained ResNet18 model and TorchInductor on the
exported PyTorch program using torch._inductor.aot_compile()
.
Note
This API also supports torch.compile()
options like mode
This means that if used on a CUDA enabled device, you can, for example, set "max_autotune": True
which leverages Triton based matrix multiplications & convolutions, and enables CUDA graphs by default.
We also specify dynamic_shapes
for the batch dimension. In this example, min=2
is not a bug and is
explained in The 0/1 Specialization Problem
import os
import torch
from torchvision.models import ResNet18_Weights, resnet18
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.eval()
with torch.inference_mode():
# Specify the generated shared library path
aot_compile_options = {
"aot_inductor.output_path": os.path.join(os.getcwd(), "resnet18_pt2.so"),
}
if torch.cuda.is_available():
device = "cuda"
aot_compile_options.update({"max_autotune": True})
else:
device = "cpu"
model = model.to(device=device)
example_inputs = (torch.randn(2, 3, 224, 224, device=device),)
# min=2 is not a bug and is explained in the 0/1 Specialization Problem
batch_dim = torch.export.Dim("batch", min=2, max=32)
exported_program = torch.export.export(
model,
example_inputs,
# Specify the first dimension of the input x as dynamic
dynamic_shapes={"x": {0: batch_dim}},
)
so_path = torch._inductor.aot_compile(
exported_program.module(),
example_inputs,
# Specify the generated shared library path
options=aot_compile_options
)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
0%| | 0.00/44.7M [00:00<?, ?B/s]
94%|#########3| 41.9M/44.7M [00:00<00:00, 438MB/s]
100%|##########| 44.7M/44.7M [00:00<00:00, 432MB/s]
AUTOTUNE convolution(2x3x224x224, 64x3x7x7)
convolution 0.0455 ms 100.0%
triton_convolution_0 0.1035 ms 44.0%
triton_convolution_4 0.1066 ms 42.7%
triton_convolution_3 0.1273 ms 35.8%
triton_convolution_1 0.1399 ms 32.6%
triton_convolution_5 0.1852 ms 24.6%
triton_convolution_2 0.2191 ms 20.8%
SingleProcess AUTOTUNE benchmarking takes 0.8572 seconds and 0.0067 seconds precompiling
AUTOTUNE convolution(2x64x56x56, 64x64x3x3)
convolution 0.0436 ms 100.0%
triton_convolution_6 0.0743 ms 58.7%
triton_convolution_9 0.0746 ms 58.4%
triton_convolution_12 0.0771 ms 56.6%
triton_convolution_10 0.0837 ms 52.1%
triton_convolution_11 0.0840 ms 51.9%
triton_convolution_7 0.1409 ms 31.0%
triton_convolution_8 0.1421 ms 30.7%
SingleProcess AUTOTUNE benchmarking takes 0.9702 seconds and 0.0004 seconds precompiling
AUTOTUNE convolution(2x64x56x56, 128x64x3x3)
convolution 0.0338 ms 100.0%
triton_convolution_38 0.0630 ms 53.6%
triton_convolution_40 0.0817 ms 41.3%
triton_convolution_34 0.0860 ms 39.3%
triton_convolution_39 0.0910 ms 37.1%
triton_convolution_37 0.1068 ms 31.6%
triton_convolution_35 0.1551 ms 21.8%
triton_convolution_36 0.3043 ms 11.1%
SingleProcess AUTOTUNE benchmarking takes 0.9807 seconds and 0.0005 seconds precompiling
AUTOTUNE convolution(2x64x56x56, 128x64x1x1)
triton_convolution_52 0.0107 ms 100.0%
triton_convolution_53 0.0123 ms 87.0%
triton_convolution_48 0.0126 ms 85.0%
convolution 0.0132 ms 81.1%
triton_convolution_54 0.0148 ms 72.4%
triton_convolution_51 0.0156 ms 68.6%
triton_convolution_50 0.0458 ms 23.4%
triton_convolution_49 0.0742 ms 14.4%
SingleProcess AUTOTUNE benchmarking takes 0.9813 seconds and 0.0004 seconds precompiling
AUTOTUNE convolution(2x128x28x28, 128x128x3x3)
convolution 0.0435 ms 100.0%
triton_convolution_59 0.1172 ms 37.1%
triton_convolution_61 0.1354 ms 32.1%
triton_convolution_55 0.1653 ms 26.3%
triton_convolution_60 0.1759 ms 24.7%
triton_convolution_56 0.1897 ms 22.9%
triton_convolution_58 0.1947 ms 22.3%
triton_convolution_57 0.2677 ms 16.2%
SingleProcess AUTOTUNE benchmarking takes 0.9639 seconds and 0.0005 seconds precompiling
AUTOTUNE convolution(2x128x28x28, 256x128x3x3)
convolution 0.0370 ms 100.0%
triton_convolution_73 0.0991 ms 37.3%
triton_convolution_75 0.1593 ms 23.2%
triton_convolution_72 0.2042 ms 18.1%
triton_convolution_70 0.2147 ms 17.2%
triton_convolution_71 0.2649 ms 14.0%
triton_convolution_74 0.2850 ms 13.0%
triton_convolution_69 0.3372 ms 11.0%
SingleProcess AUTOTUNE benchmarking takes 0.9716 seconds and 0.0005 seconds precompiling
AUTOTUNE convolution(2x128x28x28, 256x128x1x1)
triton_convolution_87 0.0121 ms 100.0%
convolution 0.0203 ms 59.6%
triton_convolution_88 0.0211 ms 57.2%
triton_convolution_89 0.0276 ms 43.6%
triton_convolution_85 0.0324 ms 37.3%
triton_convolution_86 0.0451 ms 26.8%
triton_convolution_83 0.1226 ms 9.8%
triton_convolution_84 0.1424 ms 8.5%
SingleProcess AUTOTUNE benchmarking takes 1.0169 seconds and 0.0004 seconds precompiling
AUTOTUNE convolution(2x256x14x14, 256x256x3x3)
convolution 0.0528 ms 100.0%
triton_convolution_94 0.1864 ms 28.3%
triton_convolution_92 0.2603 ms 20.3%
triton_convolution_96 0.2627 ms 20.1%
triton_convolution_91 0.3710 ms 14.2%
triton_convolution_93 0.3743 ms 14.1%
triton_convolution_95 0.5475 ms 9.6%
triton_convolution_90 0.6541 ms 8.1%
SingleProcess AUTOTUNE benchmarking takes 0.9500 seconds and 0.0006 seconds precompiling
AUTOTUNE convolution(2x256x14x14, 512x256x3x3)
convolution 0.0527 ms 100.0%
triton_convolution_108 0.1923 ms 27.4%
triton_convolution_106 0.2811 ms 18.7%
triton_convolution_110 0.2926 ms 18.0%
triton_convolution_105 0.3823 ms 13.8%
triton_convolution_107 0.3891 ms 13.5%
triton_convolution_109 0.5592 ms 9.4%
triton_convolution_104 0.6857 ms 7.7%
SingleProcess AUTOTUNE benchmarking takes 0.9525 seconds and 0.0005 seconds precompiling
AUTOTUNE convolution(2x256x14x14, 512x256x1x1)
triton_convolution_122 0.0179 ms 100.0%
convolution 0.0254 ms 70.4%
triton_convolution_120 0.0330 ms 54.2%
triton_convolution_124 0.0872 ms 20.5%
triton_convolution_123 0.0964 ms 18.5%
triton_convolution_121 0.1267 ms 14.1%
triton_convolution_118 0.2755 ms 6.5%
triton_convolution_119 0.2881 ms 6.2%
SingleProcess AUTOTUNE benchmarking takes 1.0190 seconds and 0.0005 seconds precompiling
AUTOTUNE convolution(2x512x7x7, 512x512x3x3)
convolution 0.0851 ms 100.0%
triton_convolution_127 0.2815 ms 30.2%
triton_convolution_129 0.3596 ms 23.7%
triton_convolution_131 0.4240 ms 20.1%
triton_convolution_126 0.4842 ms 17.6%
triton_convolution_128 0.7241 ms 11.7%
triton_convolution_130 1.0996 ms 7.7%
triton_convolution_125 1.4565 ms 5.8%
SingleProcess AUTOTUNE benchmarking takes 0.9560 seconds and 0.0004 seconds precompiling
AUTOTUNE addmm(2x1000, 2x512, 512x1000)
addmm 0.0152 ms 100.0%
triton_mm_142 0.0218 ms 70.0%
triton_mm_152 0.0300 ms 50.8%
triton_mm_153 0.0304 ms 50.0%
triton_mm_141 0.0306 ms 49.8%
triton_mm_146 0.0306 ms 49.7%
triton_mm_139 0.0342 ms 44.5%
triton_mm_145 0.0373 ms 40.9%
triton_mm_144 0.0453 ms 33.6%
triton_mm_148 0.0499 ms 30.5%
SingleProcess AUTOTUNE benchmarking takes 1.8361 seconds and 0.0011 seconds precompiling
Model Inference in Python¶
Typically, the shared object generated above is used in a non-Python environment. In PyTorch 2.3,
we added a new API called torch._export.aot_load()
to load the shared library in the Python runtime.
The API follows a structure similar to the torch.jit.load()
API . You need to specify the path
of the shared library and the device where it should be loaded.
Note
In the example above, we specified batch_size=1
for inference and it still functions correctly even though we specified min=2
in
torch.export.export()
.
import os
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
model_so_path = os.path.join(os.getcwd(), "resnet18_pt2.so")
model = torch._export.aot_load(model_so_path, device)
example_inputs = (torch.randn(1, 3, 224, 224, device=device),)
with torch.inference_mode():
output = model(example_inputs)
When to use AOTInductor for Python Runtime¶
One of the requirements for using AOTInductor is that the model shouldn’t have any graph breaks. Once this requirement is met, the primary use case for using AOTInductor Python Runtime is for model deployment using Python. There are mainly two reasons why you would use AOTInductor Python Runtime:
torch._inductor.aot_compile
generates a shared library. This is useful for model versioning for deployments and tracking model performance over time.With
torch.compile()
being a JIT compiler, there is a warmup cost associated with the first compilation. Your deployment needs to account for the compilation time taken for the first inference. With AOTInductor, the compilation is done offline usingtorch.export.export
&torch._indutor.aot_compile
. The deployment would only load the shared library usingtorch._export.aot_load
and run inference.
The section below shows the speedup achieved with AOTInductor for first inference
We define a utility function timed
to measure the time taken for inference
import time
def timed(fn):
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for accurate
# measurement on CUDA enabled devices.
if torch.cuda.is_available():
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
else:
start = time.time()
result = fn()
if torch.cuda.is_available():
end.record()
torch.cuda.synchronize()
else:
end = time.time()
# Measure time taken to execute the function in miliseconds
if torch.cuda.is_available():
duration = start.elapsed_time(end)
else:
duration = (end - start) * 1000
return result, duration
Lets measure the time for first inference using AOTInductor
torch._dynamo.reset()
model = torch._export.aot_load(model_so_path, device)
example_inputs = (torch.randn(1, 3, 224, 224, device=device),)
with torch.inference_mode():
_, time_taken = timed(lambda: model(example_inputs))
print(f"Time taken for first inference for AOTInductor is {time_taken:.2f} ms")
Time taken for first inference for AOTInductor is 2.88 ms
Lets measure the time for first inference using torch.compile
torch._dynamo.reset()
model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)
model.eval()
model = torch.compile(model)
example_inputs = torch.randn(1, 3, 224, 224, device=device)
with torch.inference_mode():
_, time_taken = timed(lambda: model(example_inputs))
print(f"Time taken for first inference for torch.compile is {time_taken:.2f} ms")
Time taken for first inference for torch.compile is 7021.33 ms
We see that there is a drastic speedup in first inference time using AOTInductor compared
to torch.compile
Conclusion¶
In this recipe, we have learned how to effectively use the AOTInductor for Python runtime by
compiling and loading a pretrained ResNet18
model using the torch._inductor.aot_compile
and torch._export.aot_load
APIs. This process demonstrates the practical application of
generating a shared library and running it within a Python environment, even with dynamic shape
considerations and device-specific optimizations. We also looked at the advantage of using
AOTInductor in model deployments, with regards to speed up in first inference time.
Total running time of the script: ( 1 minutes 28.680 seconds)