Shortcuts

torch.compile Tutorial

Author: William Wen

torch.compile is the latest method to speed up your PyTorch code! torch.compile makes PyTorch code run faster by JIT-compiling PyTorch code into optimized kernels, all while requiring minimal code changes.

In this tutorial, we cover basic torch.compile usage, and demonstrate the advantages of torch.compile over previous PyTorch compiler solutions, such as TorchScript and FX Tracing.

Contents

  • Basic Usage

  • Demonstrating Speedups

  • Comparison to TorchScript and FX Tracing

  • TorchDynamo and FX Graphs

  • Conclusion

Required pip Dependencies

  • torch >= 2.0

  • torchvision

  • numpy

  • scipy

  • tabulate

Note: a modern NVIDIA GPU (Volta or Ampere) is recommended for this tutorial.

Basic Usage

torch.compile is included in the latest PyTorch nightlies. Running TorchInductor on GPU requires Triton, which is included with the PyTorch 2.0 nightly binary. If Triton is still missing, try installing torchtriton via pip (pip install torchtriton --extra-index-url "https://download.pytorch.org/whl/nightly/cu117" for CUDA 11.7).

Arbitrary Python functions can be optimized by passing the callable to torch.compile. We can then call the returned optimized function in place of the original function.

import torch

def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(x)
    return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))
tensor([[ 0.1798,  1.2844,  1.1928, -0.6577,  1.2001,  1.2458,  1.3295,  1.0679,
         -0.5651,  1.2174],
        [ 0.4180,  0.8621,  0.4904,  1.2344,  1.3135,  1.0205,  1.1573,  1.4137,
          0.6963,  0.1998],
        [ 0.8186,  1.3638,  1.1769,  1.1858,  0.1157,  0.5435,  1.2660,  0.0452,
          1.4071,  1.2790],
        [ 0.1720,  0.7816,  1.1149,  0.2946,  1.0323,  1.3251,  1.2137,  0.7413,
          1.3243,  1.3410],
        [-0.5763, -0.8852,  0.0997,  0.5206,  1.2721, -0.8215,  1.1307,  0.6280,
          0.9548, -0.4519],
        [ 0.5277,  1.1179,  0.8299,  1.1068,  1.1056, -0.9118,  1.1369,  0.1787,
          1.4070,  1.2695],
        [ 1.1402,  1.3995,  1.3867,  1.4065, -0.1616,  1.3379,  0.5723,  1.4052,
          0.9187,  0.7396],
        [ 1.4136,  0.9392,  1.1151, -0.1921,  0.6561,  1.3462, -0.8778,  0.0633,
         -0.8164,  1.3842],
        [-0.2066,  1.3659,  1.4083,  1.1795,  0.8535, -1.2620,  0.5423, -0.1378,
          0.9112,  1.4097],
        [-0.7677,  0.8684, -0.4427,  1.2253,  1.3321,  1.2597,  1.1444,  0.5330,
          1.2131,  0.6155]])

Alternatively, we can decorate the function.

@torch.compile
def opt_foo2(x, y):
    a = torch.sin(x)
    b = torch.cos(x)
    return a + b
print(opt_foo2(torch.randn(10, 10), torch.randn(10, 10)))
tensor([[ 1.3819, -0.0023,  0.3889, -0.3917,  0.3476,  0.8949, -0.1183,  1.2301,
         -0.7860,  0.5604],
        [ 1.2832,  0.9980, -1.3994,  0.3217,  1.4017,  1.1034,  1.3659,  1.2822,
          0.8117, -1.4010],
        [ 0.0598,  1.3736,  1.1738,  1.4119,  0.9839,  1.2038,  1.4142, -0.9450,
         -0.6137,  1.3810],
        [ 0.1338,  1.4103,  0.9247,  0.5043,  1.4142, -0.2322,  1.0996,  0.8513,
          0.9817, -0.7160],
        [ 0.8685,  1.1138, -0.5582,  0.7401,  1.0455,  0.6791,  1.3404,  0.9539,
          0.6790,  0.5284],
        [ 0.8021,  1.3120, -1.0966,  1.2934, -0.0197, -1.1680,  0.5190,  1.3328,
         -0.6529, -0.8327],
        [ 0.2134, -0.8958,  1.2135, -0.7837,  0.8079,  0.9630,  1.3625,  0.7420,
          1.1731,  1.3631],
        [ 1.3967, -0.9851,  0.8375,  1.0817, -0.0596, -0.4823,  1.0790,  0.1602,
         -0.3130,  1.2681],
        [ 1.3737,  1.0730,  1.3989,  0.8458, -1.1168,  0.7682,  1.2389,  0.8003,
          1.2728,  1.3231],
        [-0.3265,  1.1508, -0.0631,  1.3986,  1.0830,  0.8830,  1.1943, -0.0505,
         -0.8125,  0.9741]])

