Shortcuts

torch.compile Tutorial

Author: William Wen

torch.compile is the latest method to speed up your PyTorch code! torch.compile makes PyTorch code run faster by JIT-compiling PyTorch code into optimized kernels, all while requiring minimal code changes.

In this tutorial, we cover basic torch.compile usage, and demonstrate the advantages of torch.compile over previous PyTorch compiler solutions, such as TorchScript and FX Tracing.

Contents

Required pip Dependencies

  • torch >= 2.0

  • torchvision

  • numpy

  • scipy

  • tabulate

NOTE: a modern NVIDIA GPU (H100, A100, or V100) is recommended for this tutorial in order to reproduce the speedup numbers shown below and documented elsewhere.

import torch
import warnings

gpu_ok = False
if torch.cuda.is_available():
    device_cap = torch.cuda.get_device_capability()
    if device_cap in ((7, 0), (8, 0), (9, 0)):
        gpu_ok = True

if not gpu_ok:
    warnings.warn(
        "GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower "
        "than expected."
    )
/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py:48: UserWarning:

GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower than expected.

Basic Usage

torch.compile is included in the latest PyTorch. Running TorchInductor on GPU requires Triton, which is included with the PyTorch 2.0 nightly binary. If Triton is still missing, try installing torchtriton via pip (pip install torchtriton --extra-index-url "https://download.pytorch.org/whl/nightly/cu117" for CUDA 11.7).

Arbitrary Python functions can be optimized by passing the callable to torch.compile. We can then call the returned optimized function in place of the original function.

def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))
tensor([[ 1.6850,  1.9924,  1.7090,  0.0034,  1.1414, -0.1822,  0.4861, -0.0536,
         -0.2252,  1.9398],
        [ 0.3693, -0.0695,  0.1748,  0.3436,  0.1939,  1.5721,  1.9882, -0.2235,
          0.3161,  1.2642],
        [ 0.2480,  1.8793,  1.7152,  1.6772,  1.8881,  1.4748,  1.3466,  1.7763,
          0.7469,  1.0407],
        [-0.1121,  1.6015, -0.0188,  0.2128,  0.5218,  1.9838,  0.8185,  0.5093,
         -0.3603,  0.1793],
        [-1.7890,  1.7532, -0.4040,  0.1222, -0.0029,  1.7975, -0.3877,  0.5123,
          0.1673,  0.1330],
        [ 1.0627,  0.9609,  0.1019,  1.8814,  0.1142, -0.2338, -0.9621,  0.7631,
          0.6506,  0.1853],
        [ 0.4584,  1.7648, -0.0444,  1.9610,  1.5884,  0.7353,  1.2190,  1.3662,
          1.0938, -0.1587],
        [-0.7502,  1.6640,  0.3495,  1.3496,  0.8187,  1.1719,  0.5820,  0.1498,
          0.0885,  0.1036],
        [ 0.3961,  0.6043, -0.0861, -0.3371,  0.8622,  1.4341,  1.2988,  0.5023,
          0.3074,  0.1277],
        [ 0.9748,  0.4117,  1.2616,  1.6314,  0.4693,  0.4092,  0.0401,  1.1196,
          1.2458,  1.3280]])

Alternatively, we can decorate the function.

@torch.compile
def opt_foo2(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b
print(opt_foo2(torch.randn(10, 10), torch.randn(10, 10)))
tensor([[ 0.5360,  0.1697, -0.0561,  0.1890, -0.1310,  1.2276,  1.1739,  0.1944,
         -0.1561,  1.6990],
        [ 1.0421,  1.9472,  0.2682,  0.2701,  1.3346,  0.7651,  1.0897,  1.1730,
          0.6161,  0.9223],
        [ 1.5756,  1.5294,  0.0112, -0.1522, -0.7674,  1.8515, -0.2443,  0.3696,
          0.2693,  0.8735],
        [-0.3701,  1.1190,  1.4164,  1.8648,  1.2080,  0.0732,  1.5274,  0.6868,
          1.2440,  1.0715],
        [-1.2454, -0.0159,  0.4315,  0.1317,  1.0530, -1.0603, -0.0532,  0.6661,
          1.7101, -0.2076],
        [-0.7091,  0.7824,  1.7161,  1.2750,  0.6368,  1.2488,  0.4897,  1.2429,
          1.3409,  1.3735],
        [ 0.8345,  0.0653,  0.3462,  1.2383, -0.4092,  1.6438, -0.0962,  0.4011,
          0.2463, -0.5802],
        [ 1.6349,  0.7297,  1.2547, -0.3113,  0.9310,  0.1162,  1.7618,  0.4882,
          0.7640,  0.2930],
        [ 1.1669, -0.7775,  1.2000,  0.6008, -0.2814,  0.5541,  0.5753,  1.4731,
          1.6835,  0.7370],
        [ 1.5087,  0.6195,  0.1153,  1.2966,  1.8815,  1.1678,  1.5686,  1.6018,
          0.2193,  1.3500]])

We can also optimize torch.nn.Module instances.

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x):
        return torch.nn.functional.relu(self.lin(x))

mod = MyModule()
opt_mod = torch.compile(mod)
print(opt_mod(torch.randn(10, 100)))
tensor([[0.0000, 0.0000, 0.2419, 0.0446, 0.9011, 0.2674, 0.3633, 0.4984, 0.0000,
         0.0988],
        [0.6906, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8490, 0.0000, 0.0000,
         0.5475],
        [0.0852, 0.2762, 0.7441, 0.0000, 0.0000, 0.1820, 0.0000, 0.0000, 0.0000,
         0.0334],
        [0.3024, 0.0077, 1.2572, 0.0000, 0.0000, 0.6520, 0.0000, 0.0000, 0.0000,
         0.8976],
        [0.1998, 0.3333, 0.0000, 0.7803, 0.4202, 0.0915, 0.0000, 1.2543, 0.0000,
         0.4615],
        [0.2487, 0.4187, 0.0000, 0.0000, 0.5124, 0.0000, 0.2512, 0.0000, 0.5850,
         0.0000],
        [0.0000, 0.0048, 0.0000, 0.0000, 0.0000, 0.2287, 0.0000, 0.4841, 0.3915,
         0.0000],
        [0.2017, 0.0000, 0.0896, 1.4135, 0.0593, 0.3788, 0.0000, 0.0000, 0.0000,
         0.4972],
        [0.0000, 0.0000, 1.6580, 0.6414, 0.0000, 0.0000, 0.0000, 0.0000, 0.6491,
         0.7755],
        [0.0000, 0.0000, 0.6442, 0.0260, 0.7456, 0.1000, -0.0000, -0.0000, 0.5366,
         0.1193]], grad_fn=<CompiledFunctionBackward>)

Demonstrating Speedups

Let’s now demonstrate that using torch.compile can speed up real models. We will compare standard eager mode and torch.compile by evaluating and training a torchvision model on random data.

Before we start, we need to define some utility functions.

# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data(b):
    return (
        torch.randn(b, 3, 128, 128).to(torch.float32).cuda(),
        torch.randint(1000, (b,)).cuda(),
    )

N_ITERS = 10

from torchvision.models import densenet121
def init_model():
    return densenet121().to(torch.float32).cuda()

First, let’s compare inference.

Note that in the call to torch.compile, we have have the additional mode argument, which we will discuss below.

model = init_model()

# Reset since we are using a different mode.
import torch._dynamo
torch._dynamo.reset()

model_opt = torch.compile(model, mode="reduce-overhead")

inp = generate_data(16)[0]
with torch.no_grad():
    print("eager:", timed(lambda: model(inp))[1])
    print("compile:", timed(lambda: model_opt(inp))[1])
eager: 0.3460065307617187
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:135: UserWarning:

TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.

compile: 53.44630859375

Notice that torch.compile takes a lot longer to complete compared to eager. This is because torch.compile compiles the model into optimized kernels as it executes. In our example, the structure of the model doesn’t change, and so recompilation is not needed. So if we run our optimized model several more times, we should see a significant improvement compared to eager.

eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, eager_time = timed(lambda: model(inp))
    eager_times.append(eager_time)
    print(f"eager eval time {i}: {eager_time}")

print("~" * 10)

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, compile_time = timed(lambda: model_opt(inp))
    compile_times.append(compile_time)
    print(f"compile eval time {i}: {compile_time}")
print("~" * 10)

import numpy as np
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert(speedup > 1)
print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)
eager eval time 0: 0.017955839157104494
eager eval time 1: 0.01620479965209961
eager eval time 2: 0.016089120864868165
eager eval time 3: 0.016129024505615236
eager eval time 4: 0.01606553649902344
eager eval time 5: 0.016013311386108398
eager eval time 6: 0.015993887901306153
eager eval time 7: 0.01602764892578125
eager eval time 8: 0.016029695510864257
eager eval time 9: 0.016111616134643555
~~~~~~~~~~
compile eval time 0: 0.5039564819335938
compile eval time 1: 0.006875135898590088
compile eval time 2: 0.006702079772949219
compile eval time 3: 0.006713344097137451
compile eval time 4: 0.006708159923553467
compile eval time 5: 0.006700032234191895
compile eval time 6: 0.006719488143920899
compile eval time 7: 0.006708223819732666
compile eval time 8: 0.006716351985931396
compile eval time 9: 0.006706175804138184
~~~~~~~~~~
(eval) eager median: 0.0160773286819458, compile median: 0.006710783958435058, speedup: 2.395745233571042x
~~~~~~~~~~

And indeed, we can see that running our model with torch.compile results in a significant speedup. Speedup mainly comes from reducing Python overhead and GPU read/writes, and so the observed speedup may vary on factors such as model architecture and batch size. For example, if a model’s architecture is simple and the amount of data is large, then the bottleneck would be GPU compute and the observed speedup may be less significant.

You may also see different speedup results depending on the chosen mode argument. The "reduce-overhead" mode uses CUDA graphs to further reduce the overhead of Python. For your own models, you may need to experiment with different modes to maximize speedup. You can read more about modes here.

You may might also notice that the second time we run our model with torch.compile is significantly slower than the other runs, although it is much faster than the first run. This is because the "reduce-overhead" mode runs a few warm-up iterations for CUDA graphs.

For general PyTorch benchmarking, you can try using torch.utils.benchmark instead of the timed function we defined above. We wrote our own timing function in this tutorial to show torch.compile’s compilation latency.

Now, let’s consider comparing training.

model = init_model()
opt = torch.optim.Adam(model.parameters())

def train(mod, data):
    opt.zero_grad(True)
    pred = mod(data[0])
    loss = torch.nn.CrossEntropyLoss()(pred, data[1])
    loss.backward()
    opt.step()

eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, eager_time = timed(lambda: train(model, inp))
    eager_times.append(eager_time)
    print(f"eager train time {i}: {eager_time}")
print("~" * 10)

model = init_model()
opt = torch.optim.Adam(model.parameters())
train_opt = torch.compile(train, mode="reduce-overhead")

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, compile_time = timed(lambda: train_opt(model, inp))
    compile_times.append(compile_time)
    print(f"compile train time {i}: {compile_time}")
print("~" * 10)

eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert(speedup > 1)
print(f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)
eager train time 0: 0.37154815673828123
eager train time 1: 0.05470207977294922
eager train time 2: 0.05233667373657227
eager train time 3: 0.052013057708740235
eager train time 4: 0.05195161437988281
eager train time 5: 0.05251379013061523
eager train time 6: 0.052342784881591796
eager train time 7: 0.05218099212646484
eager train time 8: 0.052691967010498046
eager train time 9: 0.0521267204284668
~~~~~~~~~~
skipping cudagraphs due to input mutation
compile train time 0: 285.266125
compile train time 1: 2.988264404296875
compile train time 2: 0.04312063980102539
compile train time 3: 0.032029632568359376
compile train time 4: 0.03120742416381836
compile train time 5: 0.03134566307067871
compile train time 6: 0.031138816833496095
compile train time 7: 0.03116646385192871
compile train time 8: 0.031185920715332032
compile train time 9: 0.031178752899169923
~~~~~~~~~~
(train) eager median: 0.05233972930908203, compile median: 0.031276543617248534, speedup: 1.6734499166403245x
~~~~~~~~~~

Again, we can see that torch.compile takes longer in the first iteration, as it must compile the model, but in subsequent iterations, we see significant speedups compared to eager.

We remark that the speedup numbers presented in this tutorial are for demonstration purposes only. Official speedup values can be seen at the TorchInductor performance dashboard.

Comparison to TorchScript and FX Tracing

We have seen that torch.compile can speed up PyTorch code. Why else should we use torch.compile over existing PyTorch compiler solutions, such as TorchScript or FX Tracing? Primarily, the advantage of torch.compile lies in its ability to handle arbitrary Python code with minimal changes to existing code.

One case that torch.compile can handle that other compiler solutions struggle with is data-dependent control flow (the if x.sum() < 0: line below).

