TorchDynamo and TorchInductor Tutorial ====================================== TorchDynamo and TorchInductor is the latest method to speed up your PyTorch code! Together, TorchDynamo and TorchInductor make PyTorch code run faster by JIT-compiling PyTorch code into optimized kernels, all while requiring minimal code changes. In this tutorial, we cover basic TorchDynamo/TorchInductor usage, and demonstrate the advantages of TorchDynamo/TorchInductor over previous PyTorch compiler solutions, such as `TorchScript `__ and `FX Tracing `__. TorchDynamo JIT compiles arbitrary Python code into `FX graphs `__, which can then be further compiled. TorchDynamo extracts FX graphs by inspecting Python bytecode during runtime and detecting calls to PyTorch operations. Unlike previous attempts at tracing PyTorch ops, TorchDynamo supports arbitrary Python code by breaking the FX graph when encountering unsupported Python features, such as data-dependent control flow. Previous attempts either silently produce incorrect results or raise an error. TorchInductor compiles the FX graphs generated by TorchDynamo into optimized C++/`Triton `__ kernels. **Required pip Dependencies** - ``torch >= 1.14`` - ``torchvision`` - ``numpy`` - ``scipy`` - ``tabulate`` Basic Usage ------------ TorchDynamo/TorchInductor are included in the latest PyTorch nightlies. Running TorchInductor on GPU requires Triton, which is included with the nightly binary. If Triton is still missing, try installing ``torchtriton`` via pip. .. code-block:: python import torch import torch._dynamo as dynamo Arbitrary Python functions can be optimized by TorchDynamo/TorchInductor with ``dynamo.optimize("inductor")``. We can then call the returned optimized function in place of the original function. .. code-block:: python def foo(x, y): a = torch.sin(x) b = torch.cos(x) return a + b opt_foo1 = dynamo.optimize("inductor")(foo) print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10))) Alternatively, we can decorate the function. .. code-block:: python @dynamo.optimize("inductor") def opt_foo2(x, y): a = torch.sin(x) b = torch.cos(x) return a + b print(opt_foo2(torch.randn(10, 10), torch.randn(10, 10))) We can also optimize ``torch.nn.Module`` instances. .. code-block:: python class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.lin = torch.nn.Linear(100, 10) def forward(self, x): return torch.nn.functional.relu(self.lin(x)) mod = MyModule() opt_mod = dynamo.optimize("inductor")(mod) print(opt_mod(torch.randn(10, 100))) Demonstrating Speedups ---------------------- Let's now demonstrate that using TorchDynamo/TorchInductor can speed up real models. We will compare standard eager mode and TorchDynamo/TorchInductor by evaluating and training ResNet-18 on random data. Before we start, we need to define some utility functions. .. code-block:: python # Returns the result of running `fn()` and the time it took for `fn()` to run, # in seconds. We use CUDA events and synchronization for the most accurate # measurements. def timed(fn): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() result = fn() end.record() torch.cuda.synchronize() return result, start.elapsed_time(end) / 1000 # Generates random input and targets data for the model, where `b` is # batch size. def generate_data(b): return ( torch.randn(b, 3, 128, 128).to(torch.float32).cuda(), torch.randint(1000, (b,)).cuda(), ) N_ITERS = 10 from torchvision.models import resnet18 def init_model(): return resnet18().to(torch.float32).cuda() First, let's compare inference. .. code-block:: python def eval(mod, inp): return mod(inp) model = init_model() eval_opt = dynamo.optimize("inductor")(eval) inp = generate_data(16)[0] print("eager:", timed(lambda: eval(model, inp))[1]) print("dynamo:", timed(lambda: eval_opt(model, inp))[1]) Notice that TorchDynamo/TorchInductor takes a lot longer to complete compared to eager. This is because TorchDynamo/TorchInductor compiles the model into optimized kernels as it executes. In our example, the structure of the model doesn't change, and so recompilation is not needed. So if we run our optimized model several more times, we should see a significant improvement compared to eager. .. code-block:: python eager_times = [] dynamo_times = [] for i in range(N_ITERS): inp = generate_data(16)[0] _, eager_time = timed(lambda: eval(model, inp)) eager_times.append(eager_time) print(f"eager eval time {i}: {eager_time}") print("~" * 10) dynamo_times = [] for i in range(N_ITERS): inp = generate_data(16)[0] _, dynamo_time = timed(lambda: eval_opt(model, inp)) dynamo_times.append(dynamo_time) print(f"dynamo eval time {i}: {dynamo_time}") print("~" * 10) import numpy as np eager_med = np.median(eager_times) dynamo_med = np.median(dynamo_times) speedup = eager_med / dynamo_med print(f"(eval) eager median: {eager_med}, dynamo median: {dynamo_med}, speedup: {speedup}x") print("~" * 10) And indeed, we can see that running our model with TorchDynamo/TorchInductor results in a significant speedup. On an NVIDIA A100 GPU, we observe a 2x speedup. Speedup mainly comes from reducing Python overhead and GPU read/writes, and so the observed speedup may vary on factors such as model architecture and batch size. For example, if a model's architecture is simple and the amount of data is large, then the bottleneck would be GPU compute and the observed speedup may be less significant. Now, let's consider comparing training. .. code-block:: python model = init_model() opt = torch.optim.Adam(model.parameters()) def train(mod, data): pred = mod(data[0]) loss = torch.nn.CrossEntropyLoss()(pred, data[1]) loss.backward() eager_times = [] for i in range(N_ITERS): inp = generate_data(16) opt.zero_grad(True) _, eager_time = timed(lambda: train(model, inp)) opt.step() eager_times.append(eager_time) print(f"eager train time {i}: {eager_time}") print("~" * 10) model = init_model() opt = torch.optim.Adam(model.parameters()) train_opt = dynamo.optimize("inductor")(train) dynamo_times = [] for i in range(N_ITERS): inp = generate_data(16) opt.zero_grad(True) _, dynamo_time = timed(lambda: train_opt(model, inp)) opt.step() dynamo_times.append(dynamo_time) print(f"dynamo train time {i}: {dynamo_time}") print("~" * 10) eager_med = np.median(eager_times) dynamo_med = np.median(dynamo_times) speedup = eager_med / dynamo_med print(f"(train) eager median: {eager_med}, dynamo median: {dynamo_med}, speedup: {speedup}x") print("~" * 10) Again, we can see that TorchDynamo/TorchInductor takes longer in the first iteration, as it must compile the model, but afterward, we see significant speedups compared to eager. On an NVIDIA A100 GPU, we observe a 2x speedup. One thing to note is that, as of now, we cannot place optimizer code -- ``opt.zero_grad`` and ``opt.step`` -- inside of an optimized function. The rest of the training loop -- the forward pass and the backward pass -- can be optimized. We are currently working on enabling optimizers to be compatible with TorchDynamo/TorchInductor. Comparison to TorchScript and FX Tracing ---------------------------------------- We have seen that TorchDynamo/TorchInductor can speed up PyTorch code. Why else should we use TorchDynamo/TorchInductor over existing PyTorch compiler solutions, such as TorchScript or FX Tracing? Primarily, the advantage of TorchDynamo/TorchInductor lies in their ability to handle arbitrary Python code with minimal changes to existing code. One case that TorchDynamo/TorchInductor can handle that other compiler solutions struggle with is data-dependent control flow (the line ``if x.sum() < 0:`` below). .. code-block:: python def f1(x, y): if x.sum() < 0: return -y return y # Test that `fn1` and `fn2` return the same result, given # the same arguments `args`. Typically, `fn1` will be an eager function # while `fn2` will be a compiled function (TorchDynamo, TorchScript, or FX graph). def test_fns(fn1, fn2, args): out1 = fn1(*args) out2 = fn2(*args) return torch.allclose(out1, out2) inp1 = torch.randn(5, 5) inp2 = torch.randn(5, 5) TorchScript tracing ``f1`` results in silently incorrect results, since only the actual control flow path is traced. .. code-block:: python traced_f1 = torch.jit.trace(f1, (inp1, inp2)) print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2))) print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2))) FX tracing ``f1`` results in an error due to the presence of data-dependent control flow. .. code-block:: python import traceback as tb try: torch.fx.symbolic_trace(f1) except: tb.print_exc() If we provide a value for ``x`` as we try to FX trace ``f1``, then we run into the same problem as TorchScript tracing, as the data-dependent control flow is removed in the traced function. .. code-block:: python fx_f1 = torch.fx.symbolic_trace(f1, concrete_args={"x": inp1}) print("fx 1, 1:", test_fns(f1, fx_f1, (inp1, inp2))) print("fx 1, 2:", test_fns(f1, fx_f1, (-inp1, inp2))) Now we can see that TorchDynamo/TorchInductor correctly handles data-dependent control flow. .. code-block:: python dynamo_f1 = dynamo.optimize("inductor")(f1) print("dynamo 1, 1:", test_fns(f1, dynamo_f1, (inp1, inp2))) print("dynamo 1, 2:", test_fns(f1, dynamo_f1, (-inp1, inp2))) print("~" * 10) TorchScript scripting can handle data-dependent control flow, but this solution comes with its own set of problems. Namely, TorchScript scripting can require major code changes and will raise errors when unsupported Python is used. In the example below, we forget TorchScript type annotations and we receive a TorchScript error because the input type for argument ``y``, an ``int``, does not match with the default argument type, ``torch.Tensor``. .. code-block:: python def f2(x, y): return x + y inp1 = torch.randn(5, 5) inp2 = 3 script_f2 = torch.jit.script(f2) try: script_f2(inp1, inp2) except: tb.print_exc() However, TorchDynamo/TorchInductor is easily able to handle ``f2``. .. code-block:: python dynamo_f2 = dynamo.optimize("inductor")(f2) print("dynamo 2:", test_fns(f2, dynamo_f2, (inp1, inp2))) print("~" * 10) Another case that TorchDynamo/TorchInductor handles well compared to previous compilers solutions is the usage of non-PyTorch functions. .. code-block:: python import scipy def f3(x): x = x * 2 x = scipy.fft.dct(x.numpy()) x = torch.from_numpy(x) x = x * 2 return x TorchScript tracing treats results from non-PyTorch function calls as constants, and so our results can be silently wrong. .. code-block:: python inp1 = torch.randn(5, 5) inp2 = torch.randn(5, 5) traced_f3 = torch.jit.trace(f3, (inp1,)) print("traced 3:", test_fns(f3, traced_f3, (inp2,))) TorchScript scripting and FX tracing disallow non-PyTorch function calls. .. code-block:: python try: torch.jit.script(f3) except: tb.print_exc() try: torch.fx.symbolic_trace(f3) except: tb.print_exc() In comparison, TorchDynamo/TorchInductor is easily able to handle the non-PyTorch function call. .. code-block:: python dynamo_f3 = dynamo.optimize("inductor")(f3) print("dynamo 3:", test_fns(f3, dynamo_f3, (inp2,))) TorchDynamo and FX Graphs ------------------------- We now cover some topics involving TorchDynamo and FX graphs. In particular, we will demonstrate how to view TorchDynamo's outputted FX graphs, discuss graph breaks and whole-program graph capture, and show how to export graphs. TorchDynamo is responsible for outputting FX graphs from traced Python code. Normally, TorchInductor further compiles the FX graphs into optimized kernels, but TorchDynamo allows for different backends to be used. In order to inspect the FX graphs that TorchDynamo outputs, let us create a custom backend that outputs the FX graph and simply returns the graph's unoptimized forward method. .. code-block:: python from typing import List def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): print("custom backend called with FX graph:") gm.graph.print_tabular() return gm.forward # Reset since we are using a different backend (a custom one). dynamo.reset() opt_model = dynamo.optimize(custom_backend)(init_model()) opt_model(generate_data(16)[0]) Using our custom backend, we can now see how TorchDynamo is able to handle data-dependent control flow. Consider the function below, where the line ``if b.sum() < 0`` is the source of data-dependent control flow. .. code-block:: python def bar(a, b): x = a / (torch.abs(a) + 1) if b.sum() < 0: b = b * -1 return x * b opt_bar = dynamo.optimize(custom_backend)(bar) inp1 = torch.randn(10) inp2 = torch.randn(10) opt_bar(inp1, inp2) opt_bar(inp1, -inp2) The output reveals that TorchDynamo extracted 3 different FX graphs corresponding the following code (order may differ from the output above): 1. ``x = a / (torch.abs(a) + 1)`` 2. ``b = b * -1; return x * b`` 3. ``return x * b`` When TorchDynamo encounters unsupported Python features, such as data-dependent control flow, it breaks the computation graph, lets the default Python interpreter handle the unsupported code, then resumes capturing the graph. Let's investigate by example how TorchDynamo would step through ``bar``. If ``b.sum() < 0``, then TorchDynamo would run graph 1, let Python determine the result of the conditional, then run graph 2. On the other hand, if ``not b.sum() < 0``, then TorchDynamo would run graph 1, let Python determine the result of the conditional, then run graph 3. This highlights a major difference between TorchDynamo and previous PyTorch compiler solutions. When encountering unsupported Python features, previous solutions either raise an error or silently fail. TorchDynamo, on the other hand, will break the computation graph. We can see where TorchDynamo breaks the graph by using ``dynamo.explain``: .. code-block:: python explanation, out_guards, graphs, ops_per_graph, break_reasons, explanation_verbose = dynamo.explain( bar, torch.randn(10), torch.randn(10) ) print(explanation_verbose) In order to maximize speedup, graph breaks should be limited. We can force TorchDynamo to raise an error upon the first graph break encountered by using ``nopython=True``: .. code-block:: python opt_bar = dynamo.optimize("inductor", nopython=True)(bar) try: opt_bar(torch.randn(10), torch.randn(10)) except: tb.print_exc() And below, we demonstrate that TorchDynamo does not break the graph on the model we used above for demonstrating speedups. .. code-block:: python opt_model = dynamo.optimize("inductor", nopython=True)(init_model()) print(opt_model(generate_data(16)[0])) Finally, if we simply want TorchDynamo to output the FX graph for export, we can use ``dynamo.export``. Note that ``dynamo.export``, like ``nopython=True``, raises an error if TorchDynamo breaks the graph. .. code-block:: python try: dynamo.export(bar, torch.randn(10), torch.randn(10)) except: tb.print_exc() model_exp = dynamo.export(init_model(), generate_data(16)[0]) print(model_exp[0](generate_data(16)[0])) Conclusion ------------ In this tutorial, we introduced TorchDynamo and TorchInductor by covering basic usage, demonstrating speedups over eager mode, comparing to previous PyTorch compiler solutions, and briefly investigating interactions with FX graphs. We hope that you will give TorchDynamo/TorchInductor a try!