We can also optimize torch.nn.Module instances.

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x):
        return torch.nn.functional.relu(self.lin(x))

mod = MyModule()
opt_mod = torch.compile(mod)
print(opt_mod(torch.randn(10, 100)))
tensor([[0.6872, -0.0000, 0.6799, 0.8060, -0.0000, -0.0000, 0.4238, 0.1070, 0.1104,
         -0.0000],
        [-0.0000, 0.0631, -0.0000, 1.3911, 0.0685, 0.0367, -0.0000, 0.4435, -0.0000,
         -0.0000],
        [-0.0000, 0.8475, -0.0000, -0.0000, 1.2670, 0.0052, -0.0000, -0.0000, 1.4696,
         0.7722],
        [-0.0000, 0.3604, 0.7189, 0.1741, -0.0000, 0.0089, 0.4830, -0.0000, -0.0000,
         -0.0000],
        [0.2388, -0.0000, 0.3756, 0.1995, -0.0000, 0.3563, 1.1686, -0.0000, -0.0000,
         -0.0000],
        [-0.0000, 0.3833, 0.2159, -0.0000, 0.0479, -0.0000, -0.0000, -0.0000, 0.3953,
         0.7013],
        [-0.0000, 0.6361, -0.0000, -0.0000, 1.6845, 0.3640, -0.0000, 0.0496, 0.7504,
         0.4540],
        [-0.0000, -0.0000, 0.8502, -0.0000, 0.1043, -0.0000, -0.0000, 0.1842, 0.0391,
         0.3200],
        [-0.0000, -0.0000, -0.0000, 0.7361, -0.0000, -0.0000, -0.0000, 0.4034, -0.0000,
         1.0752],
        [0.7645, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, 1.0489,
         -0.0000]], grad_fn=<CompiledFunctionBackward>)

Demonstrating Speedups

Let’s now demonstrate that using torch.compile can speed up real models. We will compare standard eager mode and torch.compile by evaluating and training ResNet-18 on random data.

Before we start, we need to define some utility functions.

# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data(b):
    return (
        torch.randn(b, 3, 128, 128).to(torch.float32).cuda(),
        torch.randint(1000, (b,)).cuda(),
    )

N_ITERS = 10

from torchvision.models import resnet18
def init_model():
    return resnet18().to(torch.float32).cuda()

First, let’s compare inference.

Note that in the call to torch.compile, we have have the additional mode kwarg, which we will discuss below.

def evaluate(mod, inp):
    return mod(inp)

model = init_model()

# Reset since we are using a different mode.
import torch._dynamo
torch._dynamo.reset()

evaluate_opt = torch.compile(evaluate, mode="reduce-overhead")

inp = generate_data(16)[0]
print("eager:", timed(lambda: evaluate(model, inp))[1])
print("compile:", timed(lambda: evaluate_opt(model, inp))[1])
eager: 1.95551904296875
compile: 10.07233984375

Notice that torch.compile takes a lot longer to complete compared to eager. This is because torch.compile compiles the model into optimized kernels as it executes. In our example, the structure of the model doesn’t change, and so recompilation is not needed. So if we run our optimized model several more times, we should see a significant improvement compared to eager.

eager_times = []
compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    _, eager_time = timed(lambda: evaluate(model, inp))
    eager_times.append(eager_time)
    print(f"eager eval time {i}: {eager_time}")

print("~" * 10)

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    _, compile_time = timed(lambda: evaluate_opt(model, inp))
    compile_times.append(compile_time)
    print(f"compile eval time {i}: {compile_time}")
print("~" * 10)

import numpy as np
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)
eager eval time 0: 0.009356320381164551
eager eval time 1: 0.008964096069335938
eager eval time 2: 0.008949248313903809
eager eval time 3: 0.008943615913391113
eager eval time 4: 0.008949151992797852
eager eval time 5: 0.008933343887329102
eager eval time 6: 0.008934271812438966
eager eval time 7: 0.008949376106262208
eager eval time 8: 0.008939295768737794
eager eval time 9: 0.008945919990539552
~~~~~~~~~~
compile eval time 0: 0.007935200214385986
compile eval time 1: 0.007463263988494873
compile eval time 2: 0.007408415794372559
compile eval time 3: 0.007355264186859131
compile eval time 4: 0.0073528318405151364
compile eval time 5: 0.007414783954620361
compile eval time 6: 0.007352287769317627
compile eval time 7: 0.007393663883209229
compile eval time 8: 0.007352287769317627
compile eval time 9: 0.00738918399810791
~~~~~~~~~~
(eval) eager median: 0.008947535991668702, compile median: 0.0073914239406585695, speedup: 1.2105294004921443x
~~~~~~~~~~