def f1(x, y):
    if x.sum() < 0:
        return -y
    return y

# Test that `fn1` and `fn2` return the same result, given
# the same arguments `args`. Typically, `fn1` will be an eager function
# while `fn2` will be a compiled function (torch.compile, TorchScript, or FX graph).
def test_fns(fn1, fn2, args):
    out1 = fn1(*args)
    out2 = fn2(*args)
    return torch.allclose(out1, out2)

inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)

TorchScript tracing f1 results in silently incorrect results, since only the actual control flow path is traced.

traced_f1 = torch.jit.trace(f1, (inp1, inp2))
print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2)))
print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2)))
/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py:274: TracerWarning:

Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

traced 1, 1: True
traced 1, 2: False

FX tracing f1 results in an error due to the presence of data-dependent control flow.

import traceback as tb
try:
    torch.fx.symbolic_trace(f1)
except:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 304, in <module>
    torch.fx.symbolic_trace(f1)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 1150, in symbolic_trace
    graph = tracer.trace(root, concrete_args)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 817, in trace
    (self.create_arg(fn(*args)),),
  File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 274, in f1
    if x.sum() < 0:
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 437, in __bool__
    return self.tracer.to_bool(self)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 300, in to_bool
    raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow

If we provide a value for x as we try to FX trace f1, then we run into the same problem as TorchScript tracing, as the data-dependent control flow is removed in the traced function.

fx_f1 = torch.fx.symbolic_trace(f1, concrete_args={"x": inp1})
print("fx 1, 1:", test_fns(f1, fx_f1, (inp1, inp2)))
print("fx 1, 2:", test_fns(f1, fx_f1, (-inp1, inp2)))
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py:634: UserWarning:

Was not able to add assertion to guarantee correct input x to specialized function. It is up to the user to make sure that your inputs match the inputs you specialized the function with.

fx 1, 1: True
fx 1, 2: False

Now we can see that torch.compile correctly handles data-dependent control flow.

# Reset since we are using a different mode.
torch._dynamo.reset()

compile_f1 = torch.compile(f1)
print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2)))
print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2)))
print("~" * 10)
compile 1, 1: True
compile 1, 2: True
~~~~~~~~~~

TorchScript scripting can handle data-dependent control flow, but this solution comes with its own set of problems. Namely, TorchScript scripting can require major code changes and will raise errors when unsupported Python is used.

In the example below, we forget TorchScript type annotations and we receive a TorchScript error because the input type for argument y, an int, does not match with the default argument type, torch.Tensor.

def f2(x, y):
    return x + y

inp1 = torch.randn(5, 5)
inp2 = 3

script_f2 = torch.jit.script(f2)
try:
    script_f2(inp1, inp2)
except:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 347, in <module>
    script_f2(inp1, inp2)
RuntimeError: f2() Expected a value of type 'Tensor (inferred)' for argument 'y' but instead found type 'int'.
Inferred 'y' to be of type 'Tensor' because it was not annotated with an explicit type.
Position: 1
Value: 3
Declaration: f2(Tensor x, Tensor y) -> Tensor
Cast error details: Unable to cast 3 to Tensor

However, torch.compile is easily able to handle f2.

compile_f2 = torch.compile(f2)
print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2)))
print("~" * 10)
compile 2: True
~~~~~~~~~~

Another case that torch.compile handles well compared to previous compilers solutions is the usage of non-PyTorch functions.

import scipy
def f3(x):
    x = x * 2
    x = scipy.fft.dct(x.numpy())
    x = torch.from_numpy(x)
    x = x * 2
    return x

TorchScript tracing treats results from non-PyTorch function calls as constants, and so our results can be silently wrong.

inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)
traced_f3 = torch.jit.trace(f3, (inp1,))
print("traced 3:", test_fns(f3, traced_f3, (inp2,)))
/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py:365: TracerWarning:

Converting a tensor to a NumPy array might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py:366: TracerWarning:

torch.from_numpy results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.

traced 3: False

TorchScript scripting and FX tracing disallow non-PyTorch function calls.

try:
    torch.jit.script(f3)
except:
    tb.print_exc()

try:
    torch.fx.symbolic_trace(f3)
except:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 383, in <module>
    torch.jit.script(f3)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/jit/_script.py", line 1381, in script
    fn = torch._C._jit_script_compile(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_jit_internal.py", line 1205, in _try_get_dispatched_fn
    return boolean_dispatched.get(fn)
  File "/opt/conda/envs/py_3.10/lib/python3.10/weakref.py", line 453, in get
    return self.data.get(ref(key),default)
TypeError: cannot create weak reference to 'uarray._Function' object
Traceback (most recent call last):
  File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 388, in <module>
    torch.fx.symbolic_trace(f3)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 1150, in symbolic_trace
    graph = tracer.trace(root, concrete_args)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 817, in trace
    (self.create_arg(fn(*args)),),
  File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 365, in f3
    x = scipy.fft.dct(x.numpy())
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/scipy/fft/_backend.py", line 25, in __ua_function__
    return fn(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/scipy/fft/_pocketfft/realtransforms.py", line 19, in _r2r
    tmp = _asfarray(x)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/scipy/fft/_pocketfft/helper.py", line 89, in _asfarray
    if x.dtype == np.float16:
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 542, in impl
    return tracer.create_proxy('call_function', target, args, kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 184, in create_proxy
    args_ = self.create_arg(args)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 385, in create_arg
    return super().create_arg(a)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 255, in create_arg
    return type(a)(self.create_arg(elem) for elem in a)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 255, in <genexpr>
    return type(a)(self.create_arg(elem) for elem in a)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 385, in create_arg
    return super().create_arg(a)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 291, in create_arg
    raise NotImplementedError(f"argument of type: {type(a)}")
NotImplementedError: argument of type: <class 'type'>

In comparison, torch.compile is easily able to handle the non-PyTorch function call.

compile_f3 = torch.compile(f3)
print("compile 3:", test_fns(f3, compile_f3, (inp2,)))
compile 3: True

TorchDynamo and FX Graphs

One important component of torch.compile is TorchDynamo. TorchDynamo is responsible for JIT compiling arbitrary Python code into FX graphs, which can then be further optimized. TorchDynamo extracts FX graphs by analyzing Python bytecode during runtime and detecting calls to PyTorch operations.

Normally, TorchInductor, another component of torch.compile, further compiles the FX graphs into optimized kernels, but TorchDynamo allows for different backends to be used. In order to inspect the FX graphs that TorchDynamo outputs, let us create a custom backend that outputs the FX graph and simply returns the graph’s unoptimized forward method.

from typing import List
def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("custom backend called with FX graph:")
    gm.graph.print_tabular()
    return gm.forward

# Reset since we are using a different backend.
torch._dynamo.reset()

opt_model = torch.compile(init_model(), backend=custom_backend)
opt_model(generate_data(16)[0])
custom backend called with FX graph:
opcode         name                                               target                                                      args                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       kwargs
-------------  -------------------------------------------------  ----------------------------------------------------------  ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------  -----------------
placeholder    l_x_                                               L_x_                                                        ()                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         {}
call_module    l__self___features_conv0                           L__self___features_conv0                                    (l_x_,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    {}
call_module    l__self___features_norm0                           L__self___features_norm0                                    (l__self___features_conv0,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_relu0                           L__self___features_relu0                                    (l__self___features_norm0,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_pool0                           L__self___features_pool0                                    (l__self___features_relu0,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  cat                                                <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_pool0], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            {}
call_module    l__self___features_denseblock1_denselayer1_norm1   L__self___features_denseblock1_denselayer1_norm1            (cat,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     {}
call_module    l__self___features_denseblock1_denselayer1_relu1   L__self___features_denseblock1_denselayer1_relu1            (l__self___features_denseblock1_denselayer1_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer1_conv1   L__self___features_denseblock1_denselayer1_conv1            (l__self___features_denseblock1_denselayer1_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer1_norm2   L__self___features_denseblock1_denselayer1_norm2            (l__self___features_denseblock1_denselayer1_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer1_relu2   L__self___features_denseblock1_denselayer1_relu2            (l__self___features_denseblock1_denselayer1_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer1_conv2   L__self___features_denseblock1_denselayer1_conv2            (l__self___features_denseblock1_denselayer1_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_1                                              <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_pool0, l__self___features_denseblock1_denselayer1_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          {}
call_module    l__self___features_denseblock1_denselayer2_norm1   L__self___features_denseblock1_denselayer2_norm1            (cat_1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   {}
call_module    l__self___features_denseblock1_denselayer2_relu1   L__self___features_denseblock1_denselayer2_relu1            (l__self___features_denseblock1_denselayer2_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer2_conv1   L__self___features_denseblock1_denselayer2_conv1            (l__self___features_denseblock1_denselayer2_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer2_norm2   L__self___features_denseblock1_denselayer2_norm2            (l__self___features_denseblock1_denselayer2_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer2_relu2   L__self___features_denseblock1_denselayer2_relu2            (l__self___features_denseblock1_denselayer2_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer2_conv2   L__self___features_denseblock1_denselayer2_conv2            (l__self___features_denseblock1_denselayer2_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_2                                              <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_pool0, l__self___features_denseblock1_denselayer1_conv2, l__self___features_denseblock1_denselayer2_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer3_norm1   L__self___features_denseblock1_denselayer3_norm1            (cat_2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   {}
call_module    l__self___features_denseblock1_denselayer3_relu1   L__self___features_denseblock1_denselayer3_relu1            (l__self___features_denseblock1_denselayer3_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer3_conv1   L__self___features_denseblock1_denselayer3_conv1            (l__self___features_denseblock1_denselayer3_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer3_norm2   L__self___features_denseblock1_denselayer3_norm2            (l__self___features_denseblock1_denselayer3_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer3_relu2   L__self___features_denseblock1_denselayer3_relu2            (l__self___features_denseblock1_denselayer3_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer3_conv2   L__self___features_denseblock1_denselayer3_conv2            (l__self___features_denseblock1_denselayer3_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_3                                              <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_pool0, l__self___features_denseblock1_denselayer1_conv2, l__self___features_denseblock1_denselayer2_conv2, l__self___features_denseblock1_denselayer3_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      {}
call_module    l__self___features_denseblock1_denselayer4_norm1   L__self___features_denseblock1_denselayer4_norm1            (cat_3,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   {}
call_module    l__self___features_denseblock1_denselayer4_relu1   L__self___features_denseblock1_denselayer4_relu1            (l__self___features_denseblock1_denselayer4_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer4_conv1   L__self___features_denseblock1_denselayer4_conv1            (l__self___features_denseblock1_denselayer4_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer4_norm2   L__self___features_denseblock1_denselayer4_norm2            (l__self___features_denseblock1_denselayer4_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer4_relu2   L__self___features_denseblock1_denselayer4_relu2            (l__self___features_denseblock1_denselayer4_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer4_conv2   L__self___features_denseblock1_denselayer4_conv2            (l__self___features_denseblock1_denselayer4_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_4                                              <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_pool0, l__self___features_denseblock1_denselayer1_conv2, l__self___features_denseblock1_denselayer2_conv2, l__self___features_denseblock1_denselayer3_conv2, l__self___features_denseblock1_denselayer4_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    {}
call_module    l__self___features_denseblock1_denselayer5_norm1   L__self___features_denseblock1_denselayer5_norm1            (cat_4,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   {}
call_module    l__self___features_denseblock1_denselayer5_relu1   L__self___features_denseblock1_denselayer5_relu1            (l__self___features_denseblock1_denselayer5_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer5_conv1   L__self___features_denseblock1_denselayer5_conv1            (l__self___features_denseblock1_denselayer5_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer5_norm2   L__self___features_denseblock1_denselayer5_norm2            (l__self___features_denseblock1_denselayer5_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer5_relu2   L__self___features_denseblock1_denselayer5_relu2            (l__self___features_denseblock1_denselayer5_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer5_conv2   L__self___features_denseblock1_denselayer5_conv2            (l__self___features_denseblock1_denselayer5_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_5                                              <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_pool0, l__self___features_denseblock1_denselayer1_conv2, l__self___features_denseblock1_denselayer2_conv2, l__self___features_denseblock1_denselayer3_conv2, l__self___features_denseblock1_denselayer4_conv2, l__self___features_denseblock1_denselayer5_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock1_denselayer6_norm1   L__self___features_denseblock1_denselayer6_norm1            (cat_5,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   {}
call_module    l__self___features_denseblock1_denselayer6_relu1   L__self___features_denseblock1_denselayer6_relu1            (l__self___features_denseblock1_denselayer6_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer6_conv1   L__self___features_denseblock1_denselayer6_conv1            (l__self___features_denseblock1_denselayer6_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer6_norm2   L__self___features_denseblock1_denselayer6_norm2            (l__self___features_denseblock1_denselayer6_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer6_relu2   L__self___features_denseblock1_denselayer6_relu2            (l__self___features_denseblock1_denselayer6_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock1_denselayer6_conv2   L__self___features_denseblock1_denselayer6_conv2            (l__self___features_denseblock1_denselayer6_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_6                                              <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_pool0, l__self___features_denseblock1_denselayer1_conv2, l__self___features_denseblock1_denselayer2_conv2, l__self___features_denseblock1_denselayer3_conv2, l__self___features_denseblock1_denselayer4_conv2, l__self___features_denseblock1_denselayer5_conv2, l__self___features_denseblock1_denselayer6_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_transition1_norm                L__self___features_transition1_norm                         (cat_6,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   {}
call_module    l__self___features_transition1_relu                L__self___features_transition1_relu                         (l__self___features_transition1_norm,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     {}
call_module    l__self___features_transition1_conv                L__self___features_transition1_conv                         (l__self___features_transition1_relu,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     {}
call_module    l__self___features_transition1_pool                L__self___features_transition1_pool                         (l__self___features_transition1_conv,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     {}
call_function  cat_7                                              <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition1_pool], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock2_denselayer1_norm1   L__self___features_denseblock2_denselayer1_norm1            (cat_7,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   {}
call_module    l__self___features_denseblock2_denselayer1_relu1   L__self___features_denseblock2_denselayer1_relu1            (l__self___features_denseblock2_denselayer1_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer1_conv1   L__self___features_denseblock2_denselayer1_conv1            (l__self___features_denseblock2_denselayer1_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer1_norm2   L__self___features_denseblock2_denselayer1_norm2            (l__self___features_denseblock2_denselayer1_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer1_relu2   L__self___features_denseblock2_denselayer1_relu2            (l__self___features_denseblock2_denselayer1_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer1_conv2   L__self___features_denseblock2_denselayer1_conv2            (l__self___features_denseblock2_denselayer1_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_8                                              <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition1_pool, l__self___features_denseblock2_denselayer1_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               {}
call_module    l__self___features_denseblock2_denselayer2_norm1   L__self___features_denseblock2_denselayer2_norm1            (cat_8,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   {}
call_module    l__self___features_denseblock2_denselayer2_relu1   L__self___features_denseblock2_denselayer2_relu1            (l__self___features_denseblock2_denselayer2_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer2_conv1   L__self___features_denseblock2_denselayer2_conv1            (l__self___features_denseblock2_denselayer2_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer2_norm2   L__self___features_denseblock2_denselayer2_norm2            (l__self___features_denseblock2_denselayer2_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer2_relu2   L__self___features_denseblock2_denselayer2_relu2            (l__self___features_denseblock2_denselayer2_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer2_conv2   L__self___features_denseblock2_denselayer2_conv2            (l__self___features_denseblock2_denselayer2_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_9                                              <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition1_pool, l__self___features_denseblock2_denselayer1_conv2, l__self___features_denseblock2_denselayer2_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock2_denselayer3_norm1   L__self___features_denseblock2_denselayer3_norm1            (cat_9,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   {}
call_module    l__self___features_denseblock2_denselayer3_relu1   L__self___features_denseblock2_denselayer3_relu1            (l__self___features_denseblock2_denselayer3_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer3_conv1   L__self___features_denseblock2_denselayer3_conv1            (l__self___features_denseblock2_denselayer3_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer3_norm2   L__self___features_denseblock2_denselayer3_norm2            (l__self___features_denseblock2_denselayer3_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer3_relu2   L__self___features_denseblock2_denselayer3_relu2            (l__self___features_denseblock2_denselayer3_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer3_conv2   L__self___features_denseblock2_denselayer3_conv2            (l__self___features_denseblock2_denselayer3_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_10                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition1_pool, l__self___features_denseblock2_denselayer1_conv2, l__self___features_denseblock2_denselayer2_conv2, l__self___features_denseblock2_denselayer3_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           {}
call_module    l__self___features_denseblock2_denselayer4_norm1   L__self___features_denseblock2_denselayer4_norm1            (cat_10,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock2_denselayer4_relu1   L__self___features_denseblock2_denselayer4_relu1            (l__self___features_denseblock2_denselayer4_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer4_conv1   L__self___features_denseblock2_denselayer4_conv1            (l__self___features_denseblock2_denselayer4_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer4_norm2   L__self___features_denseblock2_denselayer4_norm2            (l__self___features_denseblock2_denselayer4_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer4_relu2   L__self___features_denseblock2_denselayer4_relu2            (l__self___features_denseblock2_denselayer4_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer4_conv2   L__self___features_denseblock2_denselayer4_conv2            (l__self___features_denseblock2_denselayer4_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_11                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition1_pool, l__self___features_denseblock2_denselayer1_conv2, l__self___features_denseblock2_denselayer2_conv2, l__self___features_denseblock2_denselayer3_conv2, l__self___features_denseblock2_denselayer4_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         {}
call_module    l__self___features_denseblock2_denselayer5_norm1   L__self___features_denseblock2_denselayer5_norm1            (cat_11,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock2_denselayer5_relu1   L__self___features_denseblock2_denselayer5_relu1            (l__self___features_denseblock2_denselayer5_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer5_conv1   L__self___features_denseblock2_denselayer5_conv1            (l__self___features_denseblock2_denselayer5_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer5_norm2   L__self___features_denseblock2_denselayer5_norm2            (l__self___features_denseblock2_denselayer5_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer5_relu2   L__self___features_denseblock2_denselayer5_relu2            (l__self___features_denseblock2_denselayer5_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer5_conv2   L__self___features_denseblock2_denselayer5_conv2            (l__self___features_denseblock2_denselayer5_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_12                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition1_pool, l__self___features_denseblock2_denselayer1_conv2, l__self___features_denseblock2_denselayer2_conv2, l__self___features_denseblock2_denselayer3_conv2, l__self___features_denseblock2_denselayer4_conv2, l__self___features_denseblock2_denselayer5_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock2_denselayer6_norm1   L__self___features_denseblock2_denselayer6_norm1            (cat_12,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock2_denselayer6_relu1   L__self___features_denseblock2_denselayer6_relu1            (l__self___features_denseblock2_denselayer6_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer6_conv1   L__self___features_denseblock2_denselayer6_conv1            (l__self___features_denseblock2_denselayer6_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer6_norm2   L__self___features_denseblock2_denselayer6_norm2            (l__self___features_denseblock2_denselayer6_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer6_relu2   L__self___features_denseblock2_denselayer6_relu2            (l__self___features_denseblock2_denselayer6_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer6_conv2   L__self___features_denseblock2_denselayer6_conv2            (l__self___features_denseblock2_denselayer6_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_13                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition1_pool, l__self___features_denseblock2_denselayer1_conv2, l__self___features_denseblock2_denselayer2_conv2, l__self___features_denseblock2_denselayer3_conv2, l__self___features_denseblock2_denselayer4_conv2, l__self___features_denseblock2_denselayer5_conv2, l__self___features_denseblock2_denselayer6_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     {}
call_module    l__self___features_denseblock2_denselayer7_norm1   L__self___features_denseblock2_denselayer7_norm1            (cat_13,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock2_denselayer7_relu1   L__self___features_denseblock2_denselayer7_relu1            (l__self___features_denseblock2_denselayer7_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer7_conv1   L__self___features_denseblock2_denselayer7_conv1            (l__self___features_denseblock2_denselayer7_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer7_norm2   L__self___features_denseblock2_denselayer7_norm2            (l__self___features_denseblock2_denselayer7_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer7_relu2   L__self___features_denseblock2_denselayer7_relu2            (l__self___features_denseblock2_denselayer7_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer7_conv2   L__self___features_denseblock2_denselayer7_conv2            (l__self___features_denseblock2_denselayer7_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_14                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition1_pool, l__self___features_denseblock2_denselayer1_conv2, l__self___features_denseblock2_denselayer2_conv2, l__self___features_denseblock2_denselayer3_conv2, l__self___features_denseblock2_denselayer4_conv2, l__self___features_denseblock2_denselayer5_conv2, l__self___features_denseblock2_denselayer6_conv2, l__self___features_denseblock2_denselayer7_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   {}
call_module    l__self___features_denseblock2_denselayer8_norm1   L__self___features_denseblock2_denselayer8_norm1            (cat_14,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock2_denselayer8_relu1   L__self___features_denseblock2_denselayer8_relu1            (l__self___features_denseblock2_denselayer8_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer8_conv1   L__self___features_denseblock2_denselayer8_conv1            (l__self___features_denseblock2_denselayer8_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer8_norm2   L__self___features_denseblock2_denselayer8_norm2            (l__self___features_denseblock2_denselayer8_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer8_relu2   L__self___features_denseblock2_denselayer8_relu2            (l__self___features_denseblock2_denselayer8_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer8_conv2   L__self___features_denseblock2_denselayer8_conv2            (l__self___features_denseblock2_denselayer8_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_15                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition1_pool, l__self___features_denseblock2_denselayer1_conv2, l__self___features_denseblock2_denselayer2_conv2, l__self___features_denseblock2_denselayer3_conv2, l__self___features_denseblock2_denselayer4_conv2, l__self___features_denseblock2_denselayer5_conv2, l__self___features_denseblock2_denselayer6_conv2, l__self___features_denseblock2_denselayer7_conv2, l__self___features_denseblock2_denselayer8_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock2_denselayer9_norm1   L__self___features_denseblock2_denselayer9_norm1            (cat_15,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock2_denselayer9_relu1   L__self___features_denseblock2_denselayer9_relu1            (l__self___features_denseblock2_denselayer9_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer9_conv1   L__self___features_denseblock2_denselayer9_conv1            (l__self___features_denseblock2_denselayer9_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer9_norm2   L__self___features_denseblock2_denselayer9_norm2            (l__self___features_denseblock2_denselayer9_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer9_relu2   L__self___features_denseblock2_denselayer9_relu2            (l__self___features_denseblock2_denselayer9_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer9_conv2   L__self___features_denseblock2_denselayer9_conv2            (l__self___features_denseblock2_denselayer9_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_16                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition1_pool, l__self___features_denseblock2_denselayer1_conv2, l__self___features_denseblock2_denselayer2_conv2, l__self___features_denseblock2_denselayer3_conv2, l__self___features_denseblock2_denselayer4_conv2, l__self___features_denseblock2_denselayer5_conv2, l__self___features_denseblock2_denselayer6_conv2, l__self___features_denseblock2_denselayer7_conv2, l__self___features_denseblock2_denselayer8_conv2, l__self___features_denseblock2_denselayer9_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               {}
call_module    l__self___features_denseblock2_denselayer10_norm1  L__self___features_denseblock2_denselayer10_norm1           (cat_16,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock2_denselayer10_relu1  L__self___features_denseblock2_denselayer10_relu1           (l__self___features_denseblock2_denselayer10_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock2_denselayer10_conv1  L__self___features_denseblock2_denselayer10_conv1           (l__self___features_denseblock2_denselayer10_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock2_denselayer10_norm2  L__self___features_denseblock2_denselayer10_norm2           (l__self___features_denseblock2_denselayer10_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock2_denselayer10_relu2  L__self___features_denseblock2_denselayer10_relu2           (l__self___features_denseblock2_denselayer10_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock2_denselayer10_conv2  L__self___features_denseblock2_denselayer10_conv2           (l__self___features_denseblock2_denselayer10_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_17                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition1_pool, l__self___features_denseblock2_denselayer1_conv2, l__self___features_denseblock2_denselayer2_conv2, l__self___features_denseblock2_denselayer3_conv2, l__self___features_denseblock2_denselayer4_conv2, l__self___features_denseblock2_denselayer5_conv2, l__self___features_denseblock2_denselayer6_conv2, l__self___features_denseblock2_denselayer7_conv2, l__self___features_denseblock2_denselayer8_conv2, l__self___features_denseblock2_denselayer9_conv2, l__self___features_denseblock2_denselayer10_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            {}
call_module    l__self___features_denseblock2_denselayer11_norm1  L__self___features_denseblock2_denselayer11_norm1           (cat_17,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock2_denselayer11_relu1  L__self___features_denseblock2_denselayer11_relu1           (l__self___features_denseblock2_denselayer11_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock2_denselayer11_conv1  L__self___features_denseblock2_denselayer11_conv1           (l__self___features_denseblock2_denselayer11_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock2_denselayer11_norm2  L__self___features_denseblock2_denselayer11_norm2           (l__self___features_denseblock2_denselayer11_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock2_denselayer11_relu2  L__self___features_denseblock2_denselayer11_relu2           (l__self___features_denseblock2_denselayer11_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock2_denselayer11_conv2  L__self___features_denseblock2_denselayer11_conv2           (l__self___features_denseblock2_denselayer11_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_18                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition1_pool, l__self___features_denseblock2_denselayer1_conv2, l__self___features_denseblock2_denselayer2_conv2, l__self___features_denseblock2_denselayer3_conv2, l__self___features_denseblock2_denselayer4_conv2, l__self___features_denseblock2_denselayer5_conv2, l__self___features_denseblock2_denselayer6_conv2, l__self___features_denseblock2_denselayer7_conv2, l__self___features_denseblock2_denselayer8_conv2, l__self___features_denseblock2_denselayer9_conv2, l__self___features_denseblock2_denselayer10_conv2, l__self___features_denseblock2_denselayer11_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         {}
call_module    l__self___features_denseblock2_denselayer12_norm1  L__self___features_denseblock2_denselayer12_norm1           (cat_18,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock2_denselayer12_relu1  L__self___features_denseblock2_denselayer12_relu1           (l__self___features_denseblock2_denselayer12_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock2_denselayer12_conv1  L__self___features_denseblock2_denselayer12_conv1           (l__self___features_denseblock2_denselayer12_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock2_denselayer12_norm2  L__self___features_denseblock2_denselayer12_norm2           (l__self___features_denseblock2_denselayer12_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock2_denselayer12_relu2  L__self___features_denseblock2_denselayer12_relu2           (l__self___features_denseblock2_denselayer12_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock2_denselayer12_conv2  L__self___features_denseblock2_denselayer12_conv2           (l__self___features_denseblock2_denselayer12_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_19                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition1_pool, l__self___features_denseblock2_denselayer1_conv2, l__self___features_denseblock2_denselayer2_conv2, l__self___features_denseblock2_denselayer3_conv2, l__self___features_denseblock2_denselayer4_conv2, l__self___features_denseblock2_denselayer5_conv2, l__self___features_denseblock2_denselayer6_conv2, l__self___features_denseblock2_denselayer7_conv2, l__self___features_denseblock2_denselayer8_conv2, l__self___features_denseblock2_denselayer9_conv2, l__self___features_denseblock2_denselayer10_conv2, l__self___features_denseblock2_denselayer11_conv2, l__self___features_denseblock2_denselayer12_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      {}
call_module    l__self___features_transition2_norm                L__self___features_transition2_norm                         (cat_19,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_transition2_relu                L__self___features_transition2_relu                         (l__self___features_transition2_norm,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     {}
call_module    l__self___features_transition2_conv                L__self___features_transition2_conv                         (l__self___features_transition2_relu,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     {}
call_module    l__self___features_transition2_pool                L__self___features_transition2_pool                         (l__self___features_transition2_conv,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     {}
call_function  cat_20                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock3_denselayer1_norm1   L__self___features_denseblock3_denselayer1_norm1            (cat_20,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer1_relu1   L__self___features_denseblock3_denselayer1_relu1            (l__self___features_denseblock3_denselayer1_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer1_conv1   L__self___features_denseblock3_denselayer1_conv1            (l__self___features_denseblock3_denselayer1_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer1_norm2   L__self___features_denseblock3_denselayer1_norm2            (l__self___features_denseblock3_denselayer1_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer1_relu2   L__self___features_denseblock3_denselayer1_relu2            (l__self___features_denseblock3_denselayer1_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer1_conv2   L__self___features_denseblock3_denselayer1_conv2            (l__self___features_denseblock3_denselayer1_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_21                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               {}
call_module    l__self___features_denseblock3_denselayer2_norm1   L__self___features_denseblock3_denselayer2_norm1            (cat_21,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer2_relu1   L__self___features_denseblock3_denselayer2_relu1            (l__self___features_denseblock3_denselayer2_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer2_conv1   L__self___features_denseblock3_denselayer2_conv1            (l__self___features_denseblock3_denselayer2_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer2_norm2   L__self___features_denseblock3_denselayer2_norm2            (l__self___features_denseblock3_denselayer2_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer2_relu2   L__self___features_denseblock3_denselayer2_relu2            (l__self___features_denseblock3_denselayer2_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer2_conv2   L__self___features_denseblock3_denselayer2_conv2            (l__self___features_denseblock3_denselayer2_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_22                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer3_norm1   L__self___features_denseblock3_denselayer3_norm1            (cat_22,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer3_relu1   L__self___features_denseblock3_denselayer3_relu1            (l__self___features_denseblock3_denselayer3_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer3_conv1   L__self___features_denseblock3_denselayer3_conv1            (l__self___features_denseblock3_denselayer3_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer3_norm2   L__self___features_denseblock3_denselayer3_norm2            (l__self___features_denseblock3_denselayer3_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer3_relu2   L__self___features_denseblock3_denselayer3_relu2            (l__self___features_denseblock3_denselayer3_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer3_conv2   L__self___features_denseblock3_denselayer3_conv2            (l__self___features_denseblock3_denselayer3_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_23                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           {}
call_module    l__self___features_denseblock3_denselayer4_norm1   L__self___features_denseblock3_denselayer4_norm1            (cat_23,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer4_relu1   L__self___features_denseblock3_denselayer4_relu1            (l__self___features_denseblock3_denselayer4_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer4_conv1   L__self___features_denseblock3_denselayer4_conv1            (l__self___features_denseblock3_denselayer4_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer4_norm2   L__self___features_denseblock3_denselayer4_norm2            (l__self___features_denseblock3_denselayer4_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer4_relu2   L__self___features_denseblock3_denselayer4_relu2            (l__self___features_denseblock3_denselayer4_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer4_conv2   L__self___features_denseblock3_denselayer4_conv2            (l__self___features_denseblock3_denselayer4_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_24                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2, l__self___features_denseblock3_denselayer4_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         {}
call_module    l__self___features_denseblock3_denselayer5_norm1   L__self___features_denseblock3_denselayer5_norm1            (cat_24,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer5_relu1   L__self___features_denseblock3_denselayer5_relu1            (l__self___features_denseblock3_denselayer5_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer5_conv1   L__self___features_denseblock3_denselayer5_conv1            (l__self___features_denseblock3_denselayer5_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer5_norm2   L__self___features_denseblock3_denselayer5_norm2            (l__self___features_denseblock3_denselayer5_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer5_relu2   L__self___features_denseblock3_denselayer5_relu2            (l__self___features_denseblock3_denselayer5_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer5_conv2   L__self___features_denseblock3_denselayer5_conv2            (l__self___features_denseblock3_denselayer5_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_25                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2, l__self___features_denseblock3_denselayer4_conv2, l__self___features_denseblock3_denselayer5_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer6_norm1   L__self___features_denseblock3_denselayer6_norm1            (cat_25,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer6_relu1   L__self___features_denseblock3_denselayer6_relu1            (l__self___features_denseblock3_denselayer6_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer6_conv1   L__self___features_denseblock3_denselayer6_conv1            (l__self___features_denseblock3_denselayer6_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer6_norm2   L__self___features_denseblock3_denselayer6_norm2            (l__self___features_denseblock3_denselayer6_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer6_relu2   L__self___features_denseblock3_denselayer6_relu2            (l__self___features_denseblock3_denselayer6_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer6_conv2   L__self___features_denseblock3_denselayer6_conv2            (l__self___features_denseblock3_denselayer6_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_26                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2, l__self___features_denseblock3_denselayer4_conv2, l__self___features_denseblock3_denselayer5_conv2, l__self___features_denseblock3_denselayer6_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     {}
call_module    l__self___features_denseblock3_denselayer7_norm1   L__self___features_denseblock3_denselayer7_norm1            (cat_26,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer7_relu1   L__self___features_denseblock3_denselayer7_relu1            (l__self___features_denseblock3_denselayer7_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer7_conv1   L__self___features_denseblock3_denselayer7_conv1            (l__self___features_denseblock3_denselayer7_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer7_norm2   L__self___features_denseblock3_denselayer7_norm2            (l__self___features_denseblock3_denselayer7_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer7_relu2   L__self___features_denseblock3_denselayer7_relu2            (l__self___features_denseblock3_denselayer7_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer7_conv2   L__self___features_denseblock3_denselayer7_conv2            (l__self___features_denseblock3_denselayer7_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_27                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2, l__self___features_denseblock3_denselayer4_conv2, l__self___features_denseblock3_denselayer5_conv2, l__self___features_denseblock3_denselayer6_conv2, l__self___features_denseblock3_denselayer7_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   {}
call_module    l__self___features_denseblock3_denselayer8_norm1   L__self___features_denseblock3_denselayer8_norm1            (cat_27,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer8_relu1   L__self___features_denseblock3_denselayer8_relu1            (l__self___features_denseblock3_denselayer8_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer8_conv1   L__self___features_denseblock3_denselayer8_conv1            (l__self___features_denseblock3_denselayer8_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer8_norm2   L__self___features_denseblock3_denselayer8_norm2            (l__self___features_denseblock3_denselayer8_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer8_relu2   L__self___features_denseblock3_denselayer8_relu2            (l__self___features_denseblock3_denselayer8_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer8_conv2   L__self___features_denseblock3_denselayer8_conv2            (l__self___features_denseblock3_denselayer8_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_28                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2, l__self___features_denseblock3_denselayer4_conv2, l__self___features_denseblock3_denselayer5_conv2, l__self___features_denseblock3_denselayer6_conv2, l__self___features_denseblock3_denselayer7_conv2, l__self___features_denseblock3_denselayer8_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock3_denselayer9_norm1   L__self___features_denseblock3_denselayer9_norm1            (cat_28,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer9_relu1   L__self___features_denseblock3_denselayer9_relu1            (l__self___features_denseblock3_denselayer9_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer9_conv1   L__self___features_denseblock3_denselayer9_conv1            (l__self___features_denseblock3_denselayer9_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer9_norm2   L__self___features_denseblock3_denselayer9_norm2            (l__self___features_denseblock3_denselayer9_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer9_relu2   L__self___features_denseblock3_denselayer9_relu2            (l__self___features_denseblock3_denselayer9_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer9_conv2   L__self___features_denseblock3_denselayer9_conv2            (l__self___features_denseblock3_denselayer9_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_29                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2, l__self___features_denseblock3_denselayer4_conv2, l__self___features_denseblock3_denselayer5_conv2, l__self___features_denseblock3_denselayer6_conv2, l__self___features_denseblock3_denselayer7_conv2, l__self___features_denseblock3_denselayer8_conv2, l__self___features_denseblock3_denselayer9_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               {}
call_module    l__self___features_denseblock3_denselayer10_norm1  L__self___features_denseblock3_denselayer10_norm1           (cat_29,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer10_relu1  L__self___features_denseblock3_denselayer10_relu1           (l__self___features_denseblock3_denselayer10_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer10_conv1  L__self___features_denseblock3_denselayer10_conv1           (l__self___features_denseblock3_denselayer10_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer10_norm2  L__self___features_denseblock3_denselayer10_norm2           (l__self___features_denseblock3_denselayer10_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer10_relu2  L__self___features_denseblock3_denselayer10_relu2           (l__self___features_denseblock3_denselayer10_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer10_conv2  L__self___features_denseblock3_denselayer10_conv2           (l__self___features_denseblock3_denselayer10_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_30                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2, l__self___features_denseblock3_denselayer4_conv2, l__self___features_denseblock3_denselayer5_conv2, l__self___features_denseblock3_denselayer6_conv2, l__self___features_denseblock3_denselayer7_conv2, l__self___features_denseblock3_denselayer8_conv2, l__self___features_denseblock3_denselayer9_conv2, l__self___features_denseblock3_denselayer10_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            {}
call_module    l__self___features_denseblock3_denselayer11_norm1  L__self___features_denseblock3_denselayer11_norm1           (cat_30,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer11_relu1  L__self___features_denseblock3_denselayer11_relu1           (l__self___features_denseblock3_denselayer11_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer11_conv1  L__self___features_denseblock3_denselayer11_conv1           (l__self___features_denseblock3_denselayer11_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer11_norm2  L__self___features_denseblock3_denselayer11_norm2           (l__self___features_denseblock3_denselayer11_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer11_relu2  L__self___features_denseblock3_denselayer11_relu2           (l__self___features_denseblock3_denselayer11_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer11_conv2  L__self___features_denseblock3_denselayer11_conv2           (l__self___features_denseblock3_denselayer11_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_31                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2, l__self___features_denseblock3_denselayer4_conv2, l__self___features_denseblock3_denselayer5_conv2, l__self___features_denseblock3_denselayer6_conv2, l__self___features_denseblock3_denselayer7_conv2, l__self___features_denseblock3_denselayer8_conv2, l__self___features_denseblock3_denselayer9_conv2, l__self___features_denseblock3_denselayer10_conv2, l__self___features_denseblock3_denselayer11_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         {}
call_module    l__self___features_denseblock3_denselayer12_norm1  L__self___features_denseblock3_denselayer12_norm1           (cat_31,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer12_relu1  L__self___features_denseblock3_denselayer12_relu1           (l__self___features_denseblock3_denselayer12_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer12_conv1  L__self___features_denseblock3_denselayer12_conv1           (l__self___features_denseblock3_denselayer12_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer12_norm2  L__self___features_denseblock3_denselayer12_norm2           (l__self___features_denseblock3_denselayer12_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer12_relu2  L__self___features_denseblock3_denselayer12_relu2           (l__self___features_denseblock3_denselayer12_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer12_conv2  L__self___features_denseblock3_denselayer12_conv2           (l__self___features_denseblock3_denselayer12_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_32                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2, l__self___features_denseblock3_denselayer4_conv2, l__self___features_denseblock3_denselayer5_conv2, l__self___features_denseblock3_denselayer6_conv2, l__self___features_denseblock3_denselayer7_conv2, l__self___features_denseblock3_denselayer8_conv2, l__self___features_denseblock3_denselayer9_conv2, l__self___features_denseblock3_denselayer10_conv2, l__self___features_denseblock3_denselayer11_conv2, l__self___features_denseblock3_denselayer12_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      {}
call_module    l__self___features_denseblock3_denselayer13_norm1  L__self___features_denseblock3_denselayer13_norm1           (cat_32,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer13_relu1  L__self___features_denseblock3_denselayer13_relu1           (l__self___features_denseblock3_denselayer13_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer13_conv1  L__self___features_denseblock3_denselayer13_conv1           (l__self___features_denseblock3_denselayer13_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer13_norm2  L__self___features_denseblock3_denselayer13_norm2           (l__self___features_denseblock3_denselayer13_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer13_relu2  L__self___features_denseblock3_denselayer13_relu2           (l__self___features_denseblock3_denselayer13_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer13_conv2  L__self___features_denseblock3_denselayer13_conv2           (l__self___features_denseblock3_denselayer13_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_33                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2, l__self___features_denseblock3_denselayer4_conv2, l__self___features_denseblock3_denselayer5_conv2, l__self___features_denseblock3_denselayer6_conv2, l__self___features_denseblock3_denselayer7_conv2, l__self___features_denseblock3_denselayer8_conv2, l__self___features_denseblock3_denselayer9_conv2, l__self___features_denseblock3_denselayer10_conv2, l__self___features_denseblock3_denselayer11_conv2, l__self___features_denseblock3_denselayer12_conv2, l__self___features_denseblock3_denselayer13_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   {}
call_module    l__self___features_denseblock3_denselayer14_norm1  L__self___features_denseblock3_denselayer14_norm1           (cat_33,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer14_relu1  L__self___features_denseblock3_denselayer14_relu1           (l__self___features_denseblock3_denselayer14_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer14_conv1  L__self___features_denseblock3_denselayer14_conv1           (l__self___features_denseblock3_denselayer14_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer14_norm2  L__self___features_denseblock3_denselayer14_norm2           (l__self___features_denseblock3_denselayer14_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer14_relu2  L__self___features_denseblock3_denselayer14_relu2           (l__self___features_denseblock3_denselayer14_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer14_conv2  L__self___features_denseblock3_denselayer14_conv2           (l__self___features_denseblock3_denselayer14_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_34                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2, l__self___features_denseblock3_denselayer4_conv2, l__self___features_denseblock3_denselayer5_conv2, l__self___features_denseblock3_denselayer6_conv2, l__self___features_denseblock3_denselayer7_conv2, l__self___features_denseblock3_denselayer8_conv2, l__self___features_denseblock3_denselayer9_conv2, l__self___features_denseblock3_denselayer10_conv2, l__self___features_denseblock3_denselayer11_conv2, l__self___features_denseblock3_denselayer12_conv2, l__self___features_denseblock3_denselayer13_conv2, l__self___features_denseblock3_denselayer14_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock3_denselayer15_norm1  L__self___features_denseblock3_denselayer15_norm1           (cat_34,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer15_relu1  L__self___features_denseblock3_denselayer15_relu1           (l__self___features_denseblock3_denselayer15_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer15_conv1  L__self___features_denseblock3_denselayer15_conv1           (l__self___features_denseblock3_denselayer15_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer15_norm2  L__self___features_denseblock3_denselayer15_norm2           (l__self___features_denseblock3_denselayer15_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer15_relu2  L__self___features_denseblock3_denselayer15_relu2           (l__self___features_denseblock3_denselayer15_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer15_conv2  L__self___features_denseblock3_denselayer15_conv2           (l__self___features_denseblock3_denselayer15_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_35                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2, l__self___features_denseblock3_denselayer4_conv2, l__self___features_denseblock3_denselayer5_conv2, l__self___features_denseblock3_denselayer6_conv2, l__self___features_denseblock3_denselayer7_conv2, l__self___features_denseblock3_denselayer8_conv2, l__self___features_denseblock3_denselayer9_conv2, l__self___features_denseblock3_denselayer10_conv2, l__self___features_denseblock3_denselayer11_conv2, l__self___features_denseblock3_denselayer12_conv2, l__self___features_denseblock3_denselayer13_conv2, l__self___features_denseblock3_denselayer14_conv2, l__self___features_denseblock3_denselayer15_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer16_norm1  L__self___features_denseblock3_denselayer16_norm1           (cat_35,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer16_relu1  L__self___features_denseblock3_denselayer16_relu1           (l__self___features_denseblock3_denselayer16_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer16_conv1  L__self___features_denseblock3_denselayer16_conv1           (l__self___features_denseblock3_denselayer16_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer16_norm2  L__self___features_denseblock3_denselayer16_norm2           (l__self___features_denseblock3_denselayer16_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer16_relu2  L__self___features_denseblock3_denselayer16_relu2           (l__self___features_denseblock3_denselayer16_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer16_conv2  L__self___features_denseblock3_denselayer16_conv2           (l__self___features_denseblock3_denselayer16_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_36                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2, l__self___features_denseblock3_denselayer4_conv2, l__self___features_denseblock3_denselayer5_conv2, l__self___features_denseblock3_denselayer6_conv2, l__self___features_denseblock3_denselayer7_conv2, l__self___features_denseblock3_denselayer8_conv2, l__self___features_denseblock3_denselayer9_conv2, l__self___features_denseblock3_denselayer10_conv2, l__self___features_denseblock3_denselayer11_conv2, l__self___features_denseblock3_denselayer12_conv2, l__self___features_denseblock3_denselayer13_conv2, l__self___features_denseblock3_denselayer14_conv2, l__self___features_denseblock3_denselayer15_conv2, l__self___features_denseblock3_denselayer16_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                          {}
call_module    l__self___features_denseblock3_denselayer17_norm1  L__self___features_denseblock3_denselayer17_norm1           (cat_36,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer17_relu1  L__self___features_denseblock3_denselayer17_relu1           (l__self___features_denseblock3_denselayer17_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer17_conv1  L__self___features_denseblock3_denselayer17_conv1           (l__self___features_denseblock3_denselayer17_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer17_norm2  L__self___features_denseblock3_denselayer17_norm2           (l__self___features_denseblock3_denselayer17_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer17_relu2  L__self___features_denseblock3_denselayer17_relu2           (l__self___features_denseblock3_denselayer17_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer17_conv2  L__self___features_denseblock3_denselayer17_conv2           (l__self___features_denseblock3_denselayer17_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_37                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2, l__self___features_denseblock3_denselayer4_conv2, l__self___features_denseblock3_denselayer5_conv2, l__self___features_denseblock3_denselayer6_conv2, l__self___features_denseblock3_denselayer7_conv2, l__self___features_denseblock3_denselayer8_conv2, l__self___features_denseblock3_denselayer9_conv2, l__self___features_denseblock3_denselayer10_conv2, l__self___features_denseblock3_denselayer11_conv2, l__self___features_denseblock3_denselayer12_conv2, l__self___features_denseblock3_denselayer13_conv2, l__self___features_denseblock3_denselayer14_conv2, l__self___features_denseblock3_denselayer15_conv2, l__self___features_denseblock3_denselayer16_conv2, l__self___features_denseblock3_denselayer17_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer18_norm1  L__self___features_denseblock3_denselayer18_norm1           (cat_37,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer18_relu1  L__self___features_denseblock3_denselayer18_relu1           (l__self___features_denseblock3_denselayer18_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer18_conv1  L__self___features_denseblock3_denselayer18_conv1           (l__self___features_denseblock3_denselayer18_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer18_norm2  L__self___features_denseblock3_denselayer18_norm2           (l__self___features_denseblock3_denselayer18_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer18_relu2  L__self___features_denseblock3_denselayer18_relu2           (l__self___features_denseblock3_denselayer18_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer18_conv2  L__self___features_denseblock3_denselayer18_conv2           (l__self___features_denseblock3_denselayer18_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_38                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2, l__self___features_denseblock3_denselayer4_conv2, l__self___features_denseblock3_denselayer5_conv2, l__self___features_denseblock3_denselayer6_conv2, l__self___features_denseblock3_denselayer7_conv2, l__self___features_denseblock3_denselayer8_conv2, l__self___features_denseblock3_denselayer9_conv2, l__self___features_denseblock3_denselayer10_conv2, l__self___features_denseblock3_denselayer11_conv2, l__self___features_denseblock3_denselayer12_conv2, l__self___features_denseblock3_denselayer13_conv2, l__self___features_denseblock3_denselayer14_conv2, l__self___features_denseblock3_denselayer15_conv2, l__self___features_denseblock3_denselayer16_conv2, l__self___features_denseblock3_denselayer17_conv2, l__self___features_denseblock3_denselayer18_conv2], 1)                                                                                                                                                                                                                                                                                                                    {}
call_module    l__self___features_denseblock3_denselayer19_norm1  L__self___features_denseblock3_denselayer19_norm1           (cat_38,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer19_relu1  L__self___features_denseblock3_denselayer19_relu1           (l__self___features_denseblock3_denselayer19_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer19_conv1  L__self___features_denseblock3_denselayer19_conv1           (l__self___features_denseblock3_denselayer19_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer19_norm2  L__self___features_denseblock3_denselayer19_norm2           (l__self___features_denseblock3_denselayer19_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer19_relu2  L__self___features_denseblock3_denselayer19_relu2           (l__self___features_denseblock3_denselayer19_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer19_conv2  L__self___features_denseblock3_denselayer19_conv2           (l__self___features_denseblock3_denselayer19_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_39                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2, l__self___features_denseblock3_denselayer4_conv2, l__self___features_denseblock3_denselayer5_conv2, l__self___features_denseblock3_denselayer6_conv2, l__self___features_denseblock3_denselayer7_conv2, l__self___features_denseblock3_denselayer8_conv2, l__self___features_denseblock3_denselayer9_conv2, l__self___features_denseblock3_denselayer10_conv2, l__self___features_denseblock3_denselayer11_conv2, l__self___features_denseblock3_denselayer12_conv2, l__self___features_denseblock3_denselayer13_conv2, l__self___features_denseblock3_denselayer14_conv2, l__self___features_denseblock3_denselayer15_conv2, l__self___features_denseblock3_denselayer16_conv2, l__self___features_denseblock3_denselayer17_conv2, l__self___features_denseblock3_denselayer18_conv2, l__self___features_denseblock3_denselayer19_conv2], 1)                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock3_denselayer20_norm1  L__self___features_denseblock3_denselayer20_norm1           (cat_39,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer20_relu1  L__self___features_denseblock3_denselayer20_relu1           (l__self___features_denseblock3_denselayer20_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer20_conv1  L__self___features_denseblock3_denselayer20_conv1           (l__self___features_denseblock3_denselayer20_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer20_norm2  L__self___features_denseblock3_denselayer20_norm2           (l__self___features_denseblock3_denselayer20_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer20_relu2  L__self___features_denseblock3_denselayer20_relu2           (l__self___features_denseblock3_denselayer20_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer20_conv2  L__self___features_denseblock3_denselayer20_conv2           (l__self___features_denseblock3_denselayer20_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_40                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2, l__self___features_denseblock3_denselayer4_conv2, l__self___features_denseblock3_denselayer5_conv2, l__self___features_denseblock3_denselayer6_conv2, l__self___features_denseblock3_denselayer7_conv2, l__self___features_denseblock3_denselayer8_conv2, l__self___features_denseblock3_denselayer9_conv2, l__self___features_denseblock3_denselayer10_conv2, l__self___features_denseblock3_denselayer11_conv2, l__self___features_denseblock3_denselayer12_conv2, l__self___features_denseblock3_denselayer13_conv2, l__self___features_denseblock3_denselayer14_conv2, l__self___features_denseblock3_denselayer15_conv2, l__self___features_denseblock3_denselayer16_conv2, l__self___features_denseblock3_denselayer17_conv2, l__self___features_denseblock3_denselayer18_conv2, l__self___features_denseblock3_denselayer19_conv2, l__self___features_denseblock3_denselayer20_conv2], 1)                                                                                                                                                                                                              {}
call_module    l__self___features_denseblock3_denselayer21_norm1  L__self___features_denseblock3_denselayer21_norm1           (cat_40,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer21_relu1  L__self___features_denseblock3_denselayer21_relu1           (l__self___features_denseblock3_denselayer21_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer21_conv1  L__self___features_denseblock3_denselayer21_conv1           (l__self___features_denseblock3_denselayer21_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer21_norm2  L__self___features_denseblock3_denselayer21_norm2           (l__self___features_denseblock3_denselayer21_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer21_relu2  L__self___features_denseblock3_denselayer21_relu2           (l__self___features_denseblock3_denselayer21_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer21_conv2  L__self___features_denseblock3_denselayer21_conv2           (l__self___features_denseblock3_denselayer21_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_41                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2, l__self___features_denseblock3_denselayer4_conv2, l__self___features_denseblock3_denselayer5_conv2, l__self___features_denseblock3_denselayer6_conv2, l__self___features_denseblock3_denselayer7_conv2, l__self___features_denseblock3_denselayer8_conv2, l__self___features_denseblock3_denselayer9_conv2, l__self___features_denseblock3_denselayer10_conv2, l__self___features_denseblock3_denselayer11_conv2, l__self___features_denseblock3_denselayer12_conv2, l__self___features_denseblock3_denselayer13_conv2, l__self___features_denseblock3_denselayer14_conv2, l__self___features_denseblock3_denselayer15_conv2, l__self___features_denseblock3_denselayer16_conv2, l__self___features_denseblock3_denselayer17_conv2, l__self___features_denseblock3_denselayer18_conv2, l__self___features_denseblock3_denselayer19_conv2, l__self___features_denseblock3_denselayer20_conv2, l__self___features_denseblock3_denselayer21_conv2], 1)                                                                                                                                                           {}
call_module    l__self___features_denseblock3_denselayer22_norm1  L__self___features_denseblock3_denselayer22_norm1           (cat_41,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer22_relu1  L__self___features_denseblock3_denselayer22_relu1           (l__self___features_denseblock3_denselayer22_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer22_conv1  L__self___features_denseblock3_denselayer22_conv1           (l__self___features_denseblock3_denselayer22_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer22_norm2  L__self___features_denseblock3_denselayer22_norm2           (l__self___features_denseblock3_denselayer22_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer22_relu2  L__self___features_denseblock3_denselayer22_relu2           (l__self___features_denseblock3_denselayer22_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer22_conv2  L__self___features_denseblock3_denselayer22_conv2           (l__self___features_denseblock3_denselayer22_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_42                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2, l__self___features_denseblock3_denselayer4_conv2, l__self___features_denseblock3_denselayer5_conv2, l__self___features_denseblock3_denselayer6_conv2, l__self___features_denseblock3_denselayer7_conv2, l__self___features_denseblock3_denselayer8_conv2, l__self___features_denseblock3_denselayer9_conv2, l__self___features_denseblock3_denselayer10_conv2, l__self___features_denseblock3_denselayer11_conv2, l__self___features_denseblock3_denselayer12_conv2, l__self___features_denseblock3_denselayer13_conv2, l__self___features_denseblock3_denselayer14_conv2, l__self___features_denseblock3_denselayer15_conv2, l__self___features_denseblock3_denselayer16_conv2, l__self___features_denseblock3_denselayer17_conv2, l__self___features_denseblock3_denselayer18_conv2, l__self___features_denseblock3_denselayer19_conv2, l__self___features_denseblock3_denselayer20_conv2, l__self___features_denseblock3_denselayer21_conv2, l__self___features_denseblock3_denselayer22_conv2], 1)                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer23_norm1  L__self___features_denseblock3_denselayer23_norm1           (cat_42,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer23_relu1  L__self___features_denseblock3_denselayer23_relu1           (l__self___features_denseblock3_denselayer23_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer23_conv1  L__self___features_denseblock3_denselayer23_conv1           (l__self___features_denseblock3_denselayer23_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer23_norm2  L__self___features_denseblock3_denselayer23_norm2           (l__self___features_denseblock3_denselayer23_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer23_relu2  L__self___features_denseblock3_denselayer23_relu2           (l__self___features_denseblock3_denselayer23_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer23_conv2  L__self___features_denseblock3_denselayer23_conv2           (l__self___features_denseblock3_denselayer23_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_43                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2, l__self___features_denseblock3_denselayer4_conv2, l__self___features_denseblock3_denselayer5_conv2, l__self___features_denseblock3_denselayer6_conv2, l__self___features_denseblock3_denselayer7_conv2, l__self___features_denseblock3_denselayer8_conv2, l__self___features_denseblock3_denselayer9_conv2, l__self___features_denseblock3_denselayer10_conv2, l__self___features_denseblock3_denselayer11_conv2, l__self___features_denseblock3_denselayer12_conv2, l__self___features_denseblock3_denselayer13_conv2, l__self___features_denseblock3_denselayer14_conv2, l__self___features_denseblock3_denselayer15_conv2, l__self___features_denseblock3_denselayer16_conv2, l__self___features_denseblock3_denselayer17_conv2, l__self___features_denseblock3_denselayer18_conv2, l__self___features_denseblock3_denselayer19_conv2, l__self___features_denseblock3_denselayer20_conv2, l__self___features_denseblock3_denselayer21_conv2, l__self___features_denseblock3_denselayer22_conv2, l__self___features_denseblock3_denselayer23_conv2], 1)                                                     {}
call_module    l__self___features_denseblock3_denselayer24_norm1  L__self___features_denseblock3_denselayer24_norm1           (cat_43,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer24_relu1  L__self___features_denseblock3_denselayer24_relu1           (l__self___features_denseblock3_denselayer24_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer24_conv1  L__self___features_denseblock3_denselayer24_conv1           (l__self___features_denseblock3_denselayer24_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer24_norm2  L__self___features_denseblock3_denselayer24_norm2           (l__self___features_denseblock3_denselayer24_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer24_relu2  L__self___features_denseblock3_denselayer24_relu2           (l__self___features_denseblock3_denselayer24_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer24_conv2  L__self___features_denseblock3_denselayer24_conv2           (l__self___features_denseblock3_denselayer24_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_44                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition2_pool, l__self___features_denseblock3_denselayer1_conv2, l__self___features_denseblock3_denselayer2_conv2, l__self___features_denseblock3_denselayer3_conv2, l__self___features_denseblock3_denselayer4_conv2, l__self___features_denseblock3_denselayer5_conv2, l__self___features_denseblock3_denselayer6_conv2, l__self___features_denseblock3_denselayer7_conv2, l__self___features_denseblock3_denselayer8_conv2, l__self___features_denseblock3_denselayer9_conv2, l__self___features_denseblock3_denselayer10_conv2, l__self___features_denseblock3_denselayer11_conv2, l__self___features_denseblock3_denselayer12_conv2, l__self___features_denseblock3_denselayer13_conv2, l__self___features_denseblock3_denselayer14_conv2, l__self___features_denseblock3_denselayer15_conv2, l__self___features_denseblock3_denselayer16_conv2, l__self___features_denseblock3_denselayer17_conv2, l__self___features_denseblock3_denselayer18_conv2, l__self___features_denseblock3_denselayer19_conv2, l__self___features_denseblock3_denselayer20_conv2, l__self___features_denseblock3_denselayer21_conv2, l__self___features_denseblock3_denselayer22_conv2, l__self___features_denseblock3_denselayer23_conv2, l__self___features_denseblock3_denselayer24_conv2], 1)  {}
call_module    l__self___features_transition3_norm                L__self___features_transition3_norm                         (cat_44,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_transition3_relu                L__self___features_transition3_relu                         (l__self___features_transition3_norm,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     {}
call_module    l__self___features_transition3_conv                L__self___features_transition3_conv                         (l__self___features_transition3_relu,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     {}
call_module    l__self___features_transition3_pool                L__self___features_transition3_pool                         (l__self___features_transition3_conv,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     {}
call_function  cat_45                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition3_pool], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock4_denselayer1_norm1   L__self___features_denseblock4_denselayer1_norm1            (cat_45,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock4_denselayer1_relu1   L__self___features_denseblock4_denselayer1_relu1            (l__self___features_denseblock4_denselayer1_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer1_conv1   L__self___features_denseblock4_denselayer1_conv1            (l__self___features_denseblock4_denselayer1_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer1_norm2   L__self___features_denseblock4_denselayer1_norm2            (l__self___features_denseblock4_denselayer1_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer1_relu2   L__self___features_denseblock4_denselayer1_relu2            (l__self___features_denseblock4_denselayer1_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer1_conv2   L__self___features_denseblock4_denselayer1_conv2            (l__self___features_denseblock4_denselayer1_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_46                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition3_pool, l__self___features_denseblock4_denselayer1_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               {}
call_module    l__self___features_denseblock4_denselayer2_norm1   L__self___features_denseblock4_denselayer2_norm1            (cat_46,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock4_denselayer2_relu1   L__self___features_denseblock4_denselayer2_relu1            (l__self___features_denseblock4_denselayer2_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer2_conv1   L__self___features_denseblock4_denselayer2_conv1            (l__self___features_denseblock4_denselayer2_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer2_norm2   L__self___features_denseblock4_denselayer2_norm2            (l__self___features_denseblock4_denselayer2_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer2_relu2   L__self___features_denseblock4_denselayer2_relu2            (l__self___features_denseblock4_denselayer2_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer2_conv2   L__self___features_denseblock4_denselayer2_conv2            (l__self___features_denseblock4_denselayer2_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_47                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition3_pool, l__self___features_denseblock4_denselayer1_conv2, l__self___features_denseblock4_denselayer2_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer3_norm1   L__self___features_denseblock4_denselayer3_norm1            (cat_47,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock4_denselayer3_relu1   L__self___features_denseblock4_denselayer3_relu1            (l__self___features_denseblock4_denselayer3_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer3_conv1   L__self___features_denseblock4_denselayer3_conv1            (l__self___features_denseblock4_denselayer3_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer3_norm2   L__self___features_denseblock4_denselayer3_norm2            (l__self___features_denseblock4_denselayer3_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer3_relu2   L__self___features_denseblock4_denselayer3_relu2            (l__self___features_denseblock4_denselayer3_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer3_conv2   L__self___features_denseblock4_denselayer3_conv2            (l__self___features_denseblock4_denselayer3_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_48                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition3_pool, l__self___features_denseblock4_denselayer1_conv2, l__self___features_denseblock4_denselayer2_conv2, l__self___features_denseblock4_denselayer3_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           {}
call_module    l__self___features_denseblock4_denselayer4_norm1   L__self___features_denseblock4_denselayer4_norm1            (cat_48,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock4_denselayer4_relu1   L__self___features_denseblock4_denselayer4_relu1            (l__self___features_denseblock4_denselayer4_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer4_conv1   L__self___features_denseblock4_denselayer4_conv1            (l__self___features_denseblock4_denselayer4_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer4_norm2   L__self___features_denseblock4_denselayer4_norm2            (l__self___features_denseblock4_denselayer4_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer4_relu2   L__self___features_denseblock4_denselayer4_relu2            (l__self___features_denseblock4_denselayer4_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer4_conv2   L__self___features_denseblock4_denselayer4_conv2            (l__self___features_denseblock4_denselayer4_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_49                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition3_pool, l__self___features_denseblock4_denselayer1_conv2, l__self___features_denseblock4_denselayer2_conv2, l__self___features_denseblock4_denselayer3_conv2, l__self___features_denseblock4_denselayer4_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         {}
call_module    l__self___features_denseblock4_denselayer5_norm1   L__self___features_denseblock4_denselayer5_norm1            (cat_49,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock4_denselayer5_relu1   L__self___features_denseblock4_denselayer5_relu1            (l__self___features_denseblock4_denselayer5_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer5_conv1   L__self___features_denseblock4_denselayer5_conv1            (l__self___features_denseblock4_denselayer5_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer5_norm2   L__self___features_denseblock4_denselayer5_norm2            (l__self___features_denseblock4_denselayer5_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer5_relu2   L__self___features_denseblock4_denselayer5_relu2            (l__self___features_denseblock4_denselayer5_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer5_conv2   L__self___features_denseblock4_denselayer5_conv2            (l__self___features_denseblock4_denselayer5_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_50                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition3_pool, l__self___features_denseblock4_denselayer1_conv2, l__self___features_denseblock4_denselayer2_conv2, l__self___features_denseblock4_denselayer3_conv2, l__self___features_denseblock4_denselayer4_conv2, l__self___features_denseblock4_denselayer5_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer6_norm1   L__self___features_denseblock4_denselayer6_norm1            (cat_50,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock4_denselayer6_relu1   L__self___features_denseblock4_denselayer6_relu1            (l__self___features_denseblock4_denselayer6_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer6_conv1   L__self___features_denseblock4_denselayer6_conv1            (l__self___features_denseblock4_denselayer6_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer6_norm2   L__self___features_denseblock4_denselayer6_norm2            (l__self___features_denseblock4_denselayer6_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer6_relu2   L__self___features_denseblock4_denselayer6_relu2            (l__self___features_denseblock4_denselayer6_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer6_conv2   L__self___features_denseblock4_denselayer6_conv2            (l__self___features_denseblock4_denselayer6_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_51                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition3_pool, l__self___features_denseblock4_denselayer1_conv2, l__self___features_denseblock4_denselayer2_conv2, l__self___features_denseblock4_denselayer3_conv2, l__self___features_denseblock4_denselayer4_conv2, l__self___features_denseblock4_denselayer5_conv2, l__self___features_denseblock4_denselayer6_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     {}
call_module    l__self___features_denseblock4_denselayer7_norm1   L__self___features_denseblock4_denselayer7_norm1            (cat_51,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock4_denselayer7_relu1   L__self___features_denseblock4_denselayer7_relu1            (l__self___features_denseblock4_denselayer7_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer7_conv1   L__self___features_denseblock4_denselayer7_conv1            (l__self___features_denseblock4_denselayer7_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer7_norm2   L__self___features_denseblock4_denselayer7_norm2            (l__self___features_denseblock4_denselayer7_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer7_relu2   L__self___features_denseblock4_denselayer7_relu2            (l__self___features_denseblock4_denselayer7_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer7_conv2   L__self___features_denseblock4_denselayer7_conv2            (l__self___features_denseblock4_denselayer7_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_52                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition3_pool, l__self___features_denseblock4_denselayer1_conv2, l__self___features_denseblock4_denselayer2_conv2, l__self___features_denseblock4_denselayer3_conv2, l__self___features_denseblock4_denselayer4_conv2, l__self___features_denseblock4_denselayer5_conv2, l__self___features_denseblock4_denselayer6_conv2, l__self___features_denseblock4_denselayer7_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   {}
call_module    l__self___features_denseblock4_denselayer8_norm1   L__self___features_denseblock4_denselayer8_norm1            (cat_52,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock4_denselayer8_relu1   L__self___features_denseblock4_denselayer8_relu1            (l__self___features_denseblock4_denselayer8_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer8_conv1   L__self___features_denseblock4_denselayer8_conv1            (l__self___features_denseblock4_denselayer8_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer8_norm2   L__self___features_denseblock4_denselayer8_norm2            (l__self___features_denseblock4_denselayer8_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer8_relu2   L__self___features_denseblock4_denselayer8_relu2            (l__self___features_denseblock4_denselayer8_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer8_conv2   L__self___features_denseblock4_denselayer8_conv2            (l__self___features_denseblock4_denselayer8_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_53                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition3_pool, l__self___features_denseblock4_denselayer1_conv2, l__self___features_denseblock4_denselayer2_conv2, l__self___features_denseblock4_denselayer3_conv2, l__self___features_denseblock4_denselayer4_conv2, l__self___features_denseblock4_denselayer5_conv2, l__self___features_denseblock4_denselayer6_conv2, l__self___features_denseblock4_denselayer7_conv2, l__self___features_denseblock4_denselayer8_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock4_denselayer9_norm1   L__self___features_denseblock4_denselayer9_norm1            (cat_53,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock4_denselayer9_relu1   L__self___features_denseblock4_denselayer9_relu1            (l__self___features_denseblock4_denselayer9_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer9_conv1   L__self___features_denseblock4_denselayer9_conv1            (l__self___features_denseblock4_denselayer9_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer9_norm2   L__self___features_denseblock4_denselayer9_norm2            (l__self___features_denseblock4_denselayer9_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer9_relu2   L__self___features_denseblock4_denselayer9_relu2            (l__self___features_denseblock4_denselayer9_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer9_conv2   L__self___features_denseblock4_denselayer9_conv2            (l__self___features_denseblock4_denselayer9_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        {}
call_function  cat_54                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition3_pool, l__self___features_denseblock4_denselayer1_conv2, l__self___features_denseblock4_denselayer2_conv2, l__self___features_denseblock4_denselayer3_conv2, l__self___features_denseblock4_denselayer4_conv2, l__self___features_denseblock4_denselayer5_conv2, l__self___features_denseblock4_denselayer6_conv2, l__self___features_denseblock4_denselayer7_conv2, l__self___features_denseblock4_denselayer8_conv2, l__self___features_denseblock4_denselayer9_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               {}
call_module    l__self___features_denseblock4_denselayer10_norm1  L__self___features_denseblock4_denselayer10_norm1           (cat_54,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock4_denselayer10_relu1  L__self___features_denseblock4_denselayer10_relu1           (l__self___features_denseblock4_denselayer10_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer10_conv1  L__self___features_denseblock4_denselayer10_conv1           (l__self___features_denseblock4_denselayer10_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer10_norm2  L__self___features_denseblock4_denselayer10_norm2           (l__self___features_denseblock4_denselayer10_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer10_relu2  L__self___features_denseblock4_denselayer10_relu2           (l__self___features_denseblock4_denselayer10_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer10_conv2  L__self___features_denseblock4_denselayer10_conv2           (l__self___features_denseblock4_denselayer10_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_55                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition3_pool, l__self___features_denseblock4_denselayer1_conv2, l__self___features_denseblock4_denselayer2_conv2, l__self___features_denseblock4_denselayer3_conv2, l__self___features_denseblock4_denselayer4_conv2, l__self___features_denseblock4_denselayer5_conv2, l__self___features_denseblock4_denselayer6_conv2, l__self___features_denseblock4_denselayer7_conv2, l__self___features_denseblock4_denselayer8_conv2, l__self___features_denseblock4_denselayer9_conv2, l__self___features_denseblock4_denselayer10_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            {}
call_module    l__self___features_denseblock4_denselayer11_norm1  L__self___features_denseblock4_denselayer11_norm1           (cat_55,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock4_denselayer11_relu1  L__self___features_denseblock4_denselayer11_relu1           (l__self___features_denseblock4_denselayer11_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer11_conv1  L__self___features_denseblock4_denselayer11_conv1           (l__self___features_denseblock4_denselayer11_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer11_norm2  L__self___features_denseblock4_denselayer11_norm2           (l__self___features_denseblock4_denselayer11_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer11_relu2  L__self___features_denseblock4_denselayer11_relu2           (l__self___features_denseblock4_denselayer11_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer11_conv2  L__self___features_denseblock4_denselayer11_conv2           (l__self___features_denseblock4_denselayer11_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_56                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition3_pool, l__self___features_denseblock4_denselayer1_conv2, l__self___features_denseblock4_denselayer2_conv2, l__self___features_denseblock4_denselayer3_conv2, l__self___features_denseblock4_denselayer4_conv2, l__self___features_denseblock4_denselayer5_conv2, l__self___features_denseblock4_denselayer6_conv2, l__self___features_denseblock4_denselayer7_conv2, l__self___features_denseblock4_denselayer8_conv2, l__self___features_denseblock4_denselayer9_conv2, l__self___features_denseblock4_denselayer10_conv2, l__self___features_denseblock4_denselayer11_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         {}
call_module    l__self___features_denseblock4_denselayer12_norm1  L__self___features_denseblock4_denselayer12_norm1           (cat_56,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock4_denselayer12_relu1  L__self___features_denseblock4_denselayer12_relu1           (l__self___features_denseblock4_denselayer12_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer12_conv1  L__self___features_denseblock4_denselayer12_conv1           (l__self___features_denseblock4_denselayer12_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer12_norm2  L__self___features_denseblock4_denselayer12_norm2           (l__self___features_denseblock4_denselayer12_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer12_relu2  L__self___features_denseblock4_denselayer12_relu2           (l__self___features_denseblock4_denselayer12_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer12_conv2  L__self___features_denseblock4_denselayer12_conv2           (l__self___features_denseblock4_denselayer12_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_57                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition3_pool, l__self___features_denseblock4_denselayer1_conv2, l__self___features_denseblock4_denselayer2_conv2, l__self___features_denseblock4_denselayer3_conv2, l__self___features_denseblock4_denselayer4_conv2, l__self___features_denseblock4_denselayer5_conv2, l__self___features_denseblock4_denselayer6_conv2, l__self___features_denseblock4_denselayer7_conv2, l__self___features_denseblock4_denselayer8_conv2, l__self___features_denseblock4_denselayer9_conv2, l__self___features_denseblock4_denselayer10_conv2, l__self___features_denseblock4_denselayer11_conv2, l__self___features_denseblock4_denselayer12_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      {}
call_module    l__self___features_denseblock4_denselayer13_norm1  L__self___features_denseblock4_denselayer13_norm1           (cat_57,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock4_denselayer13_relu1  L__self___features_denseblock4_denselayer13_relu1           (l__self___features_denseblock4_denselayer13_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer13_conv1  L__self___features_denseblock4_denselayer13_conv1           (l__self___features_denseblock4_denselayer13_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer13_norm2  L__self___features_denseblock4_denselayer13_norm2           (l__self___features_denseblock4_denselayer13_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer13_relu2  L__self___features_denseblock4_denselayer13_relu2           (l__self___features_denseblock4_denselayer13_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer13_conv2  L__self___features_denseblock4_denselayer13_conv2           (l__self___features_denseblock4_denselayer13_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_58                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition3_pool, l__self___features_denseblock4_denselayer1_conv2, l__self___features_denseblock4_denselayer2_conv2, l__self___features_denseblock4_denselayer3_conv2, l__self___features_denseblock4_denselayer4_conv2, l__self___features_denseblock4_denselayer5_conv2, l__self___features_denseblock4_denselayer6_conv2, l__self___features_denseblock4_denselayer7_conv2, l__self___features_denseblock4_denselayer8_conv2, l__self___features_denseblock4_denselayer9_conv2, l__self___features_denseblock4_denselayer10_conv2, l__self___features_denseblock4_denselayer11_conv2, l__self___features_denseblock4_denselayer12_conv2, l__self___features_denseblock4_denselayer13_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   {}
call_module    l__self___features_denseblock4_denselayer14_norm1  L__self___features_denseblock4_denselayer14_norm1           (cat_58,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock4_denselayer14_relu1  L__self___features_denseblock4_denselayer14_relu1           (l__self___features_denseblock4_denselayer14_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer14_conv1  L__self___features_denseblock4_denselayer14_conv1           (l__self___features_denseblock4_denselayer14_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer14_norm2  L__self___features_denseblock4_denselayer14_norm2           (l__self___features_denseblock4_denselayer14_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer14_relu2  L__self___features_denseblock4_denselayer14_relu2           (l__self___features_denseblock4_denselayer14_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer14_conv2  L__self___features_denseblock4_denselayer14_conv2           (l__self___features_denseblock4_denselayer14_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_59                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition3_pool, l__self___features_denseblock4_denselayer1_conv2, l__self___features_denseblock4_denselayer2_conv2, l__self___features_denseblock4_denselayer3_conv2, l__self___features_denseblock4_denselayer4_conv2, l__self___features_denseblock4_denselayer5_conv2, l__self___features_denseblock4_denselayer6_conv2, l__self___features_denseblock4_denselayer7_conv2, l__self___features_denseblock4_denselayer8_conv2, l__self___features_denseblock4_denselayer9_conv2, l__self___features_denseblock4_denselayer10_conv2, l__self___features_denseblock4_denselayer11_conv2, l__self___features_denseblock4_denselayer12_conv2, l__self___features_denseblock4_denselayer13_conv2, l__self___features_denseblock4_denselayer14_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock4_denselayer15_norm1  L__self___features_denseblock4_denselayer15_norm1           (cat_59,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock4_denselayer15_relu1  L__self___features_denseblock4_denselayer15_relu1           (l__self___features_denseblock4_denselayer15_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer15_conv1  L__self___features_denseblock4_denselayer15_conv1           (l__self___features_denseblock4_denselayer15_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer15_norm2  L__self___features_denseblock4_denselayer15_norm2           (l__self___features_denseblock4_denselayer15_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer15_relu2  L__self___features_denseblock4_denselayer15_relu2           (l__self___features_denseblock4_denselayer15_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer15_conv2  L__self___features_denseblock4_denselayer15_conv2           (l__self___features_denseblock4_denselayer15_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_60                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition3_pool, l__self___features_denseblock4_denselayer1_conv2, l__self___features_denseblock4_denselayer2_conv2, l__self___features_denseblock4_denselayer3_conv2, l__self___features_denseblock4_denselayer4_conv2, l__self___features_denseblock4_denselayer5_conv2, l__self___features_denseblock4_denselayer6_conv2, l__self___features_denseblock4_denselayer7_conv2, l__self___features_denseblock4_denselayer8_conv2, l__self___features_denseblock4_denselayer9_conv2, l__self___features_denseblock4_denselayer10_conv2, l__self___features_denseblock4_denselayer11_conv2, l__self___features_denseblock4_denselayer12_conv2, l__self___features_denseblock4_denselayer13_conv2, l__self___features_denseblock4_denselayer14_conv2, l__self___features_denseblock4_denselayer15_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer16_norm1  L__self___features_denseblock4_denselayer16_norm1           (cat_60,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock4_denselayer16_relu1  L__self___features_denseblock4_denselayer16_relu1           (l__self___features_denseblock4_denselayer16_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer16_conv1  L__self___features_denseblock4_denselayer16_conv1           (l__self___features_denseblock4_denselayer16_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer16_norm2  L__self___features_denseblock4_denselayer16_norm2           (l__self___features_denseblock4_denselayer16_conv1,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer16_relu2  L__self___features_denseblock4_denselayer16_relu2           (l__self___features_denseblock4_denselayer16_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer16_conv2  L__self___features_denseblock4_denselayer16_conv2           (l__self___features_denseblock4_denselayer16_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  cat_61                                             <built-in method cat of type object at 0x7f48231e0a40>      ([l__self___features_transition3_pool, l__self___features_denseblock4_denselayer1_conv2, l__self___features_denseblock4_denselayer2_conv2, l__self___features_denseblock4_denselayer3_conv2, l__self___features_denseblock4_denselayer4_conv2, l__self___features_denseblock4_denselayer5_conv2, l__self___features_denseblock4_denselayer6_conv2, l__self___features_denseblock4_denselayer7_conv2, l__self___features_denseblock4_denselayer8_conv2, l__self___features_denseblock4_denselayer9_conv2, l__self___features_denseblock4_denselayer10_conv2, l__self___features_denseblock4_denselayer11_conv2, l__self___features_denseblock4_denselayer12_conv2, l__self___features_denseblock4_denselayer13_conv2, l__self___features_denseblock4_denselayer14_conv2, l__self___features_denseblock4_denselayer15_conv2, l__self___features_denseblock4_denselayer16_conv2], 1)                                                                                                                                                                                                                                                                                                                                                                                                                          {}
call_module    l__self___features_norm5                           L__self___features_norm5                                    (cat_61,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_function  relu                                               <function relu at 0x7f478cade950>                           (l__self___features_norm5,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                {'inplace': True}
call_function  adaptive_avg_pool2d                                <function adaptive_avg_pool2d at 0x7f478cade440>            (relu, (1, 1))                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_function  flatten                                            <built-in method flatten of type object at 0x7f48231e0a40>  (adaptive_avg_pool2d, 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   {}
call_module    l__self___classifier                               L__self___classifier                                        (flatten,)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 {}
output         output                                             output                                                      ((l__self___classifier,),)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 {}

tensor([[ 0.0614, -0.4023, -0.2792,  ..., -0.5549,  0.0976, -0.0634],
        [-0.2032, -0.2706, -0.0935,  ..., -0.4815,  0.0758, -0.1038],
        [ 0.0637, -0.3492, -0.1492,  ..., -0.4841,  0.1776, -0.0723],
        ...,
        [-0.1050, -0.3393,  0.0092,  ..., -0.4862,  0.0555, -0.1058],
        [ 0.0018, -0.2431, -0.1656,  ..., -0.5072,  0.0977, -0.1387],
        [ 0.1192, -0.3563, -0.1147,  ..., -0.4839,  0.1770, -0.0659]],
       device='cuda:0', grad_fn=<AddmmBackward0>)

Using our custom backend, we can now see how TorchDynamo is able to handle data-dependent control flow. Consider the function below, where the line if b.sum() < 0 is the source of data-dependent control flow.

def bar(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b

opt_bar = torch.compile(bar, backend=custom_backend)
inp1 = torch.randn(10)
inp2 = torch.randn(10)
opt_bar(inp1, inp2)
opt_bar(inp1, -inp2)
custom backend called with FX graph:
opcode         name     target                                                  args              kwargs
-------------  -------  ------------------------------------------------------  ----------------  --------
placeholder    l_a_     L_a_                                                    ()                {}
placeholder    l_b_     L_b_                                                    ()                {}
call_function  abs_1    <built-in method abs of type object at 0x7f48231e0a40>  (l_a_,)           {}
call_function  add      <built-in function add>                                 (abs_1, 1)        {}
call_function  truediv  <built-in function truediv>                             (l_a_, add)       {}
call_method    sum_1    sum                                                     (l_b_,)           {}
call_function  lt       <built-in function lt>                                  (sum_1, 0)        {}
output         output   output                                                  ((truediv, lt),)  {}
custom backend called with FX graph:
opcode         name    target                   args          kwargs
-------------  ------  -----------------------  ------------  --------
placeholder    l_b_    L_b_                     ()            {}
placeholder    l_x_    L_x_                     ()            {}
call_function  mul     <built-in function mul>  (l_x_, l_b_)  {}
output         output  output                   ((mul,),)     {}
custom backend called with FX graph:
opcode         name    target                   args         kwargs
-------------  ------  -----------------------  -----------  --------
placeholder    l_b_    L_b_                     ()           {}
placeholder    l_x_    L_x_                     ()           {}
call_function  mul     <built-in function mul>  (l_b_, -1)   {}
call_function  mul_1   <built-in function mul>  (l_x_, mul)  {}
output         output  output                   ((mul_1,),)  {}

tensor([-0.0176,  1.0753,  0.0282,  0.0756, -0.0176,  0.0633, -0.9161,  0.1333,
        -0.1971, -0.3406])

The output reveals that TorchDynamo extracted 3 different FX graphs corresponding the following code (order may differ from the output above):

  1. x = a / (torch.abs(a) + 1)

  2. b = b * -1; return x * b

  3. return x * b

When TorchDynamo encounters unsupported Python features, such as data-dependent control flow, it breaks the computation graph, lets the default Python interpreter handle the unsupported code, then resumes capturing the graph.

Let’s investigate by example how TorchDynamo would step through bar. If b.sum() < 0, then TorchDynamo would run graph 1, let Python determine the result of the conditional, then run graph 2. On the other hand, if not b.sum() < 0, then TorchDynamo would run graph 1, let Python determine the result of the conditional, then run graph 3.

This highlights a major difference between TorchDynamo and previous PyTorch compiler solutions. When encountering unsupported Python features, previous solutions either raise an error or silently fail. TorchDynamo, on the other hand, will break the computation graph.

We can see where TorchDynamo breaks the graph by using torch._dynamo.explain:

# Reset since we are using a different backend.
torch._dynamo.reset()
explain_output = torch._dynamo.explain(bar)(torch.randn(10), torch.randn(10))
print(explain_output)
Graph Count: 2
Graph Break Count: 1
Op Count: 6
Break Reasons:
  Break Reason 1:
    Reason: generic_jump TensorVariable()
    User Stack:
      <FrameSummary file /var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py, line 434 in bar>
Ops per Graph:
  Ops 1:
    <built-in method abs of type object at 0x7f48231e0a40>
    <built-in function add>
    <built-in function truediv>
    <built-in function lt>
  Ops 2:
    <built-in function mul>
    <built-in function mul>
Out Guards:
  Guard 1:
    Name: ''
    Source: global
    Create Function: GRAD_MODE
    Guard Types: ['GRAD_MODE']
    Code List: ['___is_grad_enabled()']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 2:
    Name: "G['torch']"
    Source: global
    Create Function: FUNCTION_MATCH
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 3:
    Name: ''
    Source: global
    Create Function: DETERMINISTIC_ALGORITHMS
    Guard Types: ['DETERMINISTIC_ALGORITHMS']
    Code List: ['not ___are_deterministic_algorithms_enabled()']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 4:
    Name: ''
    Source: global
    Create Function: TORCH_FUNCTION_STATE
    Guard Types: ['TORCH_FUNCTION_STATE']
    Code List: ['___is_torch_function_enabled()']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 5:
    Name: "L['b']"
    Source: local
    Create Function: TENSOR_MATCH
    Guard Types: ['TENSOR_MATCH']
    Code List: ["hasattr(L['b'], '_dynamo_dynamic_indices') == False"]
    Object Weakref: <weakref at 0x7f474824da80; dead>
    Guarded Class Weakref: <weakref at 0x7f478d00f650; to 'torch._C._TensorMeta' at 0x50ddf90 (Tensor)>
  Guard 6:
    Name: ''
    Source: global
    Create Function: DEFAULT_DEVICE
    Guard Types: ['DEFAULT_DEVICE']
    Code List: ['utils_device.CURRENT_DEVICE == None']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 7:
    Name: "L['a']"
    Source: local
    Create Function: TENSOR_MATCH
    Guard Types: ['TENSOR_MATCH']
    Code List: ["hasattr(L['a'], '_dynamo_dynamic_indices') == False"]
    Object Weakref: <weakref at 0x7f4641dab240; dead>
    Guarded Class Weakref: <weakref at 0x7f478d00f650; to 'torch._C._TensorMeta' at 0x50ddf90 (Tensor)>
  Guard 8:
    Name: ''
    Source: shape_env
    Create Function: SHAPE_ENV
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 9:
    Name: ''
    Source: global
    Create Function: GRAD_MODE
    Guard Types: ['GRAD_MODE']
    Code List: ['___is_grad_enabled()']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 10:
    Name: "L['b']"
    Source: local
    Create Function: TENSOR_MATCH
    Guard Types: ['TENSOR_MATCH']
    Code List: ["hasattr(L['b'], '_dynamo_dynamic_indices') == False"]
    Object Weakref: <weakref at 0x7f474824da80; dead>
    Guarded Class Weakref: <weakref at 0x7f478d00f650; to 'torch._C._TensorMeta' at 0x50ddf90 (Tensor)>
  Guard 11:
    Name: ''
    Source: global
    Create Function: DETERMINISTIC_ALGORITHMS
    Guard Types: ['DETERMINISTIC_ALGORITHMS']
    Code List: ['not ___are_deterministic_algorithms_enabled()']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 12:
    Name: "L['x']"
    Source: local
    Create Function: TENSOR_MATCH
    Guard Types: ['TENSOR_MATCH']
    Code List: ["hasattr(L['x'], '_dynamo_dynamic_indices') == False"]
    Object Weakref: <weakref at 0x7f4640689f30; dead>
    Guarded Class Weakref: <weakref at 0x7f478d00f650; to 'torch._C._TensorMeta' at 0x50ddf90 (Tensor)>
  Guard 13:
    Name: ''
    Source: global
    Create Function: TORCH_FUNCTION_STATE
    Guard Types: ['TORCH_FUNCTION_STATE']
    Code List: ['___is_torch_function_enabled()']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 14:
    Name: ''
    Source: global
    Create Function: DEFAULT_DEVICE
    Guard Types: ['DEFAULT_DEVICE']
    Code List: ['utils_device.CURRENT_DEVICE == None']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 15:
    Name: ''
    Source: shape_env
    Create Function: SHAPE_ENV
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
Compile Times: TorchDynamo compilation metrics:
Function                         Runtimes (s)
-------------------------------  --------------
_compile.<locals>.compile_inner  0.0140, 0.0072
OutputGraph.call_user_compiler   0.0010, 0.0000

In order to maximize speedup, graph breaks should be limited. We can force TorchDynamo to raise an error upon the first graph break encountered by using fullgraph=True:

opt_bar = torch.compile(bar, fullgraph=True)
try:
    opt_bar(torch.randn(10), torch.randn(10))
except:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 482, in <module>
    opt_bar(torch.randn(10), torch.randn(10))
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
    return _compile(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
    tracer.run()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2069, in run
    super().run()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 719, in run
    and self.step()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 683, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 370, in inner
    raise exc.UserError(
torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow

from user code:
   File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 434, in bar
    if b.sum() < 0:

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

And below, we demonstrate that TorchDynamo does not break the graph on the model we used above for demonstrating speedups.

opt_model = torch.compile(init_model(), fullgraph=True)
print(opt_model(generate_data(16)[0]))
tensor([[ 0.1346,  0.1981, -0.2058,  ...,  0.1868, -0.1626, -0.1644],
        [ 0.2679,  0.2347, -0.1904,  ...,  0.2167, -0.0060,  0.0307],
        [ 0.0166,  0.2182, -0.1113,  ...,  0.1708, -0.1683, -0.0637],
        ...,
        [ 0.0808,  0.2680, -0.1887,  ...,  0.0378, -0.2078, -0.1444],
        [-0.0211,  0.0857, -0.2459,  ...,  0.1863, -0.1282, -0.0283],
        [-0.0388,  0.0728, -0.1961,  ...,  0.0860, -0.2200, -0.1485]],
       device='cuda:0', grad_fn=<CompiledFunctionBackward>)

We can use torch.export (from PyTorch 2.1+) to extract a single, exportable FX graph from the input PyTorch program. The exported graph is intended to be run on different (i.e. Python-less) environments. One important restriction is that the torch.export does not support graph breaks. Please check this tutorial for more details on torch.export.

Conclusion

In this tutorial, we introduced torch.compile by covering basic usage, demonstrating speedups over eager mode, comparing to previous PyTorch compiler solutions, and briefly investigating TorchDynamo and its interactions with FX graphs. We hope that you will give torch.compile a try!

Total running time of the script: ( 6 minutes 52.759 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources