.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "intermediate/torch_compile_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_intermediate_torch_compile_tutorial.py: Introduction to ``torch.compile`` ================================= **Author:** William Wen .. GENERATED FROM PYTHON SOURCE LINES 10-33 ``torch.compile`` is the latest method to speed up your PyTorch code! ``torch.compile`` makes PyTorch code run faster by JIT-compiling PyTorch code into optimized kernels, all while requiring minimal code changes. In this tutorial, we cover basic ``torch.compile`` usage, and demonstrate the advantages of ``torch.compile`` over previous PyTorch compiler solutions, such as `TorchScript `__ and `FX Tracing `__. **Contents** .. contents:: :local: **Required pip Dependencies** - ``torch >= 2.0`` - ``torchvision`` - ``numpy`` - ``scipy`` - ``tabulate`` .. GENERATED FROM PYTHON SOURCE LINES 35-37 NOTE: a modern NVIDIA GPU (H100, A100, or V100) is recommended for this tutorial in order to reproduce the speedup numbers shown below and documented elsewhere. .. GENERATED FROM PYTHON SOURCE LINES 37-53 .. code-block:: default import torch import warnings gpu_ok = False if torch.cuda.is_available(): device_cap = torch.cuda.get_device_capability() if device_cap in ((7, 0), (8, 0), (9, 0)): gpu_ok = True if not gpu_ok: warnings.warn( "GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower " "than expected." ) .. rst-class:: sphx-glr-script-out .. code-block:: none /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:48: UserWarning: GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower than expected. .. GENERATED FROM PYTHON SOURCE LINES 54-66 Basic Usage ------------ ``torch.compile`` is included in the latest PyTorch. Running TorchInductor on GPU requires Triton, which is included with the PyTorch 2.0 nightly binary. If Triton is still missing, try installing ``torchtriton`` via pip (``pip install torchtriton --extra-index-url "https://download.pytorch.org/whl/nightly/cu117"`` for CUDA 11.7). Arbitrary Python functions can be optimized by passing the callable to ``torch.compile``. We can then call the returned optimized function in place of the original function. .. GENERATED FROM PYTHON SOURCE LINES 66-74 .. code-block:: default def foo(x, y): a = torch.sin(x) b = torch.cos(y) return a + b opt_foo1 = torch.compile(foo) print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10))) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[ 1.6850, 1.9924, 1.7090, 0.0034, 1.1414, -0.1822, 0.4861, -0.0536, -0.2252, 1.9398], [ 0.3693, -0.0695, 0.1748, 0.3436, 0.1939, 1.5721, 1.9882, -0.2235, 0.3161, 1.2642], [ 0.2480, 1.8793, 1.7152, 1.6772, 1.8881, 1.4748, 1.3466, 1.7763, 0.7469, 1.0407], [-0.1121, 1.6015, -0.0188, 0.2128, 0.5218, 1.9838, 0.8185, 0.5093, -0.3603, 0.1793], [-1.7890, 1.7532, -0.4040, 0.1222, -0.0029, 1.7975, -0.3877, 0.5123, 0.1673, 0.1330], [ 1.0627, 0.9609, 0.1019, 1.8814, 0.1142, -0.2338, -0.9621, 0.7631, 0.6506, 0.1853], [ 0.4584, 1.7648, -0.0444, 1.9610, 1.5884, 0.7353, 1.2190, 1.3662, 1.0938, -0.1587], [-0.7502, 1.6640, 0.3495, 1.3496, 0.8187, 1.1719, 0.5820, 0.1498, 0.0885, 0.1036], [ 0.3961, 0.6043, -0.0861, -0.3371, 0.8622, 1.4341, 1.2988, 0.5023, 0.3074, 0.1277], [ 0.9748, 0.4117, 1.2616, 1.6314, 0.4693, 0.4092, 0.0401, 1.1196, 1.2458, 1.3280]]) .. GENERATED FROM PYTHON SOURCE LINES 75-76 Alternatively, we can decorate the function. .. GENERATED FROM PYTHON SOURCE LINES 76-84 .. code-block:: default @torch.compile def opt_foo2(x, y): a = torch.sin(x) b = torch.cos(y) return a + b print(opt_foo2(torch.randn(10, 10), torch.randn(10, 10))) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[ 0.5360, 0.1697, -0.0561, 0.1890, -0.1310, 1.2276, 1.1739, 0.1944, -0.1561, 1.6990], [ 1.0421, 1.9472, 0.2682, 0.2701, 1.3346, 0.7651, 1.0897, 1.1730, 0.6161, 0.9223], [ 1.5756, 1.5294, 0.0112, -0.1522, -0.7674, 1.8515, -0.2443, 0.3696, 0.2693, 0.8735], [-0.3701, 1.1190, 1.4164, 1.8648, 1.2080, 0.0732, 1.5274, 0.6868, 1.2440, 1.0715], [-1.2454, -0.0159, 0.4315, 0.1317, 1.0530, -1.0603, -0.0532, 0.6661, 1.7101, -0.2076], [-0.7091, 0.7824, 1.7161, 1.2750, 0.6368, 1.2488, 0.4897, 1.2429, 1.3409, 1.3735], [ 0.8345, 0.0653, 0.3462, 1.2383, -0.4092, 1.6438, -0.0962, 0.4011, 0.2463, -0.5802], [ 1.6349, 0.7297, 1.2547, -0.3113, 0.9310, 0.1162, 1.7618, 0.4882, 0.7640, 0.2930], [ 1.1669, -0.7775, 1.2000, 0.6008, -0.2814, 0.5541, 0.5753, 1.4731, 1.6835, 0.7370], [ 1.5087, 0.6195, 0.1153, 1.2966, 1.8815, 1.1678, 1.5686, 1.6018, 0.2193, 1.3500]]) .. GENERATED FROM PYTHON SOURCE LINES 85-86 We can also optimize ``torch.nn.Module`` instances. .. GENERATED FROM PYTHON SOURCE LINES 86-99 .. code-block:: default class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.lin = torch.nn.Linear(100, 10) def forward(self, x): return torch.nn.functional.relu(self.lin(x)) mod = MyModule() opt_mod = torch.compile(mod) print(opt_mod(torch.randn(10, 100))) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[0.0000, 0.0000, 0.2419, 0.0446, 0.9011, 0.2674, 0.3633, 0.4984, 0.0000, 0.0988], [0.6906, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8490, 0.0000, 0.0000, 0.5475], [0.0852, 0.2762, 0.7441, 0.0000, 0.0000, 0.1820, 0.0000, 0.0000, 0.0000, 0.0334], [0.3024, 0.0077, 1.2572, 0.0000, 0.0000, 0.6520, 0.0000, 0.0000, 0.0000, 0.8976], [0.1998, 0.3333, 0.0000, 0.7803, 0.4202, 0.0915, 0.0000, 1.2543, 0.0000, 0.4615], [0.2487, 0.4187, 0.0000, 0.0000, 0.5124, 0.0000, 0.2512, 0.0000, 0.5850, 0.0000], [0.0000, 0.0048, 0.0000, 0.0000, 0.0000, 0.2287, 0.0000, 0.4841, 0.3915, 0.0000], [0.2017, 0.0000, 0.0896, 1.4135, 0.0593, 0.3788, 0.0000, 0.0000, 0.0000, 0.4972], [0.0000, 0.0000, 1.6580, 0.6414, 0.0000, 0.0000, 0.0000, 0.0000, 0.6491, 0.7755], [0.0000, 0.0000, 0.6442, 0.0260, 0.7456, 0.1000, 0.0000, 0.0000, 0.5366, 0.1193]], grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 100-108 Demonstrating Speedups ----------------------- Let's now demonstrate that using ``torch.compile`` can speed up real models. We will compare standard eager mode and ``torch.compile`` by evaluating and training a ``torchvision`` model on random data. Before we start, we need to define some utility functions. .. GENERATED FROM PYTHON SOURCE LINES 108-135 .. code-block:: default # Returns the result of running `fn()` and the time it took for `fn()` to run, # in seconds. We use CUDA events and synchronization for the most accurate # measurements. def timed(fn): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() result = fn() end.record() torch.cuda.synchronize() return result, start.elapsed_time(end) / 1000 # Generates random input and targets data for the model, where `b` is # batch size. def generate_data(b): return ( torch.randn(b, 3, 128, 128).to(torch.float32).cuda(), torch.randint(1000, (b,)).cuda(), ) N_ITERS = 10 from torchvision.models import densenet121 def init_model(): return densenet121().to(torch.float32).cuda() .. GENERATED FROM PYTHON SOURCE LINES 136-140 First, let's compare inference. Note that in the call to ``torch.compile``, we have have the additional ``mode`` argument, which we will discuss below. .. GENERATED FROM PYTHON SOURCE LINES 140-154 .. code-block:: default model = init_model() # Reset since we are using a different mode. import torch._dynamo torch._dynamo.reset() model_opt = torch.compile(model, mode="reduce-overhead") inp = generate_data(16)[0] with torch.no_grad(): print("eager:", timed(lambda: model(inp))[1]) print("compile:", timed(lambda: model_opt(inp))[1]) .. rst-class:: sphx-glr-script-out .. code-block:: none eager: 0.37502154541015625 /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:124: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance. compile: 85.104828125 .. GENERATED FROM PYTHON SOURCE LINES 155-161 Notice that ``torch.compile`` takes a lot longer to complete compared to eager. This is because ``torch.compile`` compiles the model into optimized kernels as it executes. In our example, the structure of the model doesn't change, and so recompilation is not needed. So if we run our optimized model several more times, we should see a significant improvement compared to eager. .. GENERATED FROM PYTHON SOURCE LINES 161-189 .. code-block:: default eager_times = [] for i in range(N_ITERS): inp = generate_data(16)[0] with torch.no_grad(): _, eager_time = timed(lambda: model(inp)) eager_times.append(eager_time) print(f"eager eval time {i}: {eager_time}") print("~" * 10) compile_times = [] for i in range(N_ITERS): inp = generate_data(16)[0] with torch.no_grad(): _, compile_time = timed(lambda: model_opt(inp)) compile_times.append(compile_time) print(f"compile eval time {i}: {compile_time}") print("~" * 10) import numpy as np eager_med = np.median(eager_times) compile_med = np.median(compile_times) speedup = eager_med / compile_med assert(speedup > 1) print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x") print("~" * 10) .. rst-class:: sphx-glr-script-out .. code-block:: none eager eval time 0: 0.01860639953613281 eager eval time 1: 0.016724479675292968 eager eval time 2: 0.016695743560791017 eager eval time 3: 0.01676595115661621 eager eval time 4: 0.016648927688598634 eager eval time 5: 0.016617376327514647 eager eval time 6: 0.016504703521728517 eager eval time 7: 0.016613887786865233 eager eval time 8: 0.016550783157348634 eager eval time 9: 0.016550559997558594 ~~~~~~~~~~ compile eval time 0: 0.6462002563476562 compile eval time 1: 0.00868489646911621 compile eval time 2: 0.008475359916687011 compile eval time 3: 0.008475520133972169 compile eval time 4: 0.008455327987670899 compile eval time 5: 0.008516384124755859 compile eval time 6: 0.008475456237792969 compile eval time 7: 0.00848476791381836 compile eval time 8: 0.008479455947875976 compile eval time 9: 0.008501279830932617 ~~~~~~~~~~ (eval) eager median: 0.01663315200805664, compile median: 0.008482111930847167, speedup: 1.960968228627864x ~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 190-212 And indeed, we can see that running our model with ``torch.compile`` results in a significant speedup. Speedup mainly comes from reducing Python overhead and GPU read/writes, and so the observed speedup may vary on factors such as model architecture and batch size. For example, if a model's architecture is simple and the amount of data is large, then the bottleneck would be GPU compute and the observed speedup may be less significant. You may also see different speedup results depending on the chosen ``mode`` argument. The ``"reduce-overhead"`` mode uses CUDA graphs to further reduce the overhead of Python. For your own models, you may need to experiment with different modes to maximize speedup. You can read more about modes `here `__. You may might also notice that the second time we run our model with ``torch.compile`` is significantly slower than the other runs, although it is much faster than the first run. This is because the ``"reduce-overhead"`` mode runs a few warm-up iterations for CUDA graphs. For general PyTorch benchmarking, you can try using ``torch.utils.benchmark`` instead of the ``timed`` function we defined above. We wrote our own timing function in this tutorial to show ``torch.compile``'s compilation latency. Now, let's consider comparing training. .. GENERATED FROM PYTHON SOURCE LINES 212-250 .. code-block:: default model = init_model() opt = torch.optim.Adam(model.parameters()) def train(mod, data): opt.zero_grad(True) pred = mod(data[0]) loss = torch.nn.CrossEntropyLoss()(pred, data[1]) loss.backward() opt.step() eager_times = [] for i in range(N_ITERS): inp = generate_data(16) _, eager_time = timed(lambda: train(model, inp)) eager_times.append(eager_time) print(f"eager train time {i}: {eager_time}") print("~" * 10) model = init_model() opt = torch.optim.Adam(model.parameters()) train_opt = torch.compile(train, mode="reduce-overhead") compile_times = [] for i in range(N_ITERS): inp = generate_data(16) _, compile_time = timed(lambda: train_opt(model, inp)) compile_times.append(compile_time) print(f"compile train time {i}: {compile_time}") print("~" * 10) eager_med = np.median(eager_times) compile_med = np.median(compile_times) speedup = eager_med / compile_med assert(speedup > 1) print(f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x") print("~" * 10) .. rst-class:: sphx-glr-script-out .. code-block:: none eager train time 0: 0.3913607177734375 eager train time 1: 0.05421667098999024 eager train time 2: 0.052566078186035155 eager train time 3: 0.052336830139160156 eager train time 4: 0.05221148681640625 eager train time 5: 0.05256249618530273 eager train time 6: 0.05228499221801758 eager train time 7: 0.053034526824951175 eager train time 8: 0.05213008117675781 eager train time 9: 0.0523721923828125 ~~~~~~~~~~ compile train time 0: 209.464 compile train time 1: 5.25901025390625 compile train time 2: 0.032414337158203126 compile train time 3: 0.024853727340698243 compile train time 4: 0.024753440856933595 compile train time 5: 0.02476995277404785 compile train time 6: 0.024747007369995116 compile train time 7: 0.025815391540527345 compile train time 8: 0.024771167755126954 compile train time 9: 0.0252161922454834 ~~~~~~~~~~ (train) eager median: 0.05246734428405762, compile median: 0.02503495979309082, speedup: 2.0957630736254513x ~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 251-258 Again, we can see that ``torch.compile`` takes longer in the first iteration, as it must compile the model, but in subsequent iterations, we see significant speedups compared to eager. We remark that the speedup numbers presented in this tutorial are for demonstration purposes only. Official speedup values can be seen at the `TorchInductor performance dashboard `__. .. GENERATED FROM PYTHON SOURCE LINES 260-272 Comparison to TorchScript and FX Tracing ----------------------------------------- We have seen that ``torch.compile`` can speed up PyTorch code. Why else should we use ``torch.compile`` over existing PyTorch compiler solutions, such as TorchScript or FX Tracing? Primarily, the advantage of ``torch.compile`` lies in its ability to handle arbitrary Python code with minimal changes to existing code. One case that ``torch.compile`` can handle that other compiler solutions struggle with is data-dependent control flow (the ``if x.sum() < 0:`` line below). .. GENERATED FROM PYTHON SOURCE LINES 272-289 .. code-block:: default def f1(x, y): if x.sum() < 0: return -y return y # Test that `fn1` and `fn2` return the same result, given # the same arguments `args`. Typically, `fn1` will be an eager function # while `fn2` will be a compiled function (torch.compile, TorchScript, or FX graph). def test_fns(fn1, fn2, args): out1 = fn1(*args) out2 = fn2(*args) return torch.allclose(out1, out2) inp1 = torch.randn(5, 5) inp2 = torch.randn(5, 5) .. GENERATED FROM PYTHON SOURCE LINES 290-293 TorchScript tracing ``f1`` results in silently incorrect results, since only the actual control flow path is traced. .. GENERATED FROM PYTHON SOURCE LINES 293-298 .. code-block:: default traced_f1 = torch.jit.trace(f1, (inp1, inp2)) print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2))) print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2))) .. rst-class:: sphx-glr-script-out .. code-block:: none /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:274: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! traced 1, 1: True traced 1, 2: False .. GENERATED FROM PYTHON SOURCE LINES 299-301 FX tracing ``f1`` results in an error due to the presence of data-dependent control flow. .. GENERATED FROM PYTHON SOURCE LINES 301-308 .. code-block:: default import traceback as tb try: torch.fx.symbolic_trace(f1) except: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 304, in torch.fx.symbolic_trace(f1) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 1193, in symbolic_trace graph = tracer.trace(root, concrete_args) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 793, in trace (self.create_arg(fn(*args)),), File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 274, in f1 if x.sum() < 0: File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 443, in __bool__ return self.tracer.to_bool(self) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 303, in to_bool raise TraceError('symbolically traced variables cannot be used as inputs to control flow') torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow .. GENERATED FROM PYTHON SOURCE LINES 309-312 If we provide a value for ``x`` as we try to FX trace ``f1``, then we run into the same problem as TorchScript tracing, as the data-dependent control flow is removed in the traced function. .. GENERATED FROM PYTHON SOURCE LINES 312-317 .. code-block:: default fx_f1 = torch.fx.symbolic_trace(f1, concrete_args={"x": inp1}) print("fx 1, 1:", test_fns(f1, fx_f1, (inp1, inp2))) print("fx 1, 2:", test_fns(f1, fx_f1, (-inp1, inp2))) .. rst-class:: sphx-glr-script-out .. code-block:: none /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py:862: UserWarning: Was not able to add assertion to guarantee correct input x to specialized function. It is up to the user to make sure that your inputs match the inputs you specialized the function with. fx 1, 1: True fx 1, 2: False .. GENERATED FROM PYTHON SOURCE LINES 318-320 Now we can see that ``torch.compile`` correctly handles data-dependent control flow. .. GENERATED FROM PYTHON SOURCE LINES 320-329 .. code-block:: default # Reset since we are using a different mode. torch._dynamo.reset() compile_f1 = torch.compile(f1) print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2))) print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2))) print("~" * 10) .. rst-class:: sphx-glr-script-out .. code-block:: none compile 1, 1: True compile 1, 2: True ~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 330-338 TorchScript scripting can handle data-dependent control flow, but this solution comes with its own set of problems. Namely, TorchScript scripting can require major code changes and will raise errors when unsupported Python is used. In the example below, we forget TorchScript type annotations and we receive a TorchScript error because the input type for argument ``y``, an ``int``, does not match with the default argument type, ``torch.Tensor``. .. GENERATED FROM PYTHON SOURCE LINES 338-351 .. code-block:: default def f2(x, y): return x + y inp1 = torch.randn(5, 5) inp2 = 3 script_f2 = torch.jit.script(f2) try: script_f2(inp1, inp2) except: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 347, in script_f2(inp1, inp2) RuntimeError: f2() Expected a value of type 'Tensor (inferred)' for argument 'y' but instead found type 'int'. Inferred 'y' to be of type 'Tensor' because it was not annotated with an explicit type. Position: 1 Value: 3 Declaration: f2(Tensor x, Tensor y) -> Tensor Cast error details: Unable to cast 3 to Tensor .. GENERATED FROM PYTHON SOURCE LINES 352-353 However, ``torch.compile`` is easily able to handle ``f2``. .. GENERATED FROM PYTHON SOURCE LINES 353-358 .. code-block:: default compile_f2 = torch.compile(f2) print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2))) print("~" * 10) .. rst-class:: sphx-glr-script-out .. code-block:: none compile 2: True ~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 359-361 Another case that ``torch.compile`` handles well compared to previous compilers solutions is the usage of non-PyTorch functions. .. GENERATED FROM PYTHON SOURCE LINES 361-370 .. code-block:: default import scipy def f3(x): x = x * 2 x = scipy.fft.dct(x.numpy()) x = torch.from_numpy(x) x = x * 2 return x .. GENERATED FROM PYTHON SOURCE LINES 371-373 TorchScript tracing treats results from non-PyTorch function calls as constants, and so our results can be silently wrong. .. GENERATED FROM PYTHON SOURCE LINES 373-379 .. code-block:: default inp1 = torch.randn(5, 5) inp2 = torch.randn(5, 5) traced_f3 = torch.jit.trace(f3, (inp1,)) print("traced 3:", test_fns(f3, traced_f3, (inp2,))) .. rst-class:: sphx-glr-script-out .. code-block:: none /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:365: TracerWarning: Converting a tensor to a NumPy array might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:366: TracerWarning: torch.from_numpy results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect. traced 3: False .. GENERATED FROM PYTHON SOURCE LINES 380-381 TorchScript scripting and FX tracing disallow non-PyTorch function calls. .. GENERATED FROM PYTHON SOURCE LINES 381-392 .. code-block:: default try: torch.jit.script(f3) except: tb.print_exc() try: torch.fx.symbolic_trace(f3) except: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 383, in torch.jit.script(f3) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/jit/_script.py", line 1395, in script fn = torch._C._jit_script_compile( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_jit_internal.py", line 1216, in _try_get_dispatched_fn return boolean_dispatched.get(fn) File "/opt/conda/envs/py_3.10/lib/python3.10/weakref.py", line 453, in get return self.data.get(ref(key),default) TypeError: cannot create weak reference to 'uarray._Function' object Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 388, in torch.fx.symbolic_trace(f3) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 1193, in symbolic_trace graph = tracer.trace(root, concrete_args) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 793, in trace (self.create_arg(fn(*args)),), File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 365, in f3 x = scipy.fft.dct(x.numpy()) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/scipy/fft/_backend.py", line 25, in __ua_function__ return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/scipy/fft/_pocketfft/realtransforms.py", line 19, in _r2r tmp = _asfarray(x) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/scipy/fft/_pocketfft/helper.py", line 89, in _asfarray if x.dtype == np.float16: File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 548, in impl return tracer.create_proxy('call_function', target, args, kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 187, in create_proxy args_ = self.create_arg(args) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 401, in create_arg return super().create_arg(a) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 258, in create_arg return type(a)(self.create_arg(elem) for elem in a) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 258, in return type(a)(self.create_arg(elem) for elem in a) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 401, in create_arg return super().create_arg(a) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 294, in create_arg raise NotImplementedError(f"argument of type: {type(a)}") NotImplementedError: argument of type: .. GENERATED FROM PYTHON SOURCE LINES 393-395 In comparison, ``torch.compile`` is easily able to handle the non-PyTorch function call. .. GENERATED FROM PYTHON SOURCE LINES 395-399 .. code-block:: default compile_f3 = torch.compile(f3) print("compile 3:", test_fns(f3, compile_f3, (inp2,))) .. rst-class:: sphx-glr-script-out .. code-block:: none compile 3: True .. GENERATED FROM PYTHON SOURCE LINES 400-414 TorchDynamo and FX Graphs -------------------------- One important component of ``torch.compile`` is TorchDynamo. TorchDynamo is responsible for JIT compiling arbitrary Python code into `FX graphs `__, which can then be further optimized. TorchDynamo extracts FX graphs by analyzing Python bytecode during runtime and detecting calls to PyTorch operations. Normally, TorchInductor, another component of ``torch.compile``, further compiles the FX graphs into optimized kernels, but TorchDynamo allows for different backends to be used. In order to inspect the FX graphs that TorchDynamo outputs, let us create a custom backend that outputs the FX graph and simply returns the graph's unoptimized forward method. .. GENERATED FROM PYTHON SOURCE LINES 414-427 .. code-block:: default from typing import List def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): print("custom backend called with FX graph:") gm.graph.print_tabular() return gm.forward # Reset since we are using a different backend. torch._dynamo.reset() opt_model = torch.compile(init_model(), backend=custom_backend) opt_model(generate_data(16)[0]) .. rst-class:: sphx-glr-script-out .. code-block:: none custom backend called with FX graph: opcode name target args kwargs ------------- ------------------------------------------------- ---------------------------------------------------------- ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ ----------------- placeholder l_x_ L_x_ () {} call_module l__self___features_conv0 L__self___features_conv0 (l_x_,) {} call_module l__self___features_norm0 L__self___features_norm0 (l__self___features_conv0,) {} call_module l__self___features_relu0 L__self___features_relu0 (l__self___features_norm0,) {} call_module l__self___features_pool0 L__self___features_pool0 (l__self___features_relu0,) {} call_function concated_features ([l__self___features_pool0], 1) {} call_module l__self___features_denseblock1_denselayer1_norm1 L__self___features_denseblock1_denselayer1_norm1 (concated_features,) {} call_module l__self___features_denseblock1_denselayer1_relu1 L__self___features_denseblock1_denselayer1_relu1 (l__self___features_denseblock1_denselayer1_norm1,) {} call_module bottleneck_output L__self___features_denseblock1_denselayer1_conv1 (l__self___features_denseblock1_denselayer1_relu1,) {} call_module l__self___features_denseblock1_denselayer1_norm2 L__self___features_denseblock1_denselayer1_norm2 (bottleneck_output,) {} call_module l__self___features_denseblock1_denselayer1_relu2 L__self___features_denseblock1_denselayer1_relu2 (l__self___features_denseblock1_denselayer1_norm2,) {} call_module new_features L__self___features_denseblock1_denselayer1_conv2 (l__self___features_denseblock1_denselayer1_relu2,) {} call_function concated_features_1 ([l__self___features_pool0, new_features], 1) {} call_module l__self___features_denseblock1_denselayer2_norm1 L__self___features_denseblock1_denselayer2_norm1 (concated_features_1,) {} call_module l__self___features_denseblock1_denselayer2_relu1 L__self___features_denseblock1_denselayer2_relu1 (l__self___features_denseblock1_denselayer2_norm1,) {} call_module bottleneck_output_2 L__self___features_denseblock1_denselayer2_conv1 (l__self___features_denseblock1_denselayer2_relu1,) {} call_module l__self___features_denseblock1_denselayer2_norm2 L__self___features_denseblock1_denselayer2_norm2 (bottleneck_output_2,) {} call_module l__self___features_denseblock1_denselayer2_relu2 L__self___features_denseblock1_denselayer2_relu2 (l__self___features_denseblock1_denselayer2_norm2,) {} call_module new_features_2 L__self___features_denseblock1_denselayer2_conv2 (l__self___features_denseblock1_denselayer2_relu2,) {} call_function concated_features_2 ([l__self___features_pool0, new_features, new_features_2], 1) {} call_module l__self___features_denseblock1_denselayer3_norm1 L__self___features_denseblock1_denselayer3_norm1 (concated_features_2,) {} call_module l__self___features_denseblock1_denselayer3_relu1 L__self___features_denseblock1_denselayer3_relu1 (l__self___features_denseblock1_denselayer3_norm1,) {} call_module bottleneck_output_4 L__self___features_denseblock1_denselayer3_conv1 (l__self___features_denseblock1_denselayer3_relu1,) {} call_module l__self___features_denseblock1_denselayer3_norm2 L__self___features_denseblock1_denselayer3_norm2 (bottleneck_output_4,) {} call_module l__self___features_denseblock1_denselayer3_relu2 L__self___features_denseblock1_denselayer3_relu2 (l__self___features_denseblock1_denselayer3_norm2,) {} call_module new_features_4 L__self___features_denseblock1_denselayer3_conv2 (l__self___features_denseblock1_denselayer3_relu2,) {} call_function concated_features_3 ([l__self___features_pool0, new_features, new_features_2, new_features_4], 1) {} call_module l__self___features_denseblock1_denselayer4_norm1 L__self___features_denseblock1_denselayer4_norm1 (concated_features_3,) {} call_module l__self___features_denseblock1_denselayer4_relu1 L__self___features_denseblock1_denselayer4_relu1 (l__self___features_denseblock1_denselayer4_norm1,) {} call_module bottleneck_output_6 L__self___features_denseblock1_denselayer4_conv1 (l__self___features_denseblock1_denselayer4_relu1,) {} call_module l__self___features_denseblock1_denselayer4_norm2 L__self___features_denseblock1_denselayer4_norm2 (bottleneck_output_6,) {} call_module l__self___features_denseblock1_denselayer4_relu2 L__self___features_denseblock1_denselayer4_relu2 (l__self___features_denseblock1_denselayer4_norm2,) {} call_module new_features_6 L__self___features_denseblock1_denselayer4_conv2 (l__self___features_denseblock1_denselayer4_relu2,) {} call_function concated_features_4 ([l__self___features_pool0, new_features, new_features_2, new_features_4, new_features_6], 1) {} call_module l__self___features_denseblock1_denselayer5_norm1 L__self___features_denseblock1_denselayer5_norm1 (concated_features_4,) {} call_module l__self___features_denseblock1_denselayer5_relu1 L__self___features_denseblock1_denselayer5_relu1 (l__self___features_denseblock1_denselayer5_norm1,) {} call_module bottleneck_output_8 L__self___features_denseblock1_denselayer5_conv1 (l__self___features_denseblock1_denselayer5_relu1,) {} call_module l__self___features_denseblock1_denselayer5_norm2 L__self___features_denseblock1_denselayer5_norm2 (bottleneck_output_8,) {} call_module l__self___features_denseblock1_denselayer5_relu2 L__self___features_denseblock1_denselayer5_relu2 (l__self___features_denseblock1_denselayer5_norm2,) {} call_module new_features_8 L__self___features_denseblock1_denselayer5_conv2 (l__self___features_denseblock1_denselayer5_relu2,) {} call_function concated_features_5 ([l__self___features_pool0, new_features, new_features_2, new_features_4, new_features_6, new_features_8], 1) {} call_module l__self___features_denseblock1_denselayer6_norm1 L__self___features_denseblock1_denselayer6_norm1 (concated_features_5,) {} call_module l__self___features_denseblock1_denselayer6_relu1 L__self___features_denseblock1_denselayer6_relu1 (l__self___features_denseblock1_denselayer6_norm1,) {} call_module bottleneck_output_10 L__self___features_denseblock1_denselayer6_conv1 (l__self___features_denseblock1_denselayer6_relu1,) {} call_module l__self___features_denseblock1_denselayer6_norm2 L__self___features_denseblock1_denselayer6_norm2 (bottleneck_output_10,) {} call_module l__self___features_denseblock1_denselayer6_relu2 L__self___features_denseblock1_denselayer6_relu2 (l__self___features_denseblock1_denselayer6_norm2,) {} call_module new_features_10 L__self___features_denseblock1_denselayer6_conv2 (l__self___features_denseblock1_denselayer6_relu2,) {} call_function cat_6 ([l__self___features_pool0, new_features, new_features_2, new_features_4, new_features_6, new_features_8, new_features_10], 1) {} call_module l__self___features_transition1_norm L__self___features_transition1_norm (cat_6,) {} call_module l__self___features_transition1_relu L__self___features_transition1_relu (l__self___features_transition1_norm,) {} call_module l__self___features_transition1_conv L__self___features_transition1_conv (l__self___features_transition1_relu,) {} call_module l__self___features_transition1_pool L__self___features_transition1_pool (l__self___features_transition1_conv,) {} call_function concated_features_6 ([l__self___features_transition1_pool], 1) {} call_module l__self___features_denseblock2_denselayer1_norm1 L__self___features_denseblock2_denselayer1_norm1 (concated_features_6,) {} call_module l__self___features_denseblock2_denselayer1_relu1 L__self___features_denseblock2_denselayer1_relu1 (l__self___features_denseblock2_denselayer1_norm1,) {} call_module bottleneck_output_12 L__self___features_denseblock2_denselayer1_conv1 (l__self___features_denseblock2_denselayer1_relu1,) {} call_module l__self___features_denseblock2_denselayer1_norm2 L__self___features_denseblock2_denselayer1_norm2 (bottleneck_output_12,) {} call_module l__self___features_denseblock2_denselayer1_relu2 L__self___features_denseblock2_denselayer1_relu2 (l__self___features_denseblock2_denselayer1_norm2,) {} call_module new_features_12 L__self___features_denseblock2_denselayer1_conv2 (l__self___features_denseblock2_denselayer1_relu2,) {} call_function concated_features_7 ([l__self___features_transition1_pool, new_features_12], 1) {} call_module l__self___features_denseblock2_denselayer2_norm1 L__self___features_denseblock2_denselayer2_norm1 (concated_features_7,) {} call_module l__self___features_denseblock2_denselayer2_relu1 L__self___features_denseblock2_denselayer2_relu1 (l__self___features_denseblock2_denselayer2_norm1,) {} call_module bottleneck_output_14 L__self___features_denseblock2_denselayer2_conv1 (l__self___features_denseblock2_denselayer2_relu1,) {} call_module l__self___features_denseblock2_denselayer2_norm2 L__self___features_denseblock2_denselayer2_norm2 (bottleneck_output_14,) {} call_module l__self___features_denseblock2_denselayer2_relu2 L__self___features_denseblock2_denselayer2_relu2 (l__self___features_denseblock2_denselayer2_norm2,) {} call_module new_features_14 L__self___features_denseblock2_denselayer2_conv2 (l__self___features_denseblock2_denselayer2_relu2,) {} call_function concated_features_8 ([l__self___features_transition1_pool, new_features_12, new_features_14], 1) {} call_module l__self___features_denseblock2_denselayer3_norm1 L__self___features_denseblock2_denselayer3_norm1 (concated_features_8,) {} call_module l__self___features_denseblock2_denselayer3_relu1 L__self___features_denseblock2_denselayer3_relu1 (l__self___features_denseblock2_denselayer3_norm1,) {} call_module bottleneck_output_16 L__self___features_denseblock2_denselayer3_conv1 (l__self___features_denseblock2_denselayer3_relu1,) {} call_module l__self___features_denseblock2_denselayer3_norm2 L__self___features_denseblock2_denselayer3_norm2 (bottleneck_output_16,) {} call_module l__self___features_denseblock2_denselayer3_relu2 L__self___features_denseblock2_denselayer3_relu2 (l__self___features_denseblock2_denselayer3_norm2,) {} call_module new_features_16 L__self___features_denseblock2_denselayer3_conv2 (l__self___features_denseblock2_denselayer3_relu2,) {} call_function concated_features_9 ([l__self___features_transition1_pool, new_features_12, new_features_14, new_features_16], 1) {} call_module l__self___features_denseblock2_denselayer4_norm1 L__self___features_denseblock2_denselayer4_norm1 (concated_features_9,) {} call_module l__self___features_denseblock2_denselayer4_relu1 L__self___features_denseblock2_denselayer4_relu1 (l__self___features_denseblock2_denselayer4_norm1,) {} call_module bottleneck_output_18 L__self___features_denseblock2_denselayer4_conv1 (l__self___features_denseblock2_denselayer4_relu1,) {} call_module l__self___features_denseblock2_denselayer4_norm2 L__self___features_denseblock2_denselayer4_norm2 (bottleneck_output_18,) {} call_module l__self___features_denseblock2_denselayer4_relu2 L__self___features_denseblock2_denselayer4_relu2 (l__self___features_denseblock2_denselayer4_norm2,) {} call_module new_features_18 L__self___features_denseblock2_denselayer4_conv2 (l__self___features_denseblock2_denselayer4_relu2,) {} call_function concated_features_10 ([l__self___features_transition1_pool, new_features_12, new_features_14, new_features_16, new_features_18], 1) {} call_module l__self___features_denseblock2_denselayer5_norm1 L__self___features_denseblock2_denselayer5_norm1 (concated_features_10,) {} call_module l__self___features_denseblock2_denselayer5_relu1 L__self___features_denseblock2_denselayer5_relu1 (l__self___features_denseblock2_denselayer5_norm1,) {} call_module bottleneck_output_20 L__self___features_denseblock2_denselayer5_conv1 (l__self___features_denseblock2_denselayer5_relu1,) {} call_module l__self___features_denseblock2_denselayer5_norm2 L__self___features_denseblock2_denselayer5_norm2 (bottleneck_output_20,) {} call_module l__self___features_denseblock2_denselayer5_relu2 L__self___features_denseblock2_denselayer5_relu2 (l__self___features_denseblock2_denselayer5_norm2,) {} call_module new_features_20 L__self___features_denseblock2_denselayer5_conv2 (l__self___features_denseblock2_denselayer5_relu2,) {} call_function concated_features_11 ([l__self___features_transition1_pool, new_features_12, new_features_14, new_features_16, new_features_18, new_features_20], 1) {} call_module l__self___features_denseblock2_denselayer6_norm1 L__self___features_denseblock2_denselayer6_norm1 (concated_features_11,) {} call_module l__self___features_denseblock2_denselayer6_relu1 L__self___features_denseblock2_denselayer6_relu1 (l__self___features_denseblock2_denselayer6_norm1,) {} call_module bottleneck_output_22 L__self___features_denseblock2_denselayer6_conv1 (l__self___features_denseblock2_denselayer6_relu1,) {} call_module l__self___features_denseblock2_denselayer6_norm2 L__self___features_denseblock2_denselayer6_norm2 (bottleneck_output_22,) {} call_module l__self___features_denseblock2_denselayer6_relu2 L__self___features_denseblock2_denselayer6_relu2 (l__self___features_denseblock2_denselayer6_norm2,) {} call_module new_features_22 L__self___features_denseblock2_denselayer6_conv2 (l__self___features_denseblock2_denselayer6_relu2,) {} call_function concated_features_12 ([l__self___features_transition1_pool, new_features_12, new_features_14, new_features_16, new_features_18, new_features_20, new_features_22], 1) {} call_module l__self___features_denseblock2_denselayer7_norm1 L__self___features_denseblock2_denselayer7_norm1 (concated_features_12,) {} call_module l__self___features_denseblock2_denselayer7_relu1 L__self___features_denseblock2_denselayer7_relu1 (l__self___features_denseblock2_denselayer7_norm1,) {} call_module bottleneck_output_24 L__self___features_denseblock2_denselayer7_conv1 (l__self___features_denseblock2_denselayer7_relu1,) {} call_module l__self___features_denseblock2_denselayer7_norm2 L__self___features_denseblock2_denselayer7_norm2 (bottleneck_output_24,) {} call_module l__self___features_denseblock2_denselayer7_relu2 L__self___features_denseblock2_denselayer7_relu2 (l__self___features_denseblock2_denselayer7_norm2,) {} call_module new_features_24 L__self___features_denseblock2_denselayer7_conv2 (l__self___features_denseblock2_denselayer7_relu2,) {} call_function concated_features_13 ([l__self___features_transition1_pool, new_features_12, new_features_14, new_features_16, new_features_18, new_features_20, new_features_22, new_features_24], 1) {} call_module l__self___features_denseblock2_denselayer8_norm1 L__self___features_denseblock2_denselayer8_norm1 (concated_features_13,) {} call_module l__self___features_denseblock2_denselayer8_relu1 L__self___features_denseblock2_denselayer8_relu1 (l__self___features_denseblock2_denselayer8_norm1,) {} call_module bottleneck_output_26 L__self___features_denseblock2_denselayer8_conv1 (l__self___features_denseblock2_denselayer8_relu1,) {} call_module l__self___features_denseblock2_denselayer8_norm2 L__self___features_denseblock2_denselayer8_norm2 (bottleneck_output_26,) {} call_module l__self___features_denseblock2_denselayer8_relu2 L__self___features_denseblock2_denselayer8_relu2 (l__self___features_denseblock2_denselayer8_norm2,) {} call_module new_features_26 L__self___features_denseblock2_denselayer8_conv2 (l__self___features_denseblock2_denselayer8_relu2,) {} call_function concated_features_14 ([l__self___features_transition1_pool, new_features_12, new_features_14, new_features_16, new_features_18, new_features_20, new_features_22, new_features_24, new_features_26], 1) {} call_module l__self___features_denseblock2_denselayer9_norm1 L__self___features_denseblock2_denselayer9_norm1 (concated_features_14,) {} call_module l__self___features_denseblock2_denselayer9_relu1 L__self___features_denseblock2_denselayer9_relu1 (l__self___features_denseblock2_denselayer9_norm1,) {} call_module bottleneck_output_28 L__self___features_denseblock2_denselayer9_conv1 (l__self___features_denseblock2_denselayer9_relu1,) {} call_module l__self___features_denseblock2_denselayer9_norm2 L__self___features_denseblock2_denselayer9_norm2 (bottleneck_output_28,) {} call_module l__self___features_denseblock2_denselayer9_relu2 L__self___features_denseblock2_denselayer9_relu2 (l__self___features_denseblock2_denselayer9_norm2,) {} call_module new_features_28 L__self___features_denseblock2_denselayer9_conv2 (l__self___features_denseblock2_denselayer9_relu2,) {} call_function concated_features_15 ([l__self___features_transition1_pool, new_features_12, new_features_14, new_features_16, new_features_18, new_features_20, new_features_22, new_features_24, new_features_26, new_features_28], 1) {} call_module l__self___features_denseblock2_denselayer10_norm1 L__self___features_denseblock2_denselayer10_norm1 (concated_features_15,) {} call_module l__self___features_denseblock2_denselayer10_relu1 L__self___features_denseblock2_denselayer10_relu1 (l__self___features_denseblock2_denselayer10_norm1,) {} call_module bottleneck_output_30 L__self___features_denseblock2_denselayer10_conv1 (l__self___features_denseblock2_denselayer10_relu1,) {} call_module l__self___features_denseblock2_denselayer10_norm2 L__self___features_denseblock2_denselayer10_norm2 (bottleneck_output_30,) {} call_module l__self___features_denseblock2_denselayer10_relu2 L__self___features_denseblock2_denselayer10_relu2 (l__self___features_denseblock2_denselayer10_norm2,) {} call_module new_features_30 L__self___features_denseblock2_denselayer10_conv2 (l__self___features_denseblock2_denselayer10_relu2,) {} call_function concated_features_16 ([l__self___features_transition1_pool, new_features_12, new_features_14, new_features_16, new_features_18, new_features_20, new_features_22, new_features_24, new_features_26, new_features_28, new_features_30], 1) {} call_module l__self___features_denseblock2_denselayer11_norm1 L__self___features_denseblock2_denselayer11_norm1 (concated_features_16,) {} call_module l__self___features_denseblock2_denselayer11_relu1 L__self___features_denseblock2_denselayer11_relu1 (l__self___features_denseblock2_denselayer11_norm1,) {} call_module bottleneck_output_32 L__self___features_denseblock2_denselayer11_conv1 (l__self___features_denseblock2_denselayer11_relu1,) {} call_module l__self___features_denseblock2_denselayer11_norm2 L__self___features_denseblock2_denselayer11_norm2 (bottleneck_output_32,) {} call_module l__self___features_denseblock2_denselayer11_relu2 L__self___features_denseblock2_denselayer11_relu2 (l__self___features_denseblock2_denselayer11_norm2,) {} call_module new_features_32 L__self___features_denseblock2_denselayer11_conv2 (l__self___features_denseblock2_denselayer11_relu2,) {} call_function concated_features_17 ([l__self___features_transition1_pool, new_features_12, new_features_14, new_features_16, new_features_18, new_features_20, new_features_22, new_features_24, new_features_26, new_features_28, new_features_30, new_features_32], 1) {} call_module l__self___features_denseblock2_denselayer12_norm1 L__self___features_denseblock2_denselayer12_norm1 (concated_features_17,) {} call_module l__self___features_denseblock2_denselayer12_relu1 L__self___features_denseblock2_denselayer12_relu1 (l__self___features_denseblock2_denselayer12_norm1,) {} call_module bottleneck_output_34 L__self___features_denseblock2_denselayer12_conv1 (l__self___features_denseblock2_denselayer12_relu1,) {} call_module l__self___features_denseblock2_denselayer12_norm2 L__self___features_denseblock2_denselayer12_norm2 (bottleneck_output_34,) {} call_module l__self___features_denseblock2_denselayer12_relu2 L__self___features_denseblock2_denselayer12_relu2 (l__self___features_denseblock2_denselayer12_norm2,) {} call_module new_features_34 L__self___features_denseblock2_denselayer12_conv2 (l__self___features_denseblock2_denselayer12_relu2,) {} call_function cat_19 ([l__self___features_transition1_pool, new_features_12, new_features_14, new_features_16, new_features_18, new_features_20, new_features_22, new_features_24, new_features_26, new_features_28, new_features_30, new_features_32, new_features_34], 1) {} call_module l__self___features_transition2_norm L__self___features_transition2_norm (cat_19,) {} call_module l__self___features_transition2_relu L__self___features_transition2_relu (l__self___features_transition2_norm,) {} call_module l__self___features_transition2_conv L__self___features_transition2_conv (l__self___features_transition2_relu,) {} call_module l__self___features_transition2_pool L__self___features_transition2_pool (l__self___features_transition2_conv,) {} call_function concated_features_18 ([l__self___features_transition2_pool], 1) {} call_module l__self___features_denseblock3_denselayer1_norm1 L__self___features_denseblock3_denselayer1_norm1 (concated_features_18,) {} call_module l__self___features_denseblock3_denselayer1_relu1 L__self___features_denseblock3_denselayer1_relu1 (l__self___features_denseblock3_denselayer1_norm1,) {} call_module bottleneck_output_36 L__self___features_denseblock3_denselayer1_conv1 (l__self___features_denseblock3_denselayer1_relu1,) {} call_module l__self___features_denseblock3_denselayer1_norm2 L__self___features_denseblock3_denselayer1_norm2 (bottleneck_output_36,) {} call_module l__self___features_denseblock3_denselayer1_relu2 L__self___features_denseblock3_denselayer1_relu2 (l__self___features_denseblock3_denselayer1_norm2,) {} call_module new_features_36 L__self___features_denseblock3_denselayer1_conv2 (l__self___features_denseblock3_denselayer1_relu2,) {} call_function concated_features_19 ([l__self___features_transition2_pool, new_features_36], 1) {} call_module l__self___features_denseblock3_denselayer2_norm1 L__self___features_denseblock3_denselayer2_norm1 (concated_features_19,) {} call_module l__self___features_denseblock3_denselayer2_relu1 L__self___features_denseblock3_denselayer2_relu1 (l__self___features_denseblock3_denselayer2_norm1,) {} call_module bottleneck_output_38 L__self___features_denseblock3_denselayer2_conv1 (l__self___features_denseblock3_denselayer2_relu1,) {} call_module l__self___features_denseblock3_denselayer2_norm2 L__self___features_denseblock3_denselayer2_norm2 (bottleneck_output_38,) {} call_module l__self___features_denseblock3_denselayer2_relu2 L__self___features_denseblock3_denselayer2_relu2 (l__self___features_denseblock3_denselayer2_norm2,) {} call_module new_features_38 L__self___features_denseblock3_denselayer2_conv2 (l__self___features_denseblock3_denselayer2_relu2,) {} call_function concated_features_20 ([l__self___features_transition2_pool, new_features_36, new_features_38], 1) {} call_module l__self___features_denseblock3_denselayer3_norm1 L__self___features_denseblock3_denselayer3_norm1 (concated_features_20,) {} call_module l__self___features_denseblock3_denselayer3_relu1 L__self___features_denseblock3_denselayer3_relu1 (l__self___features_denseblock3_denselayer3_norm1,) {} call_module bottleneck_output_40 L__self___features_denseblock3_denselayer3_conv1 (l__self___features_denseblock3_denselayer3_relu1,) {} call_module l__self___features_denseblock3_denselayer3_norm2 L__self___features_denseblock3_denselayer3_norm2 (bottleneck_output_40,) {} call_module l__self___features_denseblock3_denselayer3_relu2 L__self___features_denseblock3_denselayer3_relu2 (l__self___features_denseblock3_denselayer3_norm2,) {} call_module new_features_40 L__self___features_denseblock3_denselayer3_conv2 (l__self___features_denseblock3_denselayer3_relu2,) {} call_function concated_features_21 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40], 1) {} call_module l__self___features_denseblock3_denselayer4_norm1 L__self___features_denseblock3_denselayer4_norm1 (concated_features_21,) {} call_module l__self___features_denseblock3_denselayer4_relu1 L__self___features_denseblock3_denselayer4_relu1 (l__self___features_denseblock3_denselayer4_norm1,) {} call_module bottleneck_output_42 L__self___features_denseblock3_denselayer4_conv1 (l__self___features_denseblock3_denselayer4_relu1,) {} call_module l__self___features_denseblock3_denselayer4_norm2 L__self___features_denseblock3_denselayer4_norm2 (bottleneck_output_42,) {} call_module l__self___features_denseblock3_denselayer4_relu2 L__self___features_denseblock3_denselayer4_relu2 (l__self___features_denseblock3_denselayer4_norm2,) {} call_module new_features_42 L__self___features_denseblock3_denselayer4_conv2 (l__self___features_denseblock3_denselayer4_relu2,) {} call_function concated_features_22 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42], 1) {} call_module l__self___features_denseblock3_denselayer5_norm1 L__self___features_denseblock3_denselayer5_norm1 (concated_features_22,) {} call_module l__self___features_denseblock3_denselayer5_relu1 L__self___features_denseblock3_denselayer5_relu1 (l__self___features_denseblock3_denselayer5_norm1,) {} call_module bottleneck_output_44 L__self___features_denseblock3_denselayer5_conv1 (l__self___features_denseblock3_denselayer5_relu1,) {} call_module l__self___features_denseblock3_denselayer5_norm2 L__self___features_denseblock3_denselayer5_norm2 (bottleneck_output_44,) {} call_module l__self___features_denseblock3_denselayer5_relu2 L__self___features_denseblock3_denselayer5_relu2 (l__self___features_denseblock3_denselayer5_norm2,) {} call_module new_features_44 L__self___features_denseblock3_denselayer5_conv2 (l__self___features_denseblock3_denselayer5_relu2,) {} call_function concated_features_23 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44], 1) {} call_module l__self___features_denseblock3_denselayer6_norm1 L__self___features_denseblock3_denselayer6_norm1 (concated_features_23,) {} call_module l__self___features_denseblock3_denselayer6_relu1 L__self___features_denseblock3_denselayer6_relu1 (l__self___features_denseblock3_denselayer6_norm1,) {} call_module bottleneck_output_46 L__self___features_denseblock3_denselayer6_conv1 (l__self___features_denseblock3_denselayer6_relu1,) {} call_module l__self___features_denseblock3_denselayer6_norm2 L__self___features_denseblock3_denselayer6_norm2 (bottleneck_output_46,) {} call_module l__self___features_denseblock3_denselayer6_relu2 L__self___features_denseblock3_denselayer6_relu2 (l__self___features_denseblock3_denselayer6_norm2,) {} call_module new_features_46 L__self___features_denseblock3_denselayer6_conv2 (l__self___features_denseblock3_denselayer6_relu2,) {} call_function concated_features_24 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46], 1) {} call_module l__self___features_denseblock3_denselayer7_norm1 L__self___features_denseblock3_denselayer7_norm1 (concated_features_24,) {} call_module l__self___features_denseblock3_denselayer7_relu1 L__self___features_denseblock3_denselayer7_relu1 (l__self___features_denseblock3_denselayer7_norm1,) {} call_module bottleneck_output_48 L__self___features_denseblock3_denselayer7_conv1 (l__self___features_denseblock3_denselayer7_relu1,) {} call_module l__self___features_denseblock3_denselayer7_norm2 L__self___features_denseblock3_denselayer7_norm2 (bottleneck_output_48,) {} call_module l__self___features_denseblock3_denselayer7_relu2 L__self___features_denseblock3_denselayer7_relu2 (l__self___features_denseblock3_denselayer7_norm2,) {} call_module new_features_48 L__self___features_denseblock3_denselayer7_conv2 (l__self___features_denseblock3_denselayer7_relu2,) {} call_function concated_features_25 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48], 1) {} call_module l__self___features_denseblock3_denselayer8_norm1 L__self___features_denseblock3_denselayer8_norm1 (concated_features_25,) {} call_module l__self___features_denseblock3_denselayer8_relu1 L__self___features_denseblock3_denselayer8_relu1 (l__self___features_denseblock3_denselayer8_norm1,) {} call_module bottleneck_output_50 L__self___features_denseblock3_denselayer8_conv1 (l__self___features_denseblock3_denselayer8_relu1,) {} call_module l__self___features_denseblock3_denselayer8_norm2 L__self___features_denseblock3_denselayer8_norm2 (bottleneck_output_50,) {} call_module l__self___features_denseblock3_denselayer8_relu2 L__self___features_denseblock3_denselayer8_relu2 (l__self___features_denseblock3_denselayer8_norm2,) {} call_module new_features_50 L__self___features_denseblock3_denselayer8_conv2 (l__self___features_denseblock3_denselayer8_relu2,) {} call_function concated_features_26 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50], 1) {} call_module l__self___features_denseblock3_denselayer9_norm1 L__self___features_denseblock3_denselayer9_norm1 (concated_features_26,) {} call_module l__self___features_denseblock3_denselayer9_relu1 L__self___features_denseblock3_denselayer9_relu1 (l__self___features_denseblock3_denselayer9_norm1,) {} call_module bottleneck_output_52 L__self___features_denseblock3_denselayer9_conv1 (l__self___features_denseblock3_denselayer9_relu1,) {} call_module l__self___features_denseblock3_denselayer9_norm2 L__self___features_denseblock3_denselayer9_norm2 (bottleneck_output_52,) {} call_module l__self___features_denseblock3_denselayer9_relu2 L__self___features_denseblock3_denselayer9_relu2 (l__self___features_denseblock3_denselayer9_norm2,) {} call_module new_features_52 L__self___features_denseblock3_denselayer9_conv2 (l__self___features_denseblock3_denselayer9_relu2,) {} call_function concated_features_27 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52], 1) {} call_module l__self___features_denseblock3_denselayer10_norm1 L__self___features_denseblock3_denselayer10_norm1 (concated_features_27,) {} call_module l__self___features_denseblock3_denselayer10_relu1 L__self___features_denseblock3_denselayer10_relu1 (l__self___features_denseblock3_denselayer10_norm1,) {} call_module bottleneck_output_54 L__self___features_denseblock3_denselayer10_conv1 (l__self___features_denseblock3_denselayer10_relu1,) {} call_module l__self___features_denseblock3_denselayer10_norm2 L__self___features_denseblock3_denselayer10_norm2 (bottleneck_output_54,) {} call_module l__self___features_denseblock3_denselayer10_relu2 L__self___features_denseblock3_denselayer10_relu2 (l__self___features_denseblock3_denselayer10_norm2,) {} call_module new_features_54 L__self___features_denseblock3_denselayer10_conv2 (l__self___features_denseblock3_denselayer10_relu2,) {} call_function concated_features_28 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54], 1) {} call_module l__self___features_denseblock3_denselayer11_norm1 L__self___features_denseblock3_denselayer11_norm1 (concated_features_28,) {} call_module l__self___features_denseblock3_denselayer11_relu1 L__self___features_denseblock3_denselayer11_relu1 (l__self___features_denseblock3_denselayer11_norm1,) {} call_module bottleneck_output_56 L__self___features_denseblock3_denselayer11_conv1 (l__self___features_denseblock3_denselayer11_relu1,) {} call_module l__self___features_denseblock3_denselayer11_norm2 L__self___features_denseblock3_denselayer11_norm2 (bottleneck_output_56,) {} call_module l__self___features_denseblock3_denselayer11_relu2 L__self___features_denseblock3_denselayer11_relu2 (l__self___features_denseblock3_denselayer11_norm2,) {} call_module new_features_56 L__self___features_denseblock3_denselayer11_conv2 (l__self___features_denseblock3_denselayer11_relu2,) {} call_function concated_features_29 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56], 1) {} call_module l__self___features_denseblock3_denselayer12_norm1 L__self___features_denseblock3_denselayer12_norm1 (concated_features_29,) {} call_module l__self___features_denseblock3_denselayer12_relu1 L__self___features_denseblock3_denselayer12_relu1 (l__self___features_denseblock3_denselayer12_norm1,) {} call_module bottleneck_output_58 L__self___features_denseblock3_denselayer12_conv1 (l__self___features_denseblock3_denselayer12_relu1,) {} call_module l__self___features_denseblock3_denselayer12_norm2 L__self___features_denseblock3_denselayer12_norm2 (bottleneck_output_58,) {} call_module l__self___features_denseblock3_denselayer12_relu2 L__self___features_denseblock3_denselayer12_relu2 (l__self___features_denseblock3_denselayer12_norm2,) {} call_module new_features_58 L__self___features_denseblock3_denselayer12_conv2 (l__self___features_denseblock3_denselayer12_relu2,) {} call_function concated_features_30 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58], 1) {} call_module l__self___features_denseblock3_denselayer13_norm1 L__self___features_denseblock3_denselayer13_norm1 (concated_features_30,) {} call_module l__self___features_denseblock3_denselayer13_relu1 L__self___features_denseblock3_denselayer13_relu1 (l__self___features_denseblock3_denselayer13_norm1,) {} call_module bottleneck_output_60 L__self___features_denseblock3_denselayer13_conv1 (l__self___features_denseblock3_denselayer13_relu1,) {} call_module l__self___features_denseblock3_denselayer13_norm2 L__self___features_denseblock3_denselayer13_norm2 (bottleneck_output_60,) {} call_module l__self___features_denseblock3_denselayer13_relu2 L__self___features_denseblock3_denselayer13_relu2 (l__self___features_denseblock3_denselayer13_norm2,) {} call_module new_features_60 L__self___features_denseblock3_denselayer13_conv2 (l__self___features_denseblock3_denselayer13_relu2,) {} call_function concated_features_31 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60], 1) {} call_module l__self___features_denseblock3_denselayer14_norm1 L__self___features_denseblock3_denselayer14_norm1 (concated_features_31,) {} call_module l__self___features_denseblock3_denselayer14_relu1 L__self___features_denseblock3_denselayer14_relu1 (l__self___features_denseblock3_denselayer14_norm1,) {} call_module bottleneck_output_62 L__self___features_denseblock3_denselayer14_conv1 (l__self___features_denseblock3_denselayer14_relu1,) {} call_module l__self___features_denseblock3_denselayer14_norm2 L__self___features_denseblock3_denselayer14_norm2 (bottleneck_output_62,) {} call_module l__self___features_denseblock3_denselayer14_relu2 L__self___features_denseblock3_denselayer14_relu2 (l__self___features_denseblock3_denselayer14_norm2,) {} call_module new_features_62 L__self___features_denseblock3_denselayer14_conv2 (l__self___features_denseblock3_denselayer14_relu2,) {} call_function concated_features_32 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62], 1) {} call_module l__self___features_denseblock3_denselayer15_norm1 L__self___features_denseblock3_denselayer15_norm1 (concated_features_32,) {} call_module l__self___features_denseblock3_denselayer15_relu1 L__self___features_denseblock3_denselayer15_relu1 (l__self___features_denseblock3_denselayer15_norm1,) {} call_module bottleneck_output_64 L__self___features_denseblock3_denselayer15_conv1 (l__self___features_denseblock3_denselayer15_relu1,) {} call_module l__self___features_denseblock3_denselayer15_norm2 L__self___features_denseblock3_denselayer15_norm2 (bottleneck_output_64,) {} call_module l__self___features_denseblock3_denselayer15_relu2 L__self___features_denseblock3_denselayer15_relu2 (l__self___features_denseblock3_denselayer15_norm2,) {} call_module new_features_64 L__self___features_denseblock3_denselayer15_conv2 (l__self___features_denseblock3_denselayer15_relu2,) {} call_function concated_features_33 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62, new_features_64], 1) {} call_module l__self___features_denseblock3_denselayer16_norm1 L__self___features_denseblock3_denselayer16_norm1 (concated_features_33,) {} call_module l__self___features_denseblock3_denselayer16_relu1 L__self___features_denseblock3_denselayer16_relu1 (l__self___features_denseblock3_denselayer16_norm1,) {} call_module bottleneck_output_66 L__self___features_denseblock3_denselayer16_conv1 (l__self___features_denseblock3_denselayer16_relu1,) {} call_module l__self___features_denseblock3_denselayer16_norm2 L__self___features_denseblock3_denselayer16_norm2 (bottleneck_output_66,) {} call_module l__self___features_denseblock3_denselayer16_relu2 L__self___features_denseblock3_denselayer16_relu2 (l__self___features_denseblock3_denselayer16_norm2,) {} call_module new_features_66 L__self___features_denseblock3_denselayer16_conv2 (l__self___features_denseblock3_denselayer16_relu2,) {} call_function concated_features_34 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62, new_features_64, new_features_66], 1) {} call_module l__self___features_denseblock3_denselayer17_norm1 L__self___features_denseblock3_denselayer17_norm1 (concated_features_34,) {} call_module l__self___features_denseblock3_denselayer17_relu1 L__self___features_denseblock3_denselayer17_relu1 (l__self___features_denseblock3_denselayer17_norm1,) {} call_module bottleneck_output_68 L__self___features_denseblock3_denselayer17_conv1 (l__self___features_denseblock3_denselayer17_relu1,) {} call_module l__self___features_denseblock3_denselayer17_norm2 L__self___features_denseblock3_denselayer17_norm2 (bottleneck_output_68,) {} call_module l__self___features_denseblock3_denselayer17_relu2 L__self___features_denseblock3_denselayer17_relu2 (l__self___features_denseblock3_denselayer17_norm2,) {} call_module new_features_68 L__self___features_denseblock3_denselayer17_conv2 (l__self___features_denseblock3_denselayer17_relu2,) {} call_function concated_features_35 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62, new_features_64, new_features_66, new_features_68], 1) {} call_module l__self___features_denseblock3_denselayer18_norm1 L__self___features_denseblock3_denselayer18_norm1 (concated_features_35,) {} call_module l__self___features_denseblock3_denselayer18_relu1 L__self___features_denseblock3_denselayer18_relu1 (l__self___features_denseblock3_denselayer18_norm1,) {} call_module bottleneck_output_70 L__self___features_denseblock3_denselayer18_conv1 (l__self___features_denseblock3_denselayer18_relu1,) {} call_module l__self___features_denseblock3_denselayer18_norm2 L__self___features_denseblock3_denselayer18_norm2 (bottleneck_output_70,) {} call_module l__self___features_denseblock3_denselayer18_relu2 L__self___features_denseblock3_denselayer18_relu2 (l__self___features_denseblock3_denselayer18_norm2,) {} call_module new_features_70 L__self___features_denseblock3_denselayer18_conv2 (l__self___features_denseblock3_denselayer18_relu2,) {} call_function concated_features_36 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62, new_features_64, new_features_66, new_features_68, new_features_70], 1) {} call_module l__self___features_denseblock3_denselayer19_norm1 L__self___features_denseblock3_denselayer19_norm1 (concated_features_36,) {} call_module l__self___features_denseblock3_denselayer19_relu1 L__self___features_denseblock3_denselayer19_relu1 (l__self___features_denseblock3_denselayer19_norm1,) {} call_module bottleneck_output_72 L__self___features_denseblock3_denselayer19_conv1 (l__self___features_denseblock3_denselayer19_relu1,) {} call_module l__self___features_denseblock3_denselayer19_norm2 L__self___features_denseblock3_denselayer19_norm2 (bottleneck_output_72,) {} call_module l__self___features_denseblock3_denselayer19_relu2 L__self___features_denseblock3_denselayer19_relu2 (l__self___features_denseblock3_denselayer19_norm2,) {} call_module new_features_72 L__self___features_denseblock3_denselayer19_conv2 (l__self___features_denseblock3_denselayer19_relu2,) {} call_function concated_features_37 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62, new_features_64, new_features_66, new_features_68, new_features_70, new_features_72], 1) {} call_module l__self___features_denseblock3_denselayer20_norm1 L__self___features_denseblock3_denselayer20_norm1 (concated_features_37,) {} call_module l__self___features_denseblock3_denselayer20_relu1 L__self___features_denseblock3_denselayer20_relu1 (l__self___features_denseblock3_denselayer20_norm1,) {} call_module bottleneck_output_74 L__self___features_denseblock3_denselayer20_conv1 (l__self___features_denseblock3_denselayer20_relu1,) {} call_module l__self___features_denseblock3_denselayer20_norm2 L__self___features_denseblock3_denselayer20_norm2 (bottleneck_output_74,) {} call_module l__self___features_denseblock3_denselayer20_relu2 L__self___features_denseblock3_denselayer20_relu2 (l__self___features_denseblock3_denselayer20_norm2,) {} call_module new_features_74 L__self___features_denseblock3_denselayer20_conv2 (l__self___features_denseblock3_denselayer20_relu2,) {} call_function concated_features_38 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62, new_features_64, new_features_66, new_features_68, new_features_70, new_features_72, new_features_74], 1) {} call_module l__self___features_denseblock3_denselayer21_norm1 L__self___features_denseblock3_denselayer21_norm1 (concated_features_38,) {} call_module l__self___features_denseblock3_denselayer21_relu1 L__self___features_denseblock3_denselayer21_relu1 (l__self___features_denseblock3_denselayer21_norm1,) {} call_module bottleneck_output_76 L__self___features_denseblock3_denselayer21_conv1 (l__self___features_denseblock3_denselayer21_relu1,) {} call_module l__self___features_denseblock3_denselayer21_norm2 L__self___features_denseblock3_denselayer21_norm2 (bottleneck_output_76,) {} call_module l__self___features_denseblock3_denselayer21_relu2 L__self___features_denseblock3_denselayer21_relu2 (l__self___features_denseblock3_denselayer21_norm2,) {} call_module new_features_76 L__self___features_denseblock3_denselayer21_conv2 (l__self___features_denseblock3_denselayer21_relu2,) {} call_function concated_features_39 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62, new_features_64, new_features_66, new_features_68, new_features_70, new_features_72, new_features_74, new_features_76], 1) {} call_module l__self___features_denseblock3_denselayer22_norm1 L__self___features_denseblock3_denselayer22_norm1 (concated_features_39,) {} call_module l__self___features_denseblock3_denselayer22_relu1 L__self___features_denseblock3_denselayer22_relu1 (l__self___features_denseblock3_denselayer22_norm1,) {} call_module bottleneck_output_78 L__self___features_denseblock3_denselayer22_conv1 (l__self___features_denseblock3_denselayer22_relu1,) {} call_module l__self___features_denseblock3_denselayer22_norm2 L__self___features_denseblock3_denselayer22_norm2 (bottleneck_output_78,) {} call_module l__self___features_denseblock3_denselayer22_relu2 L__self___features_denseblock3_denselayer22_relu2 (l__self___features_denseblock3_denselayer22_norm2,) {} call_module new_features_78 L__self___features_denseblock3_denselayer22_conv2 (l__self___features_denseblock3_denselayer22_relu2,) {} call_function concated_features_40 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62, new_features_64, new_features_66, new_features_68, new_features_70, new_features_72, new_features_74, new_features_76, new_features_78], 1) {} call_module l__self___features_denseblock3_denselayer23_norm1 L__self___features_denseblock3_denselayer23_norm1 (concated_features_40,) {} call_module l__self___features_denseblock3_denselayer23_relu1 L__self___features_denseblock3_denselayer23_relu1 (l__self___features_denseblock3_denselayer23_norm1,) {} call_module bottleneck_output_80 L__self___features_denseblock3_denselayer23_conv1 (l__self___features_denseblock3_denselayer23_relu1,) {} call_module l__self___features_denseblock3_denselayer23_norm2 L__self___features_denseblock3_denselayer23_norm2 (bottleneck_output_80,) {} call_module l__self___features_denseblock3_denselayer23_relu2 L__self___features_denseblock3_denselayer23_relu2 (l__self___features_denseblock3_denselayer23_norm2,) {} call_module new_features_80 L__self___features_denseblock3_denselayer23_conv2 (l__self___features_denseblock3_denselayer23_relu2,) {} call_function concated_features_41 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62, new_features_64, new_features_66, new_features_68, new_features_70, new_features_72, new_features_74, new_features_76, new_features_78, new_features_80], 1) {} call_module l__self___features_denseblock3_denselayer24_norm1 L__self___features_denseblock3_denselayer24_norm1 (concated_features_41,) {} call_module l__self___features_denseblock3_denselayer24_relu1 L__self___features_denseblock3_denselayer24_relu1 (l__self___features_denseblock3_denselayer24_norm1,) {} call_module bottleneck_output_82 L__self___features_denseblock3_denselayer24_conv1 (l__self___features_denseblock3_denselayer24_relu1,) {} call_module l__self___features_denseblock3_denselayer24_norm2 L__self___features_denseblock3_denselayer24_norm2 (bottleneck_output_82,) {} call_module l__self___features_denseblock3_denselayer24_relu2 L__self___features_denseblock3_denselayer24_relu2 (l__self___features_denseblock3_denselayer24_norm2,) {} call_module new_features_82 L__self___features_denseblock3_denselayer24_conv2 (l__self___features_denseblock3_denselayer24_relu2,) {} call_function cat_44 ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62, new_features_64, new_features_66, new_features_68, new_features_70, new_features_72, new_features_74, new_features_76, new_features_78, new_features_80, new_features_82], 1) {} call_module l__self___features_transition3_norm L__self___features_transition3_norm (cat_44,) {} call_module l__self___features_transition3_relu L__self___features_transition3_relu (l__self___features_transition3_norm,) {} call_module l__self___features_transition3_conv L__self___features_transition3_conv (l__self___features_transition3_relu,) {} call_module l__self___features_transition3_pool L__self___features_transition3_pool (l__self___features_transition3_conv,) {} call_function concated_features_42 ([l__self___features_transition3_pool], 1) {} call_module l__self___features_denseblock4_denselayer1_norm1 L__self___features_denseblock4_denselayer1_norm1 (concated_features_42,) {} call_module l__self___features_denseblock4_denselayer1_relu1 L__self___features_denseblock4_denselayer1_relu1 (l__self___features_denseblock4_denselayer1_norm1,) {} call_module bottleneck_output_84 L__self___features_denseblock4_denselayer1_conv1 (l__self___features_denseblock4_denselayer1_relu1,) {} call_module l__self___features_denseblock4_denselayer1_norm2 L__self___features_denseblock4_denselayer1_norm2 (bottleneck_output_84,) {} call_module l__self___features_denseblock4_denselayer1_relu2 L__self___features_denseblock4_denselayer1_relu2 (l__self___features_denseblock4_denselayer1_norm2,) {} call_module new_features_84 L__self___features_denseblock4_denselayer1_conv2 (l__self___features_denseblock4_denselayer1_relu2,) {} call_function concated_features_43 ([l__self___features_transition3_pool, new_features_84], 1) {} call_module l__self___features_denseblock4_denselayer2_norm1 L__self___features_denseblock4_denselayer2_norm1 (concated_features_43,) {} call_module l__self___features_denseblock4_denselayer2_relu1 L__self___features_denseblock4_denselayer2_relu1 (l__self___features_denseblock4_denselayer2_norm1,) {} call_module bottleneck_output_86 L__self___features_denseblock4_denselayer2_conv1 (l__self___features_denseblock4_denselayer2_relu1,) {} call_module l__self___features_denseblock4_denselayer2_norm2 L__self___features_denseblock4_denselayer2_norm2 (bottleneck_output_86,) {} call_module l__self___features_denseblock4_denselayer2_relu2 L__self___features_denseblock4_denselayer2_relu2 (l__self___features_denseblock4_denselayer2_norm2,) {} call_module new_features_86 L__self___features_denseblock4_denselayer2_conv2 (l__self___features_denseblock4_denselayer2_relu2,) {} call_function concated_features_44 ([l__self___features_transition3_pool, new_features_84, new_features_86], 1) {} call_module l__self___features_denseblock4_denselayer3_norm1 L__self___features_denseblock4_denselayer3_norm1 (concated_features_44,) {} call_module l__self___features_denseblock4_denselayer3_relu1 L__self___features_denseblock4_denselayer3_relu1 (l__self___features_denseblock4_denselayer3_norm1,) {} call_module bottleneck_output_88 L__self___features_denseblock4_denselayer3_conv1 (l__self___features_denseblock4_denselayer3_relu1,) {} call_module l__self___features_denseblock4_denselayer3_norm2 L__self___features_denseblock4_denselayer3_norm2 (bottleneck_output_88,) {} call_module l__self___features_denseblock4_denselayer3_relu2 L__self___features_denseblock4_denselayer3_relu2 (l__self___features_denseblock4_denselayer3_norm2,) {} call_module new_features_88 L__self___features_denseblock4_denselayer3_conv2 (l__self___features_denseblock4_denselayer3_relu2,) {} call_function concated_features_45 ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88], 1) {} call_module l__self___features_denseblock4_denselayer4_norm1 L__self___features_denseblock4_denselayer4_norm1 (concated_features_45,) {} call_module l__self___features_denseblock4_denselayer4_relu1 L__self___features_denseblock4_denselayer4_relu1 (l__self___features_denseblock4_denselayer4_norm1,) {} call_module bottleneck_output_90 L__self___features_denseblock4_denselayer4_conv1 (l__self___features_denseblock4_denselayer4_relu1,) {} call_module l__self___features_denseblock4_denselayer4_norm2 L__self___features_denseblock4_denselayer4_norm2 (bottleneck_output_90,) {} call_module l__self___features_denseblock4_denselayer4_relu2 L__self___features_denseblock4_denselayer4_relu2 (l__self___features_denseblock4_denselayer4_norm2,) {} call_module new_features_90 L__self___features_denseblock4_denselayer4_conv2 (l__self___features_denseblock4_denselayer4_relu2,) {} call_function concated_features_46 ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90], 1) {} call_module l__self___features_denseblock4_denselayer5_norm1 L__self___features_denseblock4_denselayer5_norm1 (concated_features_46,) {} call_module l__self___features_denseblock4_denselayer5_relu1 L__self___features_denseblock4_denselayer5_relu1 (l__self___features_denseblock4_denselayer5_norm1,) {} call_module bottleneck_output_92 L__self___features_denseblock4_denselayer5_conv1 (l__self___features_denseblock4_denselayer5_relu1,) {} call_module l__self___features_denseblock4_denselayer5_norm2 L__self___features_denseblock4_denselayer5_norm2 (bottleneck_output_92,) {} call_module l__self___features_denseblock4_denselayer5_relu2 L__self___features_denseblock4_denselayer5_relu2 (l__self___features_denseblock4_denselayer5_norm2,) {} call_module new_features_92 L__self___features_denseblock4_denselayer5_conv2 (l__self___features_denseblock4_denselayer5_relu2,) {} call_function concated_features_47 ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92], 1) {} call_module l__self___features_denseblock4_denselayer6_norm1 L__self___features_denseblock4_denselayer6_norm1 (concated_features_47,) {} call_module l__self___features_denseblock4_denselayer6_relu1 L__self___features_denseblock4_denselayer6_relu1 (l__self___features_denseblock4_denselayer6_norm1,) {} call_module bottleneck_output_94 L__self___features_denseblock4_denselayer6_conv1 (l__self___features_denseblock4_denselayer6_relu1,) {} call_module l__self___features_denseblock4_denselayer6_norm2 L__self___features_denseblock4_denselayer6_norm2 (bottleneck_output_94,) {} call_module l__self___features_denseblock4_denselayer6_relu2 L__self___features_denseblock4_denselayer6_relu2 (l__self___features_denseblock4_denselayer6_norm2,) {} call_module new_features_94 L__self___features_denseblock4_denselayer6_conv2 (l__self___features_denseblock4_denselayer6_relu2,) {} call_function concated_features_48 ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94], 1) {} call_module l__self___features_denseblock4_denselayer7_norm1 L__self___features_denseblock4_denselayer7_norm1 (concated_features_48,) {} call_module l__self___features_denseblock4_denselayer7_relu1 L__self___features_denseblock4_denselayer7_relu1 (l__self___features_denseblock4_denselayer7_norm1,) {} call_module bottleneck_output_96 L__self___features_denseblock4_denselayer7_conv1 (l__self___features_denseblock4_denselayer7_relu1,) {} call_module l__self___features_denseblock4_denselayer7_norm2 L__self___features_denseblock4_denselayer7_norm2 (bottleneck_output_96,) {} call_module l__self___features_denseblock4_denselayer7_relu2 L__self___features_denseblock4_denselayer7_relu2 (l__self___features_denseblock4_denselayer7_norm2,) {} call_module new_features_96 L__self___features_denseblock4_denselayer7_conv2 (l__self___features_denseblock4_denselayer7_relu2,) {} call_function concated_features_49 ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94, new_features_96], 1) {} call_module l__self___features_denseblock4_denselayer8_norm1 L__self___features_denseblock4_denselayer8_norm1 (concated_features_49,) {} call_module l__self___features_denseblock4_denselayer8_relu1 L__self___features_denseblock4_denselayer8_relu1 (l__self___features_denseblock4_denselayer8_norm1,) {} call_module bottleneck_output_98 L__self___features_denseblock4_denselayer8_conv1 (l__self___features_denseblock4_denselayer8_relu1,) {} call_module l__self___features_denseblock4_denselayer8_norm2 L__self___features_denseblock4_denselayer8_norm2 (bottleneck_output_98,) {} call_module l__self___features_denseblock4_denselayer8_relu2 L__self___features_denseblock4_denselayer8_relu2 (l__self___features_denseblock4_denselayer8_norm2,) {} call_module new_features_98 L__self___features_denseblock4_denselayer8_conv2 (l__self___features_denseblock4_denselayer8_relu2,) {} call_function concated_features_50 ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94, new_features_96, new_features_98], 1) {} call_module l__self___features_denseblock4_denselayer9_norm1 L__self___features_denseblock4_denselayer9_norm1 (concated_features_50,) {} call_module l__self___features_denseblock4_denselayer9_relu1 L__self___features_denseblock4_denselayer9_relu1 (l__self___features_denseblock4_denselayer9_norm1,) {} call_module bottleneck_output_100 L__self___features_denseblock4_denselayer9_conv1 (l__self___features_denseblock4_denselayer9_relu1,) {} call_module l__self___features_denseblock4_denselayer9_norm2 L__self___features_denseblock4_denselayer9_norm2 (bottleneck_output_100,) {} call_module l__self___features_denseblock4_denselayer9_relu2 L__self___features_denseblock4_denselayer9_relu2 (l__self___features_denseblock4_denselayer9_norm2,) {} call_module new_features_100 L__self___features_denseblock4_denselayer9_conv2 (l__self___features_denseblock4_denselayer9_relu2,) {} call_function concated_features_51 ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94, new_features_96, new_features_98, new_features_100], 1) {} call_module l__self___features_denseblock4_denselayer10_norm1 L__self___features_denseblock4_denselayer10_norm1 (concated_features_51,) {} call_module l__self___features_denseblock4_denselayer10_relu1 L__self___features_denseblock4_denselayer10_relu1 (l__self___features_denseblock4_denselayer10_norm1,) {} call_module bottleneck_output_102 L__self___features_denseblock4_denselayer10_conv1 (l__self___features_denseblock4_denselayer10_relu1,) {} call_module l__self___features_denseblock4_denselayer10_norm2 L__self___features_denseblock4_denselayer10_norm2 (bottleneck_output_102,) {} call_module l__self___features_denseblock4_denselayer10_relu2 L__self___features_denseblock4_denselayer10_relu2 (l__self___features_denseblock4_denselayer10_norm2,) {} call_module new_features_102 L__self___features_denseblock4_denselayer10_conv2 (l__self___features_denseblock4_denselayer10_relu2,) {} call_function concated_features_52 ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94, new_features_96, new_features_98, new_features_100, new_features_102], 1) {} call_module l__self___features_denseblock4_denselayer11_norm1 L__self___features_denseblock4_denselayer11_norm1 (concated_features_52,) {} call_module l__self___features_denseblock4_denselayer11_relu1 L__self___features_denseblock4_denselayer11_relu1 (l__self___features_denseblock4_denselayer11_norm1,) {} call_module bottleneck_output_104 L__self___features_denseblock4_denselayer11_conv1 (l__self___features_denseblock4_denselayer11_relu1,) {} call_module l__self___features_denseblock4_denselayer11_norm2 L__self___features_denseblock4_denselayer11_norm2 (bottleneck_output_104,) {} call_module l__self___features_denseblock4_denselayer11_relu2 L__self___features_denseblock4_denselayer11_relu2 (l__self___features_denseblock4_denselayer11_norm2,) {} call_module new_features_104 L__self___features_denseblock4_denselayer11_conv2 (l__self___features_denseblock4_denselayer11_relu2,) {} call_function concated_features_53 ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94, new_features_96, new_features_98, new_features_100, new_features_102, new_features_104], 1) {} call_module l__self___features_denseblock4_denselayer12_norm1 L__self___features_denseblock4_denselayer12_norm1 (concated_features_53,) {} call_module l__self___features_denseblock4_denselayer12_relu1 L__self___features_denseblock4_denselayer12_relu1 (l__self___features_denseblock4_denselayer12_norm1,) {} call_module bottleneck_output_106 L__self___features_denseblock4_denselayer12_conv1 (l__self___features_denseblock4_denselayer12_relu1,) {} call_module l__self___features_denseblock4_denselayer12_norm2 L__self___features_denseblock4_denselayer12_norm2 (bottleneck_output_106,) {} call_module l__self___features_denseblock4_denselayer12_relu2 L__self___features_denseblock4_denselayer12_relu2 (l__self___features_denseblock4_denselayer12_norm2,) {} call_module new_features_106 L__self___features_denseblock4_denselayer12_conv2 (l__self___features_denseblock4_denselayer12_relu2,) {} call_function concated_features_54 ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94, new_features_96, new_features_98, new_features_100, new_features_102, new_features_104, new_features_106], 1) {} call_module l__self___features_denseblock4_denselayer13_norm1 L__self___features_denseblock4_denselayer13_norm1 (concated_features_54,) {} call_module l__self___features_denseblock4_denselayer13_relu1 L__self___features_denseblock4_denselayer13_relu1 (l__self___features_denseblock4_denselayer13_norm1,) {} call_module bottleneck_output_108 L__self___features_denseblock4_denselayer13_conv1 (l__self___features_denseblock4_denselayer13_relu1,) {} call_module l__self___features_denseblock4_denselayer13_norm2 L__self___features_denseblock4_denselayer13_norm2 (bottleneck_output_108,) {} call_module l__self___features_denseblock4_denselayer13_relu2 L__self___features_denseblock4_denselayer13_relu2 (l__self___features_denseblock4_denselayer13_norm2,) {} call_module new_features_108 L__self___features_denseblock4_denselayer13_conv2 (l__self___features_denseblock4_denselayer13_relu2,) {} call_function concated_features_55 ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94, new_features_96, new_features_98, new_features_100, new_features_102, new_features_104, new_features_106, new_features_108], 1) {} call_module l__self___features_denseblock4_denselayer14_norm1 L__self___features_denseblock4_denselayer14_norm1 (concated_features_55,) {} call_module l__self___features_denseblock4_denselayer14_relu1 L__self___features_denseblock4_denselayer14_relu1 (l__self___features_denseblock4_denselayer14_norm1,) {} call_module bottleneck_output_110 L__self___features_denseblock4_denselayer14_conv1 (l__self___features_denseblock4_denselayer14_relu1,) {} call_module l__self___features_denseblock4_denselayer14_norm2 L__self___features_denseblock4_denselayer14_norm2 (bottleneck_output_110,) {} call_module l__self___features_denseblock4_denselayer14_relu2 L__self___features_denseblock4_denselayer14_relu2 (l__self___features_denseblock4_denselayer14_norm2,) {} call_module new_features_110 L__self___features_denseblock4_denselayer14_conv2 (l__self___features_denseblock4_denselayer14_relu2,) {} call_function concated_features_56 ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94, new_features_96, new_features_98, new_features_100, new_features_102, new_features_104, new_features_106, new_features_108, new_features_110], 1) {} call_module l__self___features_denseblock4_denselayer15_norm1 L__self___features_denseblock4_denselayer15_norm1 (concated_features_56,) {} call_module l__self___features_denseblock4_denselayer15_relu1 L__self___features_denseblock4_denselayer15_relu1 (l__self___features_denseblock4_denselayer15_norm1,) {} call_module bottleneck_output_112 L__self___features_denseblock4_denselayer15_conv1 (l__self___features_denseblock4_denselayer15_relu1,) {} call_module l__self___features_denseblock4_denselayer15_norm2 L__self___features_denseblock4_denselayer15_norm2 (bottleneck_output_112,) {} call_module l__self___features_denseblock4_denselayer15_relu2 L__self___features_denseblock4_denselayer15_relu2 (l__self___features_denseblock4_denselayer15_norm2,) {} call_module new_features_112 L__self___features_denseblock4_denselayer15_conv2 (l__self___features_denseblock4_denselayer15_relu2,) {} call_function concated_features_57 ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94, new_features_96, new_features_98, new_features_100, new_features_102, new_features_104, new_features_106, new_features_108, new_features_110, new_features_112], 1) {} call_module l__self___features_denseblock4_denselayer16_norm1 L__self___features_denseblock4_denselayer16_norm1 (concated_features_57,) {} call_module l__self___features_denseblock4_denselayer16_relu1 L__self___features_denseblock4_denselayer16_relu1 (l__self___features_denseblock4_denselayer16_norm1,) {} call_module bottleneck_output_114 L__self___features_denseblock4_denselayer16_conv1 (l__self___features_denseblock4_denselayer16_relu1,) {} call_module l__self___features_denseblock4_denselayer16_norm2 L__self___features_denseblock4_denselayer16_norm2 (bottleneck_output_114,) {} call_module l__self___features_denseblock4_denselayer16_relu2 L__self___features_denseblock4_denselayer16_relu2 (l__self___features_denseblock4_denselayer16_norm2,) {} call_module new_features_114 L__self___features_denseblock4_denselayer16_conv2 (l__self___features_denseblock4_denselayer16_relu2,) {} call_function cat_61 ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94, new_features_96, new_features_98, new_features_100, new_features_102, new_features_104, new_features_106, new_features_108, new_features_110, new_features_112, new_features_114], 1) {} call_module features L__self___features_norm5 (cat_61,) {} call_function out (features,) {'inplace': True} call_function out_1 (out, (1, 1)) {} call_function out_2 (out_1, 1) {} call_module out_3 L__self___classifier (out_2,) {} output output output ((out_3,),) {} tensor([[ 0.0614, -0.4023, -0.2792, ..., -0.5549, 0.0976, -0.0634], [-0.2032, -0.2706, -0.0935, ..., -0.4815, 0.0758, -0.1038], [ 0.0637, -0.3492, -0.1492, ..., -0.4841, 0.1776, -0.0723], ..., [-0.1050, -0.3393, 0.0092, ..., -0.4862, 0.0555, -0.1058], [ 0.0018, -0.2431, -0.1656, ..., -0.5072, 0.0977, -0.1387], [ 0.1192, -0.3563, -0.1147, ..., -0.4839, 0.1770, -0.0659]], device='cuda:0', grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 428-431 Using our custom backend, we can now see how TorchDynamo is able to handle data-dependent control flow. Consider the function below, where the line ``if b.sum() < 0`` is the source of data-dependent control flow. .. GENERATED FROM PYTHON SOURCE LINES 431-444 .. code-block:: default def bar(a, b): x = a / (torch.abs(a) + 1) if b.sum() < 0: b = b * -1 return x * b opt_bar = torch.compile(bar, backend=custom_backend) inp1 = torch.randn(10) inp2 = torch.randn(10) opt_bar(inp1, inp2) opt_bar(inp1, -inp2) .. rst-class:: sphx-glr-script-out .. code-block:: none custom backend called with FX graph: opcode name target args kwargs ------------- ------ ------------------------------------------------------ ----------- -------- placeholder l_a_ L_a_ () {} placeholder l_b_ L_b_ () {} call_function abs_1 (l_a_,) {} call_function add (abs_1, 1) {} call_function x (l_a_, add) {} call_method sum_1 sum (l_b_,) {} call_function lt (sum_1, 0) {} output output output ((x, lt),) {} custom backend called with FX graph: opcode name target args kwargs ------------- ------ ----------------------- ------------ -------- placeholder l_x_ L_x_ () {} placeholder l_b_ L_b_ () {} call_function mul (l_x_, l_b_) {} output output output ((mul,),) {} custom backend called with FX graph: opcode name target args kwargs ------------- ------ ----------------------- ----------- -------- placeholder l_b_ L_b_ () {} placeholder l_x_ L_x_ () {} call_function b (l_b_, -1) {} call_function mul_1 (l_x_, b) {} output output output ((mul_1,),) {} tensor([-0.0176, 1.0753, 0.0282, 0.0756, -0.0176, 0.0633, -0.9161, 0.1333, -0.1971, -0.3406]) .. GENERATED FROM PYTHON SOURCE LINES 445-469 The output reveals that TorchDynamo extracted 3 different FX graphs corresponding the following code (order may differ from the output above): 1. ``x = a / (torch.abs(a) + 1)`` 2. ``b = b * -1; return x * b`` 3. ``return x * b`` When TorchDynamo encounters unsupported Python features, such as data-dependent control flow, it breaks the computation graph, lets the default Python interpreter handle the unsupported code, then resumes capturing the graph. Let's investigate by example how TorchDynamo would step through ``bar``. If ``b.sum() < 0``, then TorchDynamo would run graph 1, let Python determine the result of the conditional, then run graph 2. On the other hand, if ``not b.sum() < 0``, then TorchDynamo would run graph 1, let Python determine the result of the conditional, then run graph 3. This highlights a major difference between TorchDynamo and previous PyTorch compiler solutions. When encountering unsupported Python features, previous solutions either raise an error or silently fail. TorchDynamo, on the other hand, will break the computation graph. We can see where TorchDynamo breaks the graph by using ``torch._dynamo.explain``: .. GENERATED FROM PYTHON SOURCE LINES 469-475 .. code-block:: default # Reset since we are using a different backend. torch._dynamo.reset() explain_output = torch._dynamo.explain(bar)(torch.randn(10), torch.randn(10)) print(explain_output) .. rst-class:: sphx-glr-script-out .. code-block:: none Graph Count: 2 Graph Break Count: 1 Op Count: 6 Break Reasons: Break Reason 1: Reason: generic_jump TensorVariable() User Stack: Ops per Graph: Ops 1: Ops 2: Out Guards: Guard 1: Name: '' Source: global Create Function: DETERMINISTIC_ALGORITHMS Guard Types: None Code List: None Object Weakref: None Guarded Class Weakref: None Guard 2: Name: "L['a']" Source: local Create Function: TENSOR_MATCH Guard Types: ['TENSOR_MATCH'] Code List: ["hasattr(L['a'], '_dynamo_dynamic_indices') == False"] Object Weakref: Guarded Class Weakref: Guard 3: Name: '' Source: global Create Function: TORCH_FUNCTION_STATE Guard Types: None Code List: None Object Weakref: None Guarded Class Weakref: None Guard 4: Name: '' Source: global Create Function: DEFAULT_DEVICE Guard Types: ['DEFAULT_DEVICE'] Code List: ['utils_device.CURRENT_DEVICE == None'] Object Weakref: None Guarded Class Weakref: None Guard 5: Name: '' Source: global Create Function: BACKEND_MATCH Guard Types: ['BACKEND_MATCH'] Code List: ['___check_current_backend(140681667747200)'] Object Weakref: None Guarded Class Weakref: None Guard 6: Name: "L['b']" Source: local Create Function: TENSOR_MATCH Guard Types: ['TENSOR_MATCH'] Code List: ["hasattr(L['b'], '_dynamo_dynamic_indices') == False"] Object Weakref: Guarded Class Weakref: Guard 7: Name: '' Source: global Create Function: GRAD_MODE Guard Types: None Code List: None Object Weakref: None Guarded Class Weakref: None Guard 8: Name: "G['torch']" Source: global Create Function: FUNCTION_MATCH Guard Types: None Code List: None Object Weakref: None Guarded Class Weakref: None Guard 9: Name: '' Source: shape_env Create Function: SHAPE_ENV Guard Types: None Code List: None Object Weakref: None Guarded Class Weakref: None Guard 10: Name: "G['torch'].abs" Source: global Create Function: FUNCTION_MATCH Guard Types: None Code List: None Object Weakref: None Guarded Class Weakref: None Guard 11: Name: '' Source: global Create Function: DETERMINISTIC_ALGORITHMS Guard Types: None Code List: None Object Weakref: None Guarded Class Weakref: None Guard 12: Name: "L['x']" Source: local Create Function: TENSOR_MATCH Guard Types: ['TENSOR_MATCH'] Code List: ["hasattr(L['x'], '_dynamo_dynamic_indices') == False"] Object Weakref: Guarded Class Weakref: Guard 13: Name: '' Source: global Create Function: TORCH_FUNCTION_STATE Guard Types: None Code List: None Object Weakref: None Guarded Class Weakref: None Guard 14: Name: '' Source: global Create Function: DEFAULT_DEVICE Guard Types: ['DEFAULT_DEVICE'] Code List: ['utils_device.CURRENT_DEVICE == None'] Object Weakref: None Guarded Class Weakref: None Guard 15: Name: '' Source: global Create Function: BACKEND_MATCH Guard Types: ['BACKEND_MATCH'] Code List: ['___check_current_backend(140681667747200)'] Object Weakref: None Guarded Class Weakref: None Guard 16: Name: '' Source: global Create Function: GRAD_MODE Guard Types: None Code List: None Object Weakref: None Guarded Class Weakref: None Guard 17: Name: "L['b']" Source: local Create Function: TENSOR_MATCH Guard Types: ['TENSOR_MATCH'] Code List: ["hasattr(L['b'], '_dynamo_dynamic_indices') == False"] Object Weakref: Guarded Class Weakref: Guard 18: Name: '' Source: shape_env Create Function: SHAPE_ENV Guard Types: None Code List: None Object Weakref: None Guarded Class Weakref: None Compile Times: TorchDynamo compilation metrics: Function Runtimes (s) ------------------------------- -------------- _compile..compile_inner 0.0112, 0.0068 OutputGraph.call_user_compiler 0.0001, 0.0000 .. GENERATED FROM PYTHON SOURCE LINES 476-479 In order to maximize speedup, graph breaks should be limited. We can force TorchDynamo to raise an error upon the first graph break encountered by using ``fullgraph=True``: .. GENERATED FROM PYTHON SOURCE LINES 479-486 .. code-block:: default opt_bar = torch.compile(bar, fullgraph=True) try: opt_bar(torch.randn(10), torch.randn(10)) except: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 482, in opt_bar(torch.randn(10), torch.randn(10)) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors return callback(frame, cache_entry, hooks, frame_state, skip=1) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert return _compile( File "/opt/conda/envs/py_3.10/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper r = func(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner out_code = transform_code_object(code, transform) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object transformations(instructions, code_options) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform tracer.run() File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run super().run() File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run and self.step() File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step getattr(self, inst.opname)(inst) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 464, in inner raise exc.UserError( torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands from user code: File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 434, in bar if b.sum() < 0: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information You can suppress this exception and fall back to eager by setting: import torch._dynamo torch._dynamo.config.suppress_errors = True .. GENERATED FROM PYTHON SOURCE LINES 487-489 And below, we demonstrate that TorchDynamo does not break the graph on the model we used above for demonstrating speedups. .. GENERATED FROM PYTHON SOURCE LINES 489-493 .. code-block:: default opt_model = torch.compile(init_model(), fullgraph=True) print(opt_model(generate_data(16)[0])) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[ 0.1344, 0.1978, -0.2056, ..., 0.1862, -0.1632, -0.1640], [ 0.2683, 0.2348, -0.1902, ..., 0.2169, -0.0055, 0.0310], [ 0.0164, 0.2181, -0.1114, ..., 0.1707, -0.1680, -0.0635], ..., [ 0.0806, 0.2679, -0.1889, ..., 0.0383, -0.2074, -0.1447], [-0.0207, 0.0858, -0.2457, ..., 0.1862, -0.1281, -0.0283], [-0.0386, 0.0730, -0.1961, ..., 0.0863, -0.2199, -0.1485]], device='cuda:0', grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 494-500 We can use ``torch.export`` (from PyTorch 2.1+) to extract a single, exportable FX graph from the input PyTorch program. The exported graph is intended to be run on different (i.e. Python-less) environments. One important restriction is that the ``torch.export`` does not support graph breaks. Please check `this tutorial `__ for more details on ``torch.export``. .. GENERATED FROM PYTHON SOURCE LINES 502-509 Conclusion ------------ In this tutorial, we introduced ``torch.compile`` by covering basic usage, demonstrating speedups over eager mode, comparing to previous PyTorch compiler solutions, and briefly investigating TorchDynamo and its interactions with FX graphs. We hope that you will give ``torch.compile`` a try! .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 6 minutes 10.052 seconds) .. _sphx_glr_download_intermediate_torch_compile_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: torch_compile_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: torch_compile_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_