And indeed, we can see that running our model with torch.compile results in a significant speedup. Speedup mainly comes from reducing Python overhead and GPU read/writes, and so the observed speedup may vary on factors such as model architecture and batch size. For example, if a model’s architecture is simple and the amount of data is large, then the bottleneck would be GPU compute and the observed speedup may be less significant.

You may also see different speedup results depending on the chosen mode kwarg. Since our model and data are small, we want to reduce overhead as much as possible, and so we chose "reduce-overhead". For your own models, you may need to experiment with different modes to maximize speedup. You can read more about modes here.

For general PyTorch benchmarking, you can try using torch.utils.benchmark instead of the timed function we defined above. We wrote our own timing function in this tutorial to show torch.compile’s compilation latency.

Now, let’s consider comparing training.

model = init_model()
opt = torch.optim.Adam(model.parameters())

def train(mod, data):
    opt.zero_grad(True)
    pred = mod(data[0])
    loss = torch.nn.CrossEntropyLoss()(pred, data[1])
    loss.backward()
    opt.step()

eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, eager_time = timed(lambda: train(model, inp))
    eager_times.append(eager_time)
    print(f"eager train time {i}: {eager_time}")
print("~" * 10)

model = init_model()
opt = torch.optim.Adam(model.parameters())
train_opt = torch.compile(train, mode="reduce-overhead")

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, compile_time = timed(lambda: train_opt(model, inp))
    compile_times.append(compile_time)
    print(f"compile train time {i}: {compile_time}")
print("~" * 10)

eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
print(f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)
eager train time 0: 0.41696514892578124
eager train time 1: 0.026487199783325196
eager train time 2: 0.026447872161865234
eager train time 3: 0.026425567626953125
eager train time 4: 0.026419200897216798
eager train time 5: 0.02647859191894531
eager train time 6: 0.026426143646240234
eager train time 7: 0.026437631607055666
eager train time 8: 0.025902912139892577
eager train time 9: 0.02117238426208496
~~~~~~~~~~
compile train time 0: 22.5079921875
compile train time 1: 0.021698528289794922
compile train time 2: 0.02067647933959961
compile train time 3: 0.020686656951904296
compile train time 4: 0.020949024200439453
compile train time 5: 0.020941055297851563
compile train time 6: 0.02081376075744629
compile train time 7: 0.020818016052246095
compile train time 8: 0.020946943283081054
compile train time 9: 0.02089593505859375
~~~~~~~~~~
(train) eager median: 0.02643188762664795, compile median: 0.020918495178222654, speedup: 1.2635654429944392x
~~~~~~~~~~

Again, we can see that torch.compile takes longer in the first iteration, as it must compile the model, but in subsequent iterations, we see significant speedups compared to eager.

Comparison to TorchScript and FX Tracing

We have seen that torch.compile can speed up PyTorch code. Why else should we use torch.compile over existing PyTorch compiler solutions, such as TorchScript or FX Tracing? Primarily, the advantage of torch.compile lies in its ability to handle arbitrary Python code with minimal changes to existing code.

One case that torch.compile can handle that other compiler solutions struggle with is data-dependent control flow (the if x.sum() < 0: line below).

def f1(x, y):
    if x.sum() < 0:
        return -y
    return y

# Test that `fn1` and `fn2` return the same result, given
# the same arguments `args`. Typically, `fn1` will be an eager function
# while `fn2` will be a compiled function (torch.compile, TorchScript, or FX graph).
def test_fns(fn1, fn2, args):
    out1 = fn1(*args)
    out2 = fn2(*args)
    return torch.allclose(out1, out2)

inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)

TorchScript tracing f1 results in silently incorrect results, since only the actual control flow path is traced.

traced_f1 = torch.jit.trace(f1, (inp1, inp2))
print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2)))
print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2)))
/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py:254: TracerWarning:

Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

traced 1, 1: True
traced 1, 2: False

FX tracing f1 results in an error due to the presence of data-dependent control flow.

import traceback as tb
try:
    torch.fx.symbolic_trace(f1)
except:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 284, in <module>
    torch.fx.symbolic_trace(f1)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 1109, in symbolic_trace
    graph = tracer.trace(root, concrete_args)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 778, in trace
    (self.create_arg(fn(*args)),),
  File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 254, in f1
    if x.sum() < 0:
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 413, in __bool__
    return self.tracer.to_bool(self)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 276, in to_bool
    raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow

If we provide a value for x as we try to FX trace f1, then we run into the same problem as TorchScript tracing, as the data-dependent control flow is removed in the traced function.

fx_f1 = torch.fx.symbolic_trace(f1, concrete_args={"x": inp1})
print("fx 1, 1:", test_fns(f1, fx_f1, (inp1, inp2)))
print("fx 1, 2:", test_fns(f1, fx_f1, (-inp1, inp2)))
/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py:602: UserWarning:

Was not able to add assertion to guarantee correct input x to specialized function. It is up to the user to make sure that your inputs match the inputs you specialized the function with.

fx 1, 1: True
fx 1, 2: False

Now we can see that torch.compile correctly handles data-dependent control flow.

# Reset since we are using a different mode.
torch._dynamo.reset()

compile_f1 = torch.compile(f1)
print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2)))
print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2)))
print("~" * 10)
compile 1, 1: True
compile 1, 2: True
~~~~~~~~~~

TorchScript scripting can handle data-dependent control flow, but this solution comes with its own set of problems. Namely, TorchScript scripting can require major code changes and will raise errors when unsupported Python is used.

In the example below, we forget TorchScript type annotations and we receive a TorchScript error because the input type for argument y, an int, does not match with the default argument type, torch.Tensor.

def f2(x, y):
    return x + y

inp1 = torch.randn(5, 5)
inp2 = 3

script_f2 = torch.jit.script(f2)
try:
    script_f2(inp1, inp2)
except:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 327, in <module>
    script_f2(inp1, inp2)
RuntimeError: f2() Expected a value of type 'Tensor (inferred)' for argument 'y' but instead found type 'int'.
Inferred 'y' to be of type 'Tensor' because it was not annotated with an explicit type.
Position: 1
Value: 3
Declaration: f2(Tensor x, Tensor y) -> Tensor
Cast error details: Unable to cast 3 to Tensor

However, torch.compile is easily able to handle f2.

compile_f2 = torch.compile(f2)
print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2)))
print("~" * 10)
compile 2: True
~~~~~~~~~~

Another case that torch.compile handles well compared to previous compilers solutions is the usage of non-PyTorch functions.

import scipy
def f3(x):
    x = x * 2
    x = scipy.fft.dct(x.numpy())
    x = torch.from_numpy(x)
    x = x * 2
    return x

TorchScript tracing treats results from non-PyTorch function calls as constants, and so our results can be silently wrong.

inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)
traced_f3 = torch.jit.trace(f3, (inp1,))
print("traced 3:", test_fns(f3, traced_f3, (inp2,)))
/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py:345: TracerWarning:

Converting a tensor to a NumPy array might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py:346: TracerWarning:

torch.from_numpy results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.

traced 3: False

TorchScript scripting and FX tracing disallow non-PyTorch function calls.

try:
    torch.jit.script(f3)
except:
    tb.print_exc()

try:
    torch.fx.symbolic_trace(f3)
except:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 363, in <module>
    torch.jit.script(f3)
  File "/opt/conda/lib/python3.10/site-packages/torch/jit/_script.py", line 1341, in script
    fn = torch._C._jit_script_compile(
  File "/opt/conda/lib/python3.10/site-packages/torch/_jit_internal.py", line 1198, in _try_get_dispatched_fn
    return boolean_dispatched.get(fn)
  File "/opt/conda/lib/python3.10/weakref.py", line 453, in get
    return self.data.get(ref(key),default)
TypeError: cannot create weak reference to 'uarray._Function' object
Traceback (most recent call last):
  File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 368, in <module>
    torch.fx.symbolic_trace(f3)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 1109, in symbolic_trace
    graph = tracer.trace(root, concrete_args)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 778, in trace
    (self.create_arg(fn(*args)),),
  File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 345, in f3
    x = scipy.fft.dct(x.numpy())
  File "/opt/conda/lib/python3.10/site-packages/scipy/fft/_backend.py", line 25, in __ua_function__
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/scipy/fft/_pocketfft/realtransforms.py", line 19, in _r2r
    tmp = _asfarray(x)
  File "/opt/conda/lib/python3.10/site-packages/scipy/fft/_pocketfft/helper.py", line 89, in _asfarray
    if x.dtype == np.float16:
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 518, in impl
    return tracer.create_proxy('call_function', target, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 151, in create_proxy
    args_ = self.create_arg(args)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 373, in create_arg
    return super().create_arg(a)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 239, in create_arg
    return type(a)(self.create_arg(elem) for elem in a)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 239, in <genexpr>
    return type(a)(self.create_arg(elem) for elem in a)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 373, in create_arg
    return super().create_arg(a)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 267, in create_arg
    raise NotImplementedError(f"argument of type: {type(a)}")
NotImplementedError: argument of type: <class 'type'>

In comparison, torch.compile is easily able to handle the non-PyTorch function call.

compile_f3 = torch.compile(f3)
print("compile 3:", test_fns(f3, compile_f3, (inp2,)))
compile 3: True

TorchDynamo and FX Graphs

One important component of torch.compile is TorchDynamo. TorchDynamo is responsible for JIT compiling arbitrary Python code into FX graphs, which can then be further optimized. TorchDynamo extracts FX graphs by analyzing Python bytecode during runtime and detecting calls to PyTorch operations.

Normally, TorchInductor, another component of torch.compile, further compiles the FX graphs into optimized kernels, but TorchDynamo allows for different backends to be used. In order to inspect the FX graphs that TorchDynamo outputs, let us create a custom backend that outputs the FX graph and simply returns the graph’s unoptimized forward method.

from typing import List
def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("custom backend called with FX graph:")
    gm.graph.print_tabular()
    return gm.forward

# Reset since we are using a different backend.
torch._dynamo.reset()

opt_model = torch.compile(init_model(), backend=custom_backend)
opt_model(generate_data(16)[0])
custom backend called with FX graph:
opcode         name                        target                                                      args                                             kwargs
-------------  --------------------------  ----------------------------------------------------------  -----------------------------------------------  --------
placeholder    x                           x                                                           ()                                               {}
call_module    self_conv1                  self_conv1                                                  (x,)                                             {}
call_module    self_bn1                    self_bn1                                                    (self_conv1,)                                    {}
call_module    self_relu                   self_relu                                                   (self_bn1,)                                      {}
call_module    self_maxpool                self_maxpool                                                (self_relu,)                                     {}
call_module    self_layer1_0_conv1         self_layer1_0_conv1                                         (self_maxpool,)                                  {}
call_module    self_layer1_0_bn1           self_layer1_0_bn1                                           (self_layer1_0_conv1,)                           {}
call_module    self_layer1_0_relu          self_layer1_0_relu                                          (self_layer1_0_bn1,)                             {}
call_module    self_layer1_0_conv2         self_layer1_0_conv2                                         (self_layer1_0_relu,)                            {}
call_module    self_layer1_0_bn2           self_layer1_0_bn2                                           (self_layer1_0_conv2,)                           {}
call_function  iadd                        <built-in function iadd>                                    (self_layer1_0_bn2, self_maxpool)                {}
call_module    self_layer1_0_relu_1        self_layer1_0_relu                                          (iadd,)                                          {}
call_module    self_layer1_1_conv1         self_layer1_1_conv1                                         (self_layer1_0_relu_1,)                          {}
call_module    self_layer1_1_bn1           self_layer1_1_bn1                                           (self_layer1_1_conv1,)                           {}
call_module    self_layer1_1_relu          self_layer1_1_relu                                          (self_layer1_1_bn1,)                             {}
call_module    self_layer1_1_conv2         self_layer1_1_conv2                                         (self_layer1_1_relu,)                            {}
call_module    self_layer1_1_bn2           self_layer1_1_bn2                                           (self_layer1_1_conv2,)                           {}
call_function  iadd_1                      <built-in function iadd>                                    (self_layer1_1_bn2, self_layer1_0_relu_1)        {}
call_module    self_layer1_1_relu_1        self_layer1_1_relu                                          (iadd_1,)                                        {}
call_module    self_layer2_0_conv1         self_layer2_0_conv1                                         (self_layer1_1_relu_1,)                          {}
call_module    self_layer2_0_bn1           self_layer2_0_bn1                                           (self_layer2_0_conv1,)                           {}
call_module    self_layer2_0_relu          self_layer2_0_relu                                          (self_layer2_0_bn1,)                             {}
call_module    self_layer2_0_conv2         self_layer2_0_conv2                                         (self_layer2_0_relu,)                            {}
call_module    self_layer2_0_bn2           self_layer2_0_bn2                                           (self_layer2_0_conv2,)                           {}
call_module    self_layer2_0_downsample_0  self_layer2_0_downsample_0                                  (self_layer1_1_relu_1,)                          {}
call_module    self_layer2_0_downsample_1  self_layer2_0_downsample_1                                  (self_layer2_0_downsample_0,)                    {}
call_function  iadd_2                      <built-in function iadd>                                    (self_layer2_0_bn2, self_layer2_0_downsample_1)  {}
call_module    self_layer2_0_relu_1        self_layer2_0_relu                                          (iadd_2,)                                        {}
call_module    self_layer2_1_conv1         self_layer2_1_conv1                                         (self_layer2_0_relu_1,)                          {}
call_module    self_layer2_1_bn1           self_layer2_1_bn1                                           (self_layer2_1_conv1,)                           {}
call_module    self_layer2_1_relu          self_layer2_1_relu                                          (self_layer2_1_bn1,)                             {}
call_module    self_layer2_1_conv2         self_layer2_1_conv2                                         (self_layer2_1_relu,)                            {}
call_module    self_layer2_1_bn2           self_layer2_1_bn2                                           (self_layer2_1_conv2,)                           {}
call_function  iadd_3                      <built-in function iadd>                                    (self_layer2_1_bn2, self_layer2_0_relu_1)        {}
call_module    self_layer2_1_relu_1        self_layer2_1_relu                                          (iadd_3,)                                        {}
call_module    self_layer3_0_conv1         self_layer3_0_conv1                                         (self_layer2_1_relu_1,)                          {}
call_module    self_layer3_0_bn1           self_layer3_0_bn1                                           (self_layer3_0_conv1,)                           {}
call_module    self_layer3_0_relu          self_layer3_0_relu                                          (self_layer3_0_bn1,)                             {}
call_module    self_layer3_0_conv2         self_layer3_0_conv2                                         (self_layer3_0_relu,)                            {}
call_module    self_layer3_0_bn2           self_layer3_0_bn2                                           (self_layer3_0_conv2,)                           {}
call_module    self_layer3_0_downsample_0  self_layer3_0_downsample_0                                  (self_layer2_1_relu_1,)                          {}
call_module    self_layer3_0_downsample_1  self_layer3_0_downsample_1                                  (self_layer3_0_downsample_0,)                    {}
call_function  iadd_4                      <built-in function iadd>                                    (self_layer3_0_bn2, self_layer3_0_downsample_1)  {}
call_module    self_layer3_0_relu_1        self_layer3_0_relu                                          (iadd_4,)                                        {}
call_module    self_layer3_1_conv1         self_layer3_1_conv1                                         (self_layer3_0_relu_1,)                          {}
call_module    self_layer3_1_bn1           self_layer3_1_bn1                                           (self_layer3_1_conv1,)                           {}
call_module    self_layer3_1_relu          self_layer3_1_relu                                          (self_layer3_1_bn1,)                             {}
call_module    self_layer3_1_conv2         self_layer3_1_conv2                                         (self_layer3_1_relu,)                            {}
call_module    self_layer3_1_bn2           self_layer3_1_bn2                                           (self_layer3_1_conv2,)                           {}
call_function  iadd_5                      <built-in function iadd>                                    (self_layer3_1_bn2, self_layer3_0_relu_1)        {}
call_module    self_layer3_1_relu_1        self_layer3_1_relu                                          (iadd_5,)                                        {}
call_module    self_layer4_0_conv1         self_layer4_0_conv1                                         (self_layer3_1_relu_1,)                          {}
call_module    self_layer4_0_bn1           self_layer4_0_bn1                                           (self_layer4_0_conv1,)                           {}
call_module    self_layer4_0_relu          self_layer4_0_relu                                          (self_layer4_0_bn1,)                             {}
call_module    self_layer4_0_conv2         self_layer4_0_conv2                                         (self_layer4_0_relu,)                            {}
call_module    self_layer4_0_bn2           self_layer4_0_bn2                                           (self_layer4_0_conv2,)                           {}
call_module    self_layer4_0_downsample_0  self_layer4_0_downsample_0                                  (self_layer3_1_relu_1,)                          {}
call_module    self_layer4_0_downsample_1  self_layer4_0_downsample_1                                  (self_layer4_0_downsample_0,)                    {}
call_function  iadd_6                      <built-in function iadd>                                    (self_layer4_0_bn2, self_layer4_0_downsample_1)  {}
call_module    self_layer4_0_relu_1        self_layer4_0_relu                                          (iadd_6,)                                        {}
call_module    self_layer4_1_conv1         self_layer4_1_conv1                                         (self_layer4_0_relu_1,)                          {}
call_module    self_layer4_1_bn1           self_layer4_1_bn1                                           (self_layer4_1_conv1,)                           {}
call_module    self_layer4_1_relu          self_layer4_1_relu                                          (self_layer4_1_bn1,)                             {}
call_module    self_layer4_1_conv2         self_layer4_1_conv2                                         (self_layer4_1_relu,)                            {}
call_module    self_layer4_1_bn2           self_layer4_1_bn2                                           (self_layer4_1_conv2,)                           {}
call_function  iadd_7                      <built-in function iadd>                                    (self_layer4_1_bn2, self_layer4_0_relu_1)        {}
call_module    self_layer4_1_relu_1        self_layer4_1_relu                                          (iadd_7,)                                        {}
call_module    self_avgpool                self_avgpool                                                (self_layer4_1_relu_1,)                          {}
call_function  flatten                     <built-in method flatten of type object at 0x7f0da2de1540>  (self_avgpool, 1)                                {}
call_module    self_fc                     self_fc                                                     (flatten,)                                       {}
output         output                      output                                                      ((self_fc,),)                                    {}

tensor([[-1.3284, -0.5548, -0.2535,  ...,  0.0955, -0.5363, -0.3258],
        [-1.3473, -0.4153, -0.1871,  ...,  0.0867, -0.4691, -0.0939],
        [-1.0568, -0.3973, -0.1293,  ...,  0.1281, -0.6533, -0.5252],
        ...,
        [-0.9553, -0.1680, -0.1153,  ...,  0.0278, -0.7660, -0.4412],
        [-1.1365, -0.3513, -0.3417,  ..., -0.1511, -0.8385, -0.4293],
        [-1.2228, -0.3285, -0.4461,  ...,  0.0856, -0.3251, -0.3988]],
       device='cuda:0', grad_fn=<AddmmBackward0>)

Using our custom backend, we can now see how TorchDynamo is able to handle data-dependent control flow. Consider the function below, where the line if b.sum() < 0 is the source of data-dependent control flow.

def bar(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b

opt_bar = torch.compile(bar, backend=custom_backend)
inp1 = torch.randn(10)
inp2 = torch.randn(10)
opt_bar(inp1, inp2)
opt_bar(inp1, -inp2)
custom backend called with FX graph:
opcode         name     target                                                  args              kwargs
-------------  -------  ------------------------------------------------------  ----------------  --------
placeholder    a        a                                                       ()                {}
placeholder    b        b                                                       ()                {}
call_function  abs_1    <built-in method abs of type object at 0x7f0da2de1540>  (a,)              {}
call_function  add      <built-in function add>                                 (abs_1, 1)        {}
call_function  truediv  <built-in function truediv>                             (a, add)          {}
call_method    sum_1    sum                                                     (b,)              {}
call_function  lt       <built-in function lt>                                  (sum_1, 0)        {}
output         output   output                                                  ((truediv, lt),)  {}
custom backend called with FX graph:
opcode         name    target                   args         kwargs
-------------  ------  -----------------------  -----------  --------
placeholder    b       b                        ()           {}
placeholder    x       x                        ()           {}
call_function  mul     <built-in function mul>  (b, -1)      {}
call_function  mul_1   <built-in function mul>  (x, mul)     {}
output         output  output                   ((mul_1,),)  {}
custom backend called with FX graph:
opcode         name    target                   args       kwargs
-------------  ------  -----------------------  ---------  --------
placeholder    b       b                        ()         {}
placeholder    x       x                        ()         {}
call_function  mul     <built-in function mul>  (x, b)     {}
output         output  output                   ((mul,),)  {}

tensor([-0.1855, -0.0932,  0.0438, -0.1475, -0.0601,  0.3393, -1.4583,  0.0015,
         0.7143,  0.0911])

The output reveals that TorchDynamo extracted 3 different FX graphs corresponding the following code (order may differ from the output above):

  1. x = a / (torch.abs(a) + 1)

  2. b = b * -1; return x * b

  3. return x * b

When TorchDynamo encounters unsupported Python features, such as data-dependent control flow, it breaks the computation graph, lets the default Python interpreter handle the unsupported code, then resumes capturing the graph.

Let’s investigate by example how TorchDynamo would step through bar. If b.sum() < 0, then TorchDynamo would run graph 1, let Python determine the result of the conditional, then run graph 2. On the other hand, if not b.sum() < 0, then TorchDynamo would run graph 1, let Python determine the result of the conditional, then run graph 3.

This highlights a major difference between TorchDynamo and previous PyTorch compiler solutions. When encountering unsupported Python features, previous solutions either raise an error or silently fail. TorchDynamo, on the other hand, will break the computation graph.

We can see where TorchDynamo breaks the graph by using torch._dynamo.explain:

# Reset since we are using a different backend.
torch._dynamo.reset()
explanation, out_guards, graphs, ops_per_graph, break_reasons, explanation_verbose = torch._dynamo.explain(
    bar, torch.randn(10), torch.randn(10)
)
print(explanation_verbose)
Dynamo produced 2 graphs with 1 graph break and 6 ops
 Break reasons:

1. generic_jump TensorVariable()
  File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 414, in bar
    if b.sum() < 0:

2. return_value
  File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 416, in <graph break in bar>
    return x * b

TorchDynamo compilation metrics:
Function                        Runtimes (s)
------------------------------  --------------
_compile                        0.0127, 0.0057
OutputGraph.call_user_compiler  0.0001, 0.0000

In order to maximize speedup, graph breaks should be limited. We can force TorchDynamo to raise an error upon the first graph break encountered by using fullgraph=True:

opt_bar = torch.compile(bar, fullgraph=True)
try:
    opt_bar(torch.randn(10), torch.randn(10))
except:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 464, in <module>
    opt_bar(torch.randn(10), torch.randn(10))
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
    return _compile(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
    out_code = transform_code_object(code, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
    super().run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 327, in inner
    unimplemented(f"generic_jump {typestr(value)}")
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 71, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: generic_jump TensorVariable()

from user code:
   File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 414, in bar
    if b.sum() < 0:

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

And below, we demonstrate that TorchDynamo does not break the graph on the model we used above for demonstrating speedups.

opt_model = torch.compile(init_model(), fullgraph=True)
print(opt_model(generate_data(16)[0]))
tensor([[-0.0031, -0.6601,  0.7708,  ...,  0.5529, -0.0965,  0.2389],
        [ 0.5157, -0.6799,  0.6711,  ...,  0.4190, -0.0511,  0.3566],
        [ 0.4879, -0.5442,  0.6752,  ...,  0.2638, -0.2817,  0.6400],
        ...,
        [ 0.0579, -0.2516,  0.5776,  ...,  0.2413, -0.0513,  0.4131],
        [ 0.2299, -0.3535,  0.3686,  ...,  0.4281, -0.1155,  0.4612],
        [ 0.1491, -0.5038,  0.5811,  ...,  0.4087, -0.0058,  0.4349]],
       device='cuda:0', grad_fn=<CompiledFunctionBackward>)

Finally, if we simply want TorchDynamo to output the FX graph for export, we can use torch._dynamo.export. Note that torch._dynamo.export, like fullgraph=True, raises an error if TorchDynamo breaks the graph.

try:
    torch._dynamo.export(bar, torch.randn(10), torch.randn(10))
except:
    tb.print_exc()

model_exp = torch._dynamo.export(init_model(), generate_data(16)[0])
print(model_exp[0](generate_data(16)[0]))
Traceback (most recent call last):
  File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 481, in <module>
    torch._dynamo.export(bar, torch.randn(10), torch.randn(10))
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 601, in export
    result_traced = opt_f(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
    return _compile(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
    out_code = transform_code_object(code, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
    super().run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 327, in inner
    unimplemented(f"generic_jump {typestr(value)}")
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 71, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: generic_jump TensorVariable()

from user code:
   File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 414, in bar
    if b.sum() < 0:

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

tensor([[ 0.1920, -0.1323,  0.9817,  ...,  0.5318, -0.5725,  0.6707],
        [ 0.1287, -0.2137,  0.7767,  ...,  0.4573, -0.3710,  0.4077],
        [ 0.4535, -0.0487,  0.8373,  ...,  0.5423, -0.3619,  0.3945],
        ...,
        [ 0.1012, -0.2960,  0.6673,  ...,  0.4942, -0.4441,  0.5463],
        [ 0.0787, -0.0200,  0.7137,  ...,  0.5139, -0.4842,  0.4904],
        [ 0.2168, -0.1528,  0.7756,  ...,  0.7245, -0.3392,  0.6314]],
       device='cuda:0', grad_fn=<AddmmBackward0>)

Conclusion

In this tutorial, we introduced torch.compile by covering basic usage, demonstrating speedups over eager mode, comparing to previous PyTorch compiler solutions, and briefly investigating TorchDynamo and its interactions with FX graphs. We hope that you will give torch.compile a try!

Total running time of the script: ( 1 minutes 1.284 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources