Note
Click here to download the full example code
torch.compile Tutorial¶
Author: William Wen
torch.compile
is the latest method to speed up your PyTorch code!
torch.compile
makes PyTorch code run faster by
JIT-compiling PyTorch code into optimized kernels,
all while requiring minimal code changes.
In this tutorial, we cover basic torch.compile
usage,
and demonstrate the advantages of torch.compile
over
previous PyTorch compiler solutions, such as
TorchScript and
FX Tracing.
Contents
Basic Usage
Demonstrating Speedups
Comparison to TorchScript and FX Tracing
TorchDynamo and FX Graphs
Conclusion
Required pip Dependencies
torch >= 2.0
torchvision
numpy
scipy
tabulate
NOTE: a modern NVIDIA GPU (H100, A100, or V100) is recommended for this tutorial in order to reproduce the speedup numbers shown below and documented elsewhere.
import torch
import warnings
gpu_ok = False
if torch.cuda.is_available():
device_cap = torch.cuda.get_device_capability()
if device_cap in ((7, 0), (8, 0), (9, 0)):
gpu_ok = True
if not gpu_ok:
warnings.warn(
"GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower "
"than expected."
)
Basic Usage¶
torch.compile
is included in the latest PyTorch..
Running TorchInductor on GPU requires Triton, which is included with the PyTorch 2.0 nightly
binary. If Triton is still missing, try installing torchtriton
via pip
(pip install torchtriton --extra-index-url "https://download.pytorch.org/whl/nightly/cu117"
for CUDA 11.7).
Arbitrary Python functions can be optimized by passing the callable to
torch.compile
. We can then call the returned optimized
function in place of the original function.
def foo(x, y):
a = torch.sin(x)
b = torch.cos(x)
return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))
tensor([[ 1.2028, -0.8198, 1.3456, 0.0646, -0.8120, 0.3657, -0.3298, 1.2512,
0.0617, -0.7635],
[-1.3895, 1.0898, 0.7058, -1.0972, 0.5735, 1.3662, 0.6120, 0.8521,
-0.6361, 1.4060],
[-0.5332, 0.9340, -0.0763, 0.1795, 0.4085, -0.1216, 0.3937, 0.4600,
-0.2902, 1.1192],
[ 0.9617, 1.3994, 1.2001, 1.4135, 1.0895, -1.2855, -0.1238, 0.9845,
-0.6927, 1.3706],
[-0.3904, 0.6147, 1.3255, 0.9164, 0.6438, 1.2898, -0.2472, 0.9181,
1.1070, 0.3966],
[ 1.3430, 0.9218, 0.8789, 1.3962, 1.0236, 0.5662, 1.4078, 1.2897,
1.3904, 0.2585],
[ 1.2359, 1.3456, -0.1253, 1.0474, 1.3414, -0.0851, 1.3539, 0.9941,
0.7292, 0.6930],
[-1.3179, 1.0045, 1.2682, 1.3234, 0.2889, 1.3933, 1.3966, 0.9050,
0.3804, 1.2449],
[ 0.5815, -0.5598, 0.8970, 1.4046, -0.5406, 0.5134, 0.3247, 0.8744,
1.3375, 0.9972],
[-0.3084, 1.3705, 0.6644, 1.0475, -0.2088, 0.0941, 1.1855, 1.2274,
-0.1658, 1.2139]])
Alternatively, we can decorate the function.
@torch.compile
def opt_foo2(x, y):
a = torch.sin(x)
b = torch.cos(x)
return a + b
print(opt_foo2(torch.randn(10, 10), torch.randn(10, 10)))
tensor([[ 0.3756, 1.3493, 0.4288, 0.9815, 1.3079, 0.0590, -0.1370, 0.6304,
0.0464, -0.0692],
[ 1.3692, 1.3246, 1.3717, -0.5833, -0.8194, 1.2445, 0.7953, 1.3003,
0.0527, 1.3214],
[ 0.6659, 1.0319, -0.0831, -0.0723, 0.9993, 1.4142, 1.1154, 0.8908,
0.5703, 1.1342],
[-0.3525, -1.4139, 0.8958, -0.4677, -0.3192, 1.0065, -0.6645, 1.4127,
0.5887, 1.4139],
[ 0.5756, 0.8351, 1.1729, -0.0425, 0.0660, 1.3983, -0.2973, 1.0626,
0.5217, 0.2817],
[ 0.7342, 0.5069, 0.6673, 1.1095, 0.3386, 0.1767, 1.4053, -1.1923,
1.3794, 0.8764],
[ 0.6734, 0.9689, 0.1351, -0.3152, -1.4078, 0.1769, 0.9358, 0.8269,
1.4024, -0.6573],
[-1.2371, 1.3530, -1.3570, 1.0353, 1.4141, 0.3330, 0.9643, 1.1254,
1.0723, 1.3444],
[-0.1108, 0.9511, 0.1918, 1.2880, 0.4769, 1.2417, 1.2710, 0.6166,
0.8541, 0.7168],
[ 1.0000, 1.4102, 1.3548, 0.0645, 1.1483, 1.4108, 1.3512, 1.1252,
0.6453, 1.4137]])
We can also optimize torch.nn.Module
instances.
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(100, 10)
def forward(self, x):
return torch.nn.functional.relu(self.lin(x))
mod = MyModule()
opt_mod = torch.compile(mod)
print(opt_mod(torch.randn(10, 100)))
tensor([[0.2217, 0.2177, 0.7180, 0.0000, 0.1718, 0.0000, 0.0000, 0.0000, 0.1294,
0.2342],
[0.3194, 0.4982, 0.5793, 0.2452, 0.0000, 0.6886, 0.1196, 0.0000, 1.6327,
0.0000],
[0.0000, 0.5968, 0.0000, 0.1686, 0.0000, 0.0000, 0.0000, 0.0000, 0.2766,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.5773, 0.0000, 0.0000, 0.0000, 0.0660,
0.1944],
[0.5061, 0.0000, 0.0000, 0.0000, 0.0000, 0.3677, 0.2802, 0.0000, 0.0000,
0.0000],
[0.5282, 0.0000, 0.0165, 0.1677, 0.0000, 0.0368, 0.0000, 0.0000, 0.1776,
0.0000],
[0.0000, 0.1107, 0.0000, 0.0000, 0.3115, 0.0000, 0.0740, 0.4278, 0.0000,
0.0000],
[0.7416, 0.0000, 0.0000, 0.0000, 0.0000, 0.5975, 0.0357, 0.2453, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0490, 0.7279, 0.0000, 0.5957, 0.4366, 0.0000,
0.0000],
[0.0000, 0.4120, 0.0000, 0.2766, 0.1493, 0.0000, -0.0000, -0.0000, -0.0000,
-0.0000]], grad_fn=<CompiledFunctionBackward>)
Demonstrating Speedups¶
Let’s now demonstrate that using torch.compile
can speed
up real models. We will compare standard eager mode and
torch.compile
by evaluating and training ResNet-18 on random data.
Before we start, we need to define some utility functions.
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1000
# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data(b):
return (
torch.randn(b, 3, 128, 128).to(torch.float32).cuda(),
torch.randint(1000, (b,)).cuda(),
)
N_ITERS = 10
from torchvision.models import densenet121
def init_model():
return densenet121().to(torch.float32).cuda()
First, let’s compare inference.
Note that in the call to torch.compile
, we have have the additional
mode
argument, which we will discuss below.
def evaluate(mod, inp):
return mod(inp)
model = init_model()
# Reset since we are using a different mode.
import torch._dynamo
torch._dynamo.reset()
evaluate_opt = torch.compile(evaluate, mode="reduce-overhead")
inp = generate_data(16)[0]
print("eager:", timed(lambda: evaluate(model, inp))[1])
print("compile:", timed(lambda: evaluate_opt(model, inp))[1])
eager: 2.835186767578125
compile: 68.630125
Notice that torch.compile
takes a lot longer to complete
compared to eager. This is because torch.compile
compiles
the model into optimized kernels as it executes. In our example, the
structure of the model doesn’t change, and so recompilation is not
needed. So if we run our optimized model several more times, we should
see a significant improvement compared to eager.
eager_times = []
compile_times = []
for i in range(N_ITERS):
inp = generate_data(16)[0]
_, eager_time = timed(lambda: evaluate(model, inp))
eager_times.append(eager_time)
print(f"eager eval time {i}: {eager_time}")
print("~" * 10)
compile_times = []
for i in range(N_ITERS):
inp = generate_data(16)[0]
_, compile_time = timed(lambda: evaluate_opt(model, inp))
compile_times.append(compile_time)
print(f"compile eval time {i}: {compile_time}")
print("~" * 10)
import numpy as np
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)
eager eval time 0: 0.04076748657226562
eager eval time 1: 0.040182785034179686
eager eval time 2: 0.02405990409851074
eager eval time 3: 0.024440832138061523
eager eval time 4: 0.024440832138061523
eager eval time 5: 0.02411008071899414
eager eval time 6: 0.024790016174316407
eager eval time 7: 0.024460287094116212
eager eval time 8: 0.02457088088989258
eager eval time 9: 0.02351820755004883
~~~~~~~~~~
compile eval time 0: 0.011045887947082519
compile eval time 1: 0.010876928329467773
compile eval time 2: 0.01080832004547119
compile eval time 3: 0.010898431777954102
compile eval time 4: 0.01061683177947998
compile eval time 5: 0.010513407707214355
compile eval time 6: 0.010568703651428223
compile eval time 7: 0.010604543685913086
compile eval time 8: 0.010563584327697753
compile eval time 9: 0.01063526439666748
~~~~~~~~~~
(eval) eager median: 0.024450559616088868, compile median: 0.010626048088073731, speedup: 2.301002161239157x
~~~~~~~~~~
And indeed, we can see that running our model with torch.compile
results in a significant speedup. Speedup mainly comes from reducing Python overhead and
GPU read/writes, and so the observed speedup may vary on factors such as model
architecture and batch size. For example, if a model’s architecture is simple
and the amount of data is large, then the bottleneck would be
GPU compute and the observed speedup may be less significant.
You may also see different speedup results depending on the chosen mode
argument. Since our model and data are small, we want to reduce overhead as
much as possible, and so we chose "reduce-overhead"
. For your own models,
you may need to experiment with different modes to maximize speedup. You can
read more about modes here.
For general PyTorch benchmarking, you can try using torch.utils.benchmark
instead of the timed
function we defined above. We wrote our own timing function in this tutorial to show
torch.compile
’s compilation latency.
Now, let’s consider comparing training.
model = init_model()
opt = torch.optim.Adam(model.parameters())
def train(mod, data):
opt.zero_grad(True)
pred = mod(data[0])
loss = torch.nn.CrossEntropyLoss()(pred, data[1])
loss.backward()
opt.step()
eager_times = []
for i in range(N_ITERS):
inp = generate_data(16)
_, eager_time = timed(lambda: train(model, inp))
eager_times.append(eager_time)
print(f"eager train time {i}: {eager_time}")
print("~" * 10)
model = init_model()
opt = torch.optim.Adam(model.parameters())
train_opt = torch.compile(train, mode="reduce-overhead")
compile_times = []
for i in range(N_ITERS):
inp = generate_data(16)
_, compile_time = timed(lambda: train_opt(model, inp))
compile_times.append(compile_time)
print(f"compile train time {i}: {compile_time}")
print("~" * 10)
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
print(f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)
eager train time 0: 0.631172119140625
eager train time 1: 0.060521472930908204
eager train time 2: 0.06008115386962891
eager train time 3: 0.05933977508544922
eager train time 4: 0.05841715240478516
eager train time 5: 0.059529216766357425
eager train time 6: 0.05788467025756836
eager train time 7: 0.061462528228759764
eager train time 8: 0.058535934448242184
eager train time 9: 0.05813759994506836
~~~~~~~~~~
compile train time 0: 94.708734375
compile train time 1: 0.03828736114501953
compile train time 2: 0.030166015625
compile train time 3: 0.02949017524719238
compile train time 4: 0.029459455490112304
compile train time 5: 0.029302783966064453
compile train time 6: 0.02916659164428711
compile train time 7: 0.02911129570007324
compile train time 8: 0.029648895263671874
compile train time 9: 0.029459455490112304
~~~~~~~~~~
(train) eager median: 0.05943449592590332, compile median: 0.029474815368652343, speedup: 2.0164501518511395x
~~~~~~~~~~
Again, we can see that torch.compile
takes longer in the first
iteration, as it must compile the model, but in subsequent iterations, we see
significant speedups compared to eager.
Comparison to TorchScript and FX Tracing¶
We have seen that torch.compile
can speed up PyTorch code.
Why else should we use torch.compile
over existing PyTorch
compiler solutions, such as TorchScript or FX Tracing? Primarily, the
advantage of torch.compile
lies in its ability to handle
arbitrary Python code with minimal changes to existing code.
One case that torch.compile
can handle that other compiler
solutions struggle with is data-dependent control flow (the
if x.sum() < 0:
line below).
def f1(x, y):
if x.sum() < 0:
return -y
return y
# Test that `fn1` and `fn2` return the same result, given
# the same arguments `args`. Typically, `fn1` will be an eager function
# while `fn2` will be a compiled function (torch.compile, TorchScript, or FX graph).
def test_fns(fn1, fn2, args):
out1 = fn1(*args)
out2 = fn2(*args)
return torch.allclose(out1, out2)
inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)
TorchScript tracing f1
results in
silently incorrect results, since only the actual control flow path
is traced.
/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py:268: TracerWarning:
Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
traced 1, 1: True
traced 1, 2: False
FX tracing f1
results in an error due to the presence of
data-dependent control flow.
import traceback as tb
try:
torch.fx.symbolic_trace(f1)
except:
tb.print_exc()
Traceback (most recent call last):
File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 298, in <module>
torch.fx.symbolic_trace(f1)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 1109, in symbolic_trace
graph = tracer.trace(root, concrete_args)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 778, in trace
(self.create_arg(fn(*args)),),
File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 268, in f1
if x.sum() < 0:
File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 413, in __bool__
return self.tracer.to_bool(self)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 276, in to_bool
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
If we provide a value for x
as we try to FX trace f1
, then
we run into the same problem as TorchScript tracing, as the data-dependent
control flow is removed in the traced function.
/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py:602: UserWarning:
Was not able to add assertion to guarantee correct input x to specialized function. It is up to the user to make sure that your inputs match the inputs you specialized the function with.
fx 1, 1: True
fx 1, 2: False
Now we can see that torch.compile
correctly handles
data-dependent control flow.
# Reset since we are using a different mode.
torch._dynamo.reset()
compile_f1 = torch.compile(f1)
print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2)))
print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2)))
print("~" * 10)
compile 1, 1: True
compile 1, 2: True
~~~~~~~~~~
TorchScript scripting can handle data-dependent control flow, but this solution comes with its own set of problems. Namely, TorchScript scripting can require major code changes and will raise errors when unsupported Python is used.
In the example below, we forget TorchScript type annotations and we receive
a TorchScript error because the input type for argument y
, an int
,
does not match with the default argument type, torch.Tensor
.
def f2(x, y):
return x + y
inp1 = torch.randn(5, 5)
inp2 = 3
script_f2 = torch.jit.script(f2)
try:
script_f2(inp1, inp2)
except:
tb.print_exc()
Traceback (most recent call last):
File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 341, in <module>
script_f2(inp1, inp2)
RuntimeError: f2() Expected a value of type 'Tensor (inferred)' for argument 'y' but instead found type 'int'.
Inferred 'y' to be of type 'Tensor' because it was not annotated with an explicit type.
Position: 1
Value: 3
Declaration: f2(Tensor x, Tensor y) -> Tensor
Cast error details: Unable to cast 3 to Tensor
However, torch.compile
is easily able to handle f2
.
compile_f2 = torch.compile(f2)
print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2)))
print("~" * 10)
compile 2: True
~~~~~~~~~~
Another case that torch.compile
handles well compared to
previous compilers solutions is the usage of non-PyTorch functions.
import scipy
def f3(x):
x = x * 2
x = scipy.fft.dct(x.numpy())
x = torch.from_numpy(x)
x = x * 2
return x
TorchScript tracing treats results from non-PyTorch function calls as constants, and so our results can be silently wrong.
inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)
traced_f3 = torch.jit.trace(f3, (inp1,))
print("traced 3:", test_fns(f3, traced_f3, (inp2,)))
/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py:359: TracerWarning:
Converting a tensor to a NumPy array might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py:360: TracerWarning:
torch.from_numpy results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
traced 3: False
TorchScript scripting and FX tracing disallow non-PyTorch function calls.
try:
torch.jit.script(f3)
except:
tb.print_exc()
try:
torch.fx.symbolic_trace(f3)
except:
tb.print_exc()
Traceback (most recent call last):
File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 377, in <module>
torch.jit.script(f3)
File "/opt/conda/lib/python3.10/site-packages/torch/jit/_script.py", line 1341, in script
fn = torch._C._jit_script_compile(
File "/opt/conda/lib/python3.10/site-packages/torch/_jit_internal.py", line 1198, in _try_get_dispatched_fn
return boolean_dispatched.get(fn)
File "/opt/conda/lib/python3.10/weakref.py", line 453, in get
return self.data.get(ref(key),default)
TypeError: cannot create weak reference to 'uarray._Function' object
Traceback (most recent call last):
File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 382, in <module>
torch.fx.symbolic_trace(f3)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 1109, in symbolic_trace
graph = tracer.trace(root, concrete_args)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 778, in trace
(self.create_arg(fn(*args)),),
File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 359, in f3
x = scipy.fft.dct(x.numpy())
File "/opt/conda/lib/python3.10/site-packages/scipy/fft/_backend.py", line 25, in __ua_function__
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/scipy/fft/_pocketfft/realtransforms.py", line 19, in _r2r
tmp = _asfarray(x)
File "/opt/conda/lib/python3.10/site-packages/scipy/fft/_pocketfft/helper.py", line 89, in _asfarray
if x.dtype == np.float16:
File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 518, in impl
return tracer.create_proxy('call_function', target, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 151, in create_proxy
args_ = self.create_arg(args)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 373, in create_arg
return super().create_arg(a)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 239, in create_arg
return type(a)(self.create_arg(elem) for elem in a)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 239, in <genexpr>
return type(a)(self.create_arg(elem) for elem in a)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 373, in create_arg
return super().create_arg(a)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 267, in create_arg
raise NotImplementedError(f"argument of type: {type(a)}")
NotImplementedError: argument of type: <class 'type'>
In comparison, torch.compile
is easily able to handle
the non-PyTorch function call.
compile_f3 = torch.compile(f3)
print("compile 3:", test_fns(f3, compile_f3, (inp2,)))
compile 3: True
TorchDynamo and FX Graphs¶
One important component of torch.compile
is TorchDynamo.
TorchDynamo is responsible for JIT compiling arbitrary Python code into
FX graphs, which can
then be further optimized. TorchDynamo extracts FX graphs by analyzing Python bytecode
during runtime and detecting calls to PyTorch operations.
Normally, TorchInductor, another component of torch.compile
,
further compiles the FX graphs into optimized kernels,
but TorchDynamo allows for different backends to be used. In order to inspect
the FX graphs that TorchDynamo outputs, let us create a custom backend that
outputs the FX graph and simply returns the graph’s unoptimized forward method.
from typing import List
def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("custom backend called with FX graph:")
gm.graph.print_tabular()
return gm.forward
# Reset since we are using a different backend.
torch._dynamo.reset()
opt_model = torch.compile(init_model(), backend=custom_backend)
opt_model(generate_data(16)[0])
custom backend called with FX graph:
opcode name target args kwargs
------------- ----------------------------------- ---------------------------------------------------------- --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- -----------------
placeholder x x () {}
call_module self_features_0 self_features_0 (x,) {}
call_module self_features_1 self_features_1 (self_features_0,) {}
call_module self_features_2 self_features_2 (self_features_1,) {}
call_module self_features_3 self_features_3 (self_features_2,) {}
call_function cat <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_3], 1) {}
call_module self_features_4_denselayer1_norm1 self_features_4_denselayer1_norm1 (cat,) {}
call_module self_features_4_denselayer1_relu1 self_features_4_denselayer1_relu1 (self_features_4_denselayer1_norm1,) {}
call_module self_features_4_denselayer1_conv1 self_features_4_denselayer1_conv1 (self_features_4_denselayer1_relu1,) {}
call_module self_features_4_denselayer1_norm2 self_features_4_denselayer1_norm2 (self_features_4_denselayer1_conv1,) {}
call_module self_features_4_denselayer1_relu2 self_features_4_denselayer1_relu2 (self_features_4_denselayer1_norm2,) {}
call_module self_features_4_denselayer1_conv2 self_features_4_denselayer1_conv2 (self_features_4_denselayer1_relu2,) {}
call_function cat_1 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_3, self_features_4_denselayer1_conv2], 1) {}
call_module self_features_4_denselayer2_norm1 self_features_4_denselayer2_norm1 (cat_1,) {}
call_module self_features_4_denselayer2_relu1 self_features_4_denselayer2_relu1 (self_features_4_denselayer2_norm1,) {}
call_module self_features_4_denselayer2_conv1 self_features_4_denselayer2_conv1 (self_features_4_denselayer2_relu1,) {}
call_module self_features_4_denselayer2_norm2 self_features_4_denselayer2_norm2 (self_features_4_denselayer2_conv1,) {}
call_module self_features_4_denselayer2_relu2 self_features_4_denselayer2_relu2 (self_features_4_denselayer2_norm2,) {}
call_module self_features_4_denselayer2_conv2 self_features_4_denselayer2_conv2 (self_features_4_denselayer2_relu2,) {}
call_function cat_2 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_3, self_features_4_denselayer1_conv2, self_features_4_denselayer2_conv2], 1) {}
call_module self_features_4_denselayer3_norm1 self_features_4_denselayer3_norm1 (cat_2,) {}
call_module self_features_4_denselayer3_relu1 self_features_4_denselayer3_relu1 (self_features_4_denselayer3_norm1,) {}
call_module self_features_4_denselayer3_conv1 self_features_4_denselayer3_conv1 (self_features_4_denselayer3_relu1,) {}
call_module self_features_4_denselayer3_norm2 self_features_4_denselayer3_norm2 (self_features_4_denselayer3_conv1,) {}
call_module self_features_4_denselayer3_relu2 self_features_4_denselayer3_relu2 (self_features_4_denselayer3_norm2,) {}
call_module self_features_4_denselayer3_conv2 self_features_4_denselayer3_conv2 (self_features_4_denselayer3_relu2,) {}
call_function cat_3 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_3, self_features_4_denselayer1_conv2, self_features_4_denselayer2_conv2, self_features_4_denselayer3_conv2], 1) {}
call_module self_features_4_denselayer4_norm1 self_features_4_denselayer4_norm1 (cat_3,) {}
call_module self_features_4_denselayer4_relu1 self_features_4_denselayer4_relu1 (self_features_4_denselayer4_norm1,) {}
call_module self_features_4_denselayer4_conv1 self_features_4_denselayer4_conv1 (self_features_4_denselayer4_relu1,) {}
call_module self_features_4_denselayer4_norm2 self_features_4_denselayer4_norm2 (self_features_4_denselayer4_conv1,) {}
call_module self_features_4_denselayer4_relu2 self_features_4_denselayer4_relu2 (self_features_4_denselayer4_norm2,) {}
call_module self_features_4_denselayer4_conv2 self_features_4_denselayer4_conv2 (self_features_4_denselayer4_relu2,) {}
call_function cat_4 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_3, self_features_4_denselayer1_conv2, self_features_4_denselayer2_conv2, self_features_4_denselayer3_conv2, self_features_4_denselayer4_conv2], 1) {}
call_module self_features_4_denselayer5_norm1 self_features_4_denselayer5_norm1 (cat_4,) {}
call_module self_features_4_denselayer5_relu1 self_features_4_denselayer5_relu1 (self_features_4_denselayer5_norm1,) {}
call_module self_features_4_denselayer5_conv1 self_features_4_denselayer5_conv1 (self_features_4_denselayer5_relu1,) {}
call_module self_features_4_denselayer5_norm2 self_features_4_denselayer5_norm2 (self_features_4_denselayer5_conv1,) {}
call_module self_features_4_denselayer5_relu2 self_features_4_denselayer5_relu2 (self_features_4_denselayer5_norm2,) {}
call_module self_features_4_denselayer5_conv2 self_features_4_denselayer5_conv2 (self_features_4_denselayer5_relu2,) {}
call_function cat_5 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_3, self_features_4_denselayer1_conv2, self_features_4_denselayer2_conv2, self_features_4_denselayer3_conv2, self_features_4_denselayer4_conv2, self_features_4_denselayer5_conv2], 1) {}
call_module self_features_4_denselayer6_norm1 self_features_4_denselayer6_norm1 (cat_5,) {}
call_module self_features_4_denselayer6_relu1 self_features_4_denselayer6_relu1 (self_features_4_denselayer6_norm1,) {}
call_module self_features_4_denselayer6_conv1 self_features_4_denselayer6_conv1 (self_features_4_denselayer6_relu1,) {}
call_module self_features_4_denselayer6_norm2 self_features_4_denselayer6_norm2 (self_features_4_denselayer6_conv1,) {}
call_module self_features_4_denselayer6_relu2 self_features_4_denselayer6_relu2 (self_features_4_denselayer6_norm2,) {}
call_module self_features_4_denselayer6_conv2 self_features_4_denselayer6_conv2 (self_features_4_denselayer6_relu2,) {}
call_function cat_6 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_3, self_features_4_denselayer1_conv2, self_features_4_denselayer2_conv2, self_features_4_denselayer3_conv2, self_features_4_denselayer4_conv2, self_features_4_denselayer5_conv2, self_features_4_denselayer6_conv2], 1) {}
call_module self_features_5_0 self_features_5_0 (cat_6,) {}
call_module self_features_5_1 self_features_5_1 (self_features_5_0,) {}
call_module self_features_5_2 self_features_5_2 (self_features_5_1,) {}
call_module self_features_5_3 self_features_5_3 (self_features_5_2,) {}
call_function cat_7 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_5_3], 1) {}
call_module self_features_6_denselayer1_norm1 self_features_6_denselayer1_norm1 (cat_7,) {}
call_module self_features_6_denselayer1_relu1 self_features_6_denselayer1_relu1 (self_features_6_denselayer1_norm1,) {}
call_module self_features_6_denselayer1_conv1 self_features_6_denselayer1_conv1 (self_features_6_denselayer1_relu1,) {}
call_module self_features_6_denselayer1_norm2 self_features_6_denselayer1_norm2 (self_features_6_denselayer1_conv1,) {}
call_module self_features_6_denselayer1_relu2 self_features_6_denselayer1_relu2 (self_features_6_denselayer1_norm2,) {}
call_module self_features_6_denselayer1_conv2 self_features_6_denselayer1_conv2 (self_features_6_denselayer1_relu2,) {}
call_function cat_8 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_5_3, self_features_6_denselayer1_conv2], 1) {}
call_module self_features_6_denselayer2_norm1 self_features_6_denselayer2_norm1 (cat_8,) {}
call_module self_features_6_denselayer2_relu1 self_features_6_denselayer2_relu1 (self_features_6_denselayer2_norm1,) {}
call_module self_features_6_denselayer2_conv1 self_features_6_denselayer2_conv1 (self_features_6_denselayer2_relu1,) {}
call_module self_features_6_denselayer2_norm2 self_features_6_denselayer2_norm2 (self_features_6_denselayer2_conv1,) {}
call_module self_features_6_denselayer2_relu2 self_features_6_denselayer2_relu2 (self_features_6_denselayer2_norm2,) {}
call_module self_features_6_denselayer2_conv2 self_features_6_denselayer2_conv2 (self_features_6_denselayer2_relu2,) {}
call_function cat_9 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_5_3, self_features_6_denselayer1_conv2, self_features_6_denselayer2_conv2], 1) {}
call_module self_features_6_denselayer3_norm1 self_features_6_denselayer3_norm1 (cat_9,) {}
call_module self_features_6_denselayer3_relu1 self_features_6_denselayer3_relu1 (self_features_6_denselayer3_norm1,) {}
call_module self_features_6_denselayer3_conv1 self_features_6_denselayer3_conv1 (self_features_6_denselayer3_relu1,) {}
call_module self_features_6_denselayer3_norm2 self_features_6_denselayer3_norm2 (self_features_6_denselayer3_conv1,) {}
call_module self_features_6_denselayer3_relu2 self_features_6_denselayer3_relu2 (self_features_6_denselayer3_norm2,) {}
call_module self_features_6_denselayer3_conv2 self_features_6_denselayer3_conv2 (self_features_6_denselayer3_relu2,) {}
call_function cat_10 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_5_3, self_features_6_denselayer1_conv2, self_features_6_denselayer2_conv2, self_features_6_denselayer3_conv2], 1) {}
call_module self_features_6_denselayer4_norm1 self_features_6_denselayer4_norm1 (cat_10,) {}
call_module self_features_6_denselayer4_relu1 self_features_6_denselayer4_relu1 (self_features_6_denselayer4_norm1,) {}
call_module self_features_6_denselayer4_conv1 self_features_6_denselayer4_conv1 (self_features_6_denselayer4_relu1,) {}
call_module self_features_6_denselayer4_norm2 self_features_6_denselayer4_norm2 (self_features_6_denselayer4_conv1,) {}
call_module self_features_6_denselayer4_relu2 self_features_6_denselayer4_relu2 (self_features_6_denselayer4_norm2,) {}
call_module self_features_6_denselayer4_conv2 self_features_6_denselayer4_conv2 (self_features_6_denselayer4_relu2,) {}
call_function cat_11 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_5_3, self_features_6_denselayer1_conv2, self_features_6_denselayer2_conv2, self_features_6_denselayer3_conv2, self_features_6_denselayer4_conv2], 1) {}
call_module self_features_6_denselayer5_norm1 self_features_6_denselayer5_norm1 (cat_11,) {}
call_module self_features_6_denselayer5_relu1 self_features_6_denselayer5_relu1 (self_features_6_denselayer5_norm1,) {}
call_module self_features_6_denselayer5_conv1 self_features_6_denselayer5_conv1 (self_features_6_denselayer5_relu1,) {}
call_module self_features_6_denselayer5_norm2 self_features_6_denselayer5_norm2 (self_features_6_denselayer5_conv1,) {}
call_module self_features_6_denselayer5_relu2 self_features_6_denselayer5_relu2 (self_features_6_denselayer5_norm2,) {}
call_module self_features_6_denselayer5_conv2 self_features_6_denselayer5_conv2 (self_features_6_denselayer5_relu2,) {}
call_function cat_12 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_5_3, self_features_6_denselayer1_conv2, self_features_6_denselayer2_conv2, self_features_6_denselayer3_conv2, self_features_6_denselayer4_conv2, self_features_6_denselayer5_conv2], 1) {}
call_module self_features_6_denselayer6_norm1 self_features_6_denselayer6_norm1 (cat_12,) {}
call_module self_features_6_denselayer6_relu1 self_features_6_denselayer6_relu1 (self_features_6_denselayer6_norm1,) {}
call_module self_features_6_denselayer6_conv1 self_features_6_denselayer6_conv1 (self_features_6_denselayer6_relu1,) {}
call_module self_features_6_denselayer6_norm2 self_features_6_denselayer6_norm2 (self_features_6_denselayer6_conv1,) {}
call_module self_features_6_denselayer6_relu2 self_features_6_denselayer6_relu2 (self_features_6_denselayer6_norm2,) {}
call_module self_features_6_denselayer6_conv2 self_features_6_denselayer6_conv2 (self_features_6_denselayer6_relu2,) {}
call_function cat_13 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_5_3, self_features_6_denselayer1_conv2, self_features_6_denselayer2_conv2, self_features_6_denselayer3_conv2, self_features_6_denselayer4_conv2, self_features_6_denselayer5_conv2, self_features_6_denselayer6_conv2], 1) {}
call_module self_features_6_denselayer7_norm1 self_features_6_denselayer7_norm1 (cat_13,) {}
call_module self_features_6_denselayer7_relu1 self_features_6_denselayer7_relu1 (self_features_6_denselayer7_norm1,) {}
call_module self_features_6_denselayer7_conv1 self_features_6_denselayer7_conv1 (self_features_6_denselayer7_relu1,) {}
call_module self_features_6_denselayer7_norm2 self_features_6_denselayer7_norm2 (self_features_6_denselayer7_conv1,) {}
call_module self_features_6_denselayer7_relu2 self_features_6_denselayer7_relu2 (self_features_6_denselayer7_norm2,) {}
call_module self_features_6_denselayer7_conv2 self_features_6_denselayer7_conv2 (self_features_6_denselayer7_relu2,) {}
call_function cat_14 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_5_3, self_features_6_denselayer1_conv2, self_features_6_denselayer2_conv2, self_features_6_denselayer3_conv2, self_features_6_denselayer4_conv2, self_features_6_denselayer5_conv2, self_features_6_denselayer6_conv2, self_features_6_denselayer7_conv2], 1) {}
call_module self_features_6_denselayer8_norm1 self_features_6_denselayer8_norm1 (cat_14,) {}
call_module self_features_6_denselayer8_relu1 self_features_6_denselayer8_relu1 (self_features_6_denselayer8_norm1,) {}
call_module self_features_6_denselayer8_conv1 self_features_6_denselayer8_conv1 (self_features_6_denselayer8_relu1,) {}
call_module self_features_6_denselayer8_norm2 self_features_6_denselayer8_norm2 (self_features_6_denselayer8_conv1,) {}
call_module self_features_6_denselayer8_relu2 self_features_6_denselayer8_relu2 (self_features_6_denselayer8_norm2,) {}
call_module self_features_6_denselayer8_conv2 self_features_6_denselayer8_conv2 (self_features_6_denselayer8_relu2,) {}
call_function cat_15 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_5_3, self_features_6_denselayer1_conv2, self_features_6_denselayer2_conv2, self_features_6_denselayer3_conv2, self_features_6_denselayer4_conv2, self_features_6_denselayer5_conv2, self_features_6_denselayer6_conv2, self_features_6_denselayer7_conv2, self_features_6_denselayer8_conv2], 1) {}
call_module self_features_6_denselayer9_norm1 self_features_6_denselayer9_norm1 (cat_15,) {}
call_module self_features_6_denselayer9_relu1 self_features_6_denselayer9_relu1 (self_features_6_denselayer9_norm1,) {}
call_module self_features_6_denselayer9_conv1 self_features_6_denselayer9_conv1 (self_features_6_denselayer9_relu1,) {}
call_module self_features_6_denselayer9_norm2 self_features_6_denselayer9_norm2 (self_features_6_denselayer9_conv1,) {}
call_module self_features_6_denselayer9_relu2 self_features_6_denselayer9_relu2 (self_features_6_denselayer9_norm2,) {}
call_module self_features_6_denselayer9_conv2 self_features_6_denselayer9_conv2 (self_features_6_denselayer9_relu2,) {}
call_function cat_16 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_5_3, self_features_6_denselayer1_conv2, self_features_6_denselayer2_conv2, self_features_6_denselayer3_conv2, self_features_6_denselayer4_conv2, self_features_6_denselayer5_conv2, self_features_6_denselayer6_conv2, self_features_6_denselayer7_conv2, self_features_6_denselayer8_conv2, self_features_6_denselayer9_conv2], 1) {}
call_module self_features_6_denselayer10_norm1 self_features_6_denselayer10_norm1 (cat_16,) {}
call_module self_features_6_denselayer10_relu1 self_features_6_denselayer10_relu1 (self_features_6_denselayer10_norm1,) {}
call_module self_features_6_denselayer10_conv1 self_features_6_denselayer10_conv1 (self_features_6_denselayer10_relu1,) {}
call_module self_features_6_denselayer10_norm2 self_features_6_denselayer10_norm2 (self_features_6_denselayer10_conv1,) {}
call_module self_features_6_denselayer10_relu2 self_features_6_denselayer10_relu2 (self_features_6_denselayer10_norm2,) {}
call_module self_features_6_denselayer10_conv2 self_features_6_denselayer10_conv2 (self_features_6_denselayer10_relu2,) {}
call_function cat_17 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_5_3, self_features_6_denselayer1_conv2, self_features_6_denselayer2_conv2, self_features_6_denselayer3_conv2, self_features_6_denselayer4_conv2, self_features_6_denselayer5_conv2, self_features_6_denselayer6_conv2, self_features_6_denselayer7_conv2, self_features_6_denselayer8_conv2, self_features_6_denselayer9_conv2, self_features_6_denselayer10_conv2], 1) {}
call_module self_features_6_denselayer11_norm1 self_features_6_denselayer11_norm1 (cat_17,) {}
call_module self_features_6_denselayer11_relu1 self_features_6_denselayer11_relu1 (self_features_6_denselayer11_norm1,) {}
call_module self_features_6_denselayer11_conv1 self_features_6_denselayer11_conv1 (self_features_6_denselayer11_relu1,) {}
call_module self_features_6_denselayer11_norm2 self_features_6_denselayer11_norm2 (self_features_6_denselayer11_conv1,) {}
call_module self_features_6_denselayer11_relu2 self_features_6_denselayer11_relu2 (self_features_6_denselayer11_norm2,) {}
call_module self_features_6_denselayer11_conv2 self_features_6_denselayer11_conv2 (self_features_6_denselayer11_relu2,) {}
call_function cat_18 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_5_3, self_features_6_denselayer1_conv2, self_features_6_denselayer2_conv2, self_features_6_denselayer3_conv2, self_features_6_denselayer4_conv2, self_features_6_denselayer5_conv2, self_features_6_denselayer6_conv2, self_features_6_denselayer7_conv2, self_features_6_denselayer8_conv2, self_features_6_denselayer9_conv2, self_features_6_denselayer10_conv2, self_features_6_denselayer11_conv2], 1) {}
call_module self_features_6_denselayer12_norm1 self_features_6_denselayer12_norm1 (cat_18,) {}
call_module self_features_6_denselayer12_relu1 self_features_6_denselayer12_relu1 (self_features_6_denselayer12_norm1,) {}
call_module self_features_6_denselayer12_conv1 self_features_6_denselayer12_conv1 (self_features_6_denselayer12_relu1,) {}
call_module self_features_6_denselayer12_norm2 self_features_6_denselayer12_norm2 (self_features_6_denselayer12_conv1,) {}
call_module self_features_6_denselayer12_relu2 self_features_6_denselayer12_relu2 (self_features_6_denselayer12_norm2,) {}
call_module self_features_6_denselayer12_conv2 self_features_6_denselayer12_conv2 (self_features_6_denselayer12_relu2,) {}
call_function cat_19 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_5_3, self_features_6_denselayer1_conv2, self_features_6_denselayer2_conv2, self_features_6_denselayer3_conv2, self_features_6_denselayer4_conv2, self_features_6_denselayer5_conv2, self_features_6_denselayer6_conv2, self_features_6_denselayer7_conv2, self_features_6_denselayer8_conv2, self_features_6_denselayer9_conv2, self_features_6_denselayer10_conv2, self_features_6_denselayer11_conv2, self_features_6_denselayer12_conv2], 1) {}
call_module self_features_7_0 self_features_7_0 (cat_19,) {}
call_module self_features_7_1 self_features_7_1 (self_features_7_0,) {}
call_module self_features_7_2 self_features_7_2 (self_features_7_1,) {}
call_module self_features_7_3 self_features_7_3 (self_features_7_2,) {}
call_function cat_20 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3], 1) {}
call_module self_features_8_denselayer1_norm1 self_features_8_denselayer1_norm1 (cat_20,) {}
call_module self_features_8_denselayer1_relu1 self_features_8_denselayer1_relu1 (self_features_8_denselayer1_norm1,) {}
call_module self_features_8_denselayer1_conv1 self_features_8_denselayer1_conv1 (self_features_8_denselayer1_relu1,) {}
call_module self_features_8_denselayer1_norm2 self_features_8_denselayer1_norm2 (self_features_8_denselayer1_conv1,) {}
call_module self_features_8_denselayer1_relu2 self_features_8_denselayer1_relu2 (self_features_8_denselayer1_norm2,) {}
call_module self_features_8_denselayer1_conv2 self_features_8_denselayer1_conv2 (self_features_8_denselayer1_relu2,) {}
call_function cat_21 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2], 1) {}
call_module self_features_8_denselayer2_norm1 self_features_8_denselayer2_norm1 (cat_21,) {}
call_module self_features_8_denselayer2_relu1 self_features_8_denselayer2_relu1 (self_features_8_denselayer2_norm1,) {}
call_module self_features_8_denselayer2_conv1 self_features_8_denselayer2_conv1 (self_features_8_denselayer2_relu1,) {}
call_module self_features_8_denselayer2_norm2 self_features_8_denselayer2_norm2 (self_features_8_denselayer2_conv1,) {}
call_module self_features_8_denselayer2_relu2 self_features_8_denselayer2_relu2 (self_features_8_denselayer2_norm2,) {}
call_module self_features_8_denselayer2_conv2 self_features_8_denselayer2_conv2 (self_features_8_denselayer2_relu2,) {}
call_function cat_22 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2], 1) {}
call_module self_features_8_denselayer3_norm1 self_features_8_denselayer3_norm1 (cat_22,) {}
call_module self_features_8_denselayer3_relu1 self_features_8_denselayer3_relu1 (self_features_8_denselayer3_norm1,) {}
call_module self_features_8_denselayer3_conv1 self_features_8_denselayer3_conv1 (self_features_8_denselayer3_relu1,) {}
call_module self_features_8_denselayer3_norm2 self_features_8_denselayer3_norm2 (self_features_8_denselayer3_conv1,) {}
call_module self_features_8_denselayer3_relu2 self_features_8_denselayer3_relu2 (self_features_8_denselayer3_norm2,) {}
call_module self_features_8_denselayer3_conv2 self_features_8_denselayer3_conv2 (self_features_8_denselayer3_relu2,) {}
call_function cat_23 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2], 1) {}
call_module self_features_8_denselayer4_norm1 self_features_8_denselayer4_norm1 (cat_23,) {}
call_module self_features_8_denselayer4_relu1 self_features_8_denselayer4_relu1 (self_features_8_denselayer4_norm1,) {}
call_module self_features_8_denselayer4_conv1 self_features_8_denselayer4_conv1 (self_features_8_denselayer4_relu1,) {}
call_module self_features_8_denselayer4_norm2 self_features_8_denselayer4_norm2 (self_features_8_denselayer4_conv1,) {}
call_module self_features_8_denselayer4_relu2 self_features_8_denselayer4_relu2 (self_features_8_denselayer4_norm2,) {}
call_module self_features_8_denselayer4_conv2 self_features_8_denselayer4_conv2 (self_features_8_denselayer4_relu2,) {}
call_function cat_24 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2, self_features_8_denselayer4_conv2], 1) {}
call_module self_features_8_denselayer5_norm1 self_features_8_denselayer5_norm1 (cat_24,) {}
call_module self_features_8_denselayer5_relu1 self_features_8_denselayer5_relu1 (self_features_8_denselayer5_norm1,) {}
call_module self_features_8_denselayer5_conv1 self_features_8_denselayer5_conv1 (self_features_8_denselayer5_relu1,) {}
call_module self_features_8_denselayer5_norm2 self_features_8_denselayer5_norm2 (self_features_8_denselayer5_conv1,) {}
call_module self_features_8_denselayer5_relu2 self_features_8_denselayer5_relu2 (self_features_8_denselayer5_norm2,) {}
call_module self_features_8_denselayer5_conv2 self_features_8_denselayer5_conv2 (self_features_8_denselayer5_relu2,) {}
call_function cat_25 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2, self_features_8_denselayer4_conv2, self_features_8_denselayer5_conv2], 1) {}
call_module self_features_8_denselayer6_norm1 self_features_8_denselayer6_norm1 (cat_25,) {}
call_module self_features_8_denselayer6_relu1 self_features_8_denselayer6_relu1 (self_features_8_denselayer6_norm1,) {}
call_module self_features_8_denselayer6_conv1 self_features_8_denselayer6_conv1 (self_features_8_denselayer6_relu1,) {}
call_module self_features_8_denselayer6_norm2 self_features_8_denselayer6_norm2 (self_features_8_denselayer6_conv1,) {}
call_module self_features_8_denselayer6_relu2 self_features_8_denselayer6_relu2 (self_features_8_denselayer6_norm2,) {}
call_module self_features_8_denselayer6_conv2 self_features_8_denselayer6_conv2 (self_features_8_denselayer6_relu2,) {}
call_function cat_26 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2, self_features_8_denselayer4_conv2, self_features_8_denselayer5_conv2, self_features_8_denselayer6_conv2], 1) {}
call_module self_features_8_denselayer7_norm1 self_features_8_denselayer7_norm1 (cat_26,) {}
call_module self_features_8_denselayer7_relu1 self_features_8_denselayer7_relu1 (self_features_8_denselayer7_norm1,) {}
call_module self_features_8_denselayer7_conv1 self_features_8_denselayer7_conv1 (self_features_8_denselayer7_relu1,) {}
call_module self_features_8_denselayer7_norm2 self_features_8_denselayer7_norm2 (self_features_8_denselayer7_conv1,) {}
call_module self_features_8_denselayer7_relu2 self_features_8_denselayer7_relu2 (self_features_8_denselayer7_norm2,) {}
call_module self_features_8_denselayer7_conv2 self_features_8_denselayer7_conv2 (self_features_8_denselayer7_relu2,) {}
call_function cat_27 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2, self_features_8_denselayer4_conv2, self_features_8_denselayer5_conv2, self_features_8_denselayer6_conv2, self_features_8_denselayer7_conv2], 1) {}
call_module self_features_8_denselayer8_norm1 self_features_8_denselayer8_norm1 (cat_27,) {}
call_module self_features_8_denselayer8_relu1 self_features_8_denselayer8_relu1 (self_features_8_denselayer8_norm1,) {}
call_module self_features_8_denselayer8_conv1 self_features_8_denselayer8_conv1 (self_features_8_denselayer8_relu1,) {}
call_module self_features_8_denselayer8_norm2 self_features_8_denselayer8_norm2 (self_features_8_denselayer8_conv1,) {}
call_module self_features_8_denselayer8_relu2 self_features_8_denselayer8_relu2 (self_features_8_denselayer8_norm2,) {}
call_module self_features_8_denselayer8_conv2 self_features_8_denselayer8_conv2 (self_features_8_denselayer8_relu2,) {}
call_function cat_28 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2, self_features_8_denselayer4_conv2, self_features_8_denselayer5_conv2, self_features_8_denselayer6_conv2, self_features_8_denselayer7_conv2, self_features_8_denselayer8_conv2], 1) {}
call_module self_features_8_denselayer9_norm1 self_features_8_denselayer9_norm1 (cat_28,) {}
call_module self_features_8_denselayer9_relu1 self_features_8_denselayer9_relu1 (self_features_8_denselayer9_norm1,) {}
call_module self_features_8_denselayer9_conv1 self_features_8_denselayer9_conv1 (self_features_8_denselayer9_relu1,) {}
call_module self_features_8_denselayer9_norm2 self_features_8_denselayer9_norm2 (self_features_8_denselayer9_conv1,) {}
call_module self_features_8_denselayer9_relu2 self_features_8_denselayer9_relu2 (self_features_8_denselayer9_norm2,) {}
call_module self_features_8_denselayer9_conv2 self_features_8_denselayer9_conv2 (self_features_8_denselayer9_relu2,) {}
call_function cat_29 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2, self_features_8_denselayer4_conv2, self_features_8_denselayer5_conv2, self_features_8_denselayer6_conv2, self_features_8_denselayer7_conv2, self_features_8_denselayer8_conv2, self_features_8_denselayer9_conv2], 1) {}
call_module self_features_8_denselayer10_norm1 self_features_8_denselayer10_norm1 (cat_29,) {}
call_module self_features_8_denselayer10_relu1 self_features_8_denselayer10_relu1 (self_features_8_denselayer10_norm1,) {}
call_module self_features_8_denselayer10_conv1 self_features_8_denselayer10_conv1 (self_features_8_denselayer10_relu1,) {}
call_module self_features_8_denselayer10_norm2 self_features_8_denselayer10_norm2 (self_features_8_denselayer10_conv1,) {}
call_module self_features_8_denselayer10_relu2 self_features_8_denselayer10_relu2 (self_features_8_denselayer10_norm2,) {}
call_module self_features_8_denselayer10_conv2 self_features_8_denselayer10_conv2 (self_features_8_denselayer10_relu2,) {}
call_function cat_30 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2, self_features_8_denselayer4_conv2, self_features_8_denselayer5_conv2, self_features_8_denselayer6_conv2, self_features_8_denselayer7_conv2, self_features_8_denselayer8_conv2, self_features_8_denselayer9_conv2, self_features_8_denselayer10_conv2], 1) {}
call_module self_features_8_denselayer11_norm1 self_features_8_denselayer11_norm1 (cat_30,) {}
call_module self_features_8_denselayer11_relu1 self_features_8_denselayer11_relu1 (self_features_8_denselayer11_norm1,) {}
call_module self_features_8_denselayer11_conv1 self_features_8_denselayer11_conv1 (self_features_8_denselayer11_relu1,) {}
call_module self_features_8_denselayer11_norm2 self_features_8_denselayer11_norm2 (self_features_8_denselayer11_conv1,) {}
call_module self_features_8_denselayer11_relu2 self_features_8_denselayer11_relu2 (self_features_8_denselayer11_norm2,) {}
call_module self_features_8_denselayer11_conv2 self_features_8_denselayer11_conv2 (self_features_8_denselayer11_relu2,) {}
call_function cat_31 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2, self_features_8_denselayer4_conv2, self_features_8_denselayer5_conv2, self_features_8_denselayer6_conv2, self_features_8_denselayer7_conv2, self_features_8_denselayer8_conv2, self_features_8_denselayer9_conv2, self_features_8_denselayer10_conv2, self_features_8_denselayer11_conv2], 1) {}
call_module self_features_8_denselayer12_norm1 self_features_8_denselayer12_norm1 (cat_31,) {}
call_module self_features_8_denselayer12_relu1 self_features_8_denselayer12_relu1 (self_features_8_denselayer12_norm1,) {}
call_module self_features_8_denselayer12_conv1 self_features_8_denselayer12_conv1 (self_features_8_denselayer12_relu1,) {}
call_module self_features_8_denselayer12_norm2 self_features_8_denselayer12_norm2 (self_features_8_denselayer12_conv1,) {}
call_module self_features_8_denselayer12_relu2 self_features_8_denselayer12_relu2 (self_features_8_denselayer12_norm2,) {}
call_module self_features_8_denselayer12_conv2 self_features_8_denselayer12_conv2 (self_features_8_denselayer12_relu2,) {}
call_function cat_32 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2, self_features_8_denselayer4_conv2, self_features_8_denselayer5_conv2, self_features_8_denselayer6_conv2, self_features_8_denselayer7_conv2, self_features_8_denselayer8_conv2, self_features_8_denselayer9_conv2, self_features_8_denselayer10_conv2, self_features_8_denselayer11_conv2, self_features_8_denselayer12_conv2], 1) {}
call_module self_features_8_denselayer13_norm1 self_features_8_denselayer13_norm1 (cat_32,) {}
call_module self_features_8_denselayer13_relu1 self_features_8_denselayer13_relu1 (self_features_8_denselayer13_norm1,) {}
call_module self_features_8_denselayer13_conv1 self_features_8_denselayer13_conv1 (self_features_8_denselayer13_relu1,) {}
call_module self_features_8_denselayer13_norm2 self_features_8_denselayer13_norm2 (self_features_8_denselayer13_conv1,) {}
call_module self_features_8_denselayer13_relu2 self_features_8_denselayer13_relu2 (self_features_8_denselayer13_norm2,) {}
call_module self_features_8_denselayer13_conv2 self_features_8_denselayer13_conv2 (self_features_8_denselayer13_relu2,) {}
call_function cat_33 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2, self_features_8_denselayer4_conv2, self_features_8_denselayer5_conv2, self_features_8_denselayer6_conv2, self_features_8_denselayer7_conv2, self_features_8_denselayer8_conv2, self_features_8_denselayer9_conv2, self_features_8_denselayer10_conv2, self_features_8_denselayer11_conv2, self_features_8_denselayer12_conv2, self_features_8_denselayer13_conv2], 1) {}
call_module self_features_8_denselayer14_norm1 self_features_8_denselayer14_norm1 (cat_33,) {}
call_module self_features_8_denselayer14_relu1 self_features_8_denselayer14_relu1 (self_features_8_denselayer14_norm1,) {}
call_module self_features_8_denselayer14_conv1 self_features_8_denselayer14_conv1 (self_features_8_denselayer14_relu1,) {}
call_module self_features_8_denselayer14_norm2 self_features_8_denselayer14_norm2 (self_features_8_denselayer14_conv1,) {}
call_module self_features_8_denselayer14_relu2 self_features_8_denselayer14_relu2 (self_features_8_denselayer14_norm2,) {}
call_module self_features_8_denselayer14_conv2 self_features_8_denselayer14_conv2 (self_features_8_denselayer14_relu2,) {}
call_function cat_34 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2, self_features_8_denselayer4_conv2, self_features_8_denselayer5_conv2, self_features_8_denselayer6_conv2, self_features_8_denselayer7_conv2, self_features_8_denselayer8_conv2, self_features_8_denselayer9_conv2, self_features_8_denselayer10_conv2, self_features_8_denselayer11_conv2, self_features_8_denselayer12_conv2, self_features_8_denselayer13_conv2, self_features_8_denselayer14_conv2], 1) {}
call_module self_features_8_denselayer15_norm1 self_features_8_denselayer15_norm1 (cat_34,) {}
call_module self_features_8_denselayer15_relu1 self_features_8_denselayer15_relu1 (self_features_8_denselayer15_norm1,) {}
call_module self_features_8_denselayer15_conv1 self_features_8_denselayer15_conv1 (self_features_8_denselayer15_relu1,) {}
call_module self_features_8_denselayer15_norm2 self_features_8_denselayer15_norm2 (self_features_8_denselayer15_conv1,) {}
call_module self_features_8_denselayer15_relu2 self_features_8_denselayer15_relu2 (self_features_8_denselayer15_norm2,) {}
call_module self_features_8_denselayer15_conv2 self_features_8_denselayer15_conv2 (self_features_8_denselayer15_relu2,) {}
call_function cat_35 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2, self_features_8_denselayer4_conv2, self_features_8_denselayer5_conv2, self_features_8_denselayer6_conv2, self_features_8_denselayer7_conv2, self_features_8_denselayer8_conv2, self_features_8_denselayer9_conv2, self_features_8_denselayer10_conv2, self_features_8_denselayer11_conv2, self_features_8_denselayer12_conv2, self_features_8_denselayer13_conv2, self_features_8_denselayer14_conv2, self_features_8_denselayer15_conv2], 1) {}
call_module self_features_8_denselayer16_norm1 self_features_8_denselayer16_norm1 (cat_35,) {}
call_module self_features_8_denselayer16_relu1 self_features_8_denselayer16_relu1 (self_features_8_denselayer16_norm1,) {}
call_module self_features_8_denselayer16_conv1 self_features_8_denselayer16_conv1 (self_features_8_denselayer16_relu1,) {}
call_module self_features_8_denselayer16_norm2 self_features_8_denselayer16_norm2 (self_features_8_denselayer16_conv1,) {}
call_module self_features_8_denselayer16_relu2 self_features_8_denselayer16_relu2 (self_features_8_denselayer16_norm2,) {}
call_module self_features_8_denselayer16_conv2 self_features_8_denselayer16_conv2 (self_features_8_denselayer16_relu2,) {}
call_function cat_36 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2, self_features_8_denselayer4_conv2, self_features_8_denselayer5_conv2, self_features_8_denselayer6_conv2, self_features_8_denselayer7_conv2, self_features_8_denselayer8_conv2, self_features_8_denselayer9_conv2, self_features_8_denselayer10_conv2, self_features_8_denselayer11_conv2, self_features_8_denselayer12_conv2, self_features_8_denselayer13_conv2, self_features_8_denselayer14_conv2, self_features_8_denselayer15_conv2, self_features_8_denselayer16_conv2], 1) {}
call_module self_features_8_denselayer17_norm1 self_features_8_denselayer17_norm1 (cat_36,) {}
call_module self_features_8_denselayer17_relu1 self_features_8_denselayer17_relu1 (self_features_8_denselayer17_norm1,) {}
call_module self_features_8_denselayer17_conv1 self_features_8_denselayer17_conv1 (self_features_8_denselayer17_relu1,) {}
call_module self_features_8_denselayer17_norm2 self_features_8_denselayer17_norm2 (self_features_8_denselayer17_conv1,) {}
call_module self_features_8_denselayer17_relu2 self_features_8_denselayer17_relu2 (self_features_8_denselayer17_norm2,) {}
call_module self_features_8_denselayer17_conv2 self_features_8_denselayer17_conv2 (self_features_8_denselayer17_relu2,) {}
call_function cat_37 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2, self_features_8_denselayer4_conv2, self_features_8_denselayer5_conv2, self_features_8_denselayer6_conv2, self_features_8_denselayer7_conv2, self_features_8_denselayer8_conv2, self_features_8_denselayer9_conv2, self_features_8_denselayer10_conv2, self_features_8_denselayer11_conv2, self_features_8_denselayer12_conv2, self_features_8_denselayer13_conv2, self_features_8_denselayer14_conv2, self_features_8_denselayer15_conv2, self_features_8_denselayer16_conv2, self_features_8_denselayer17_conv2], 1) {}
call_module self_features_8_denselayer18_norm1 self_features_8_denselayer18_norm1 (cat_37,) {}
call_module self_features_8_denselayer18_relu1 self_features_8_denselayer18_relu1 (self_features_8_denselayer18_norm1,) {}
call_module self_features_8_denselayer18_conv1 self_features_8_denselayer18_conv1 (self_features_8_denselayer18_relu1,) {}
call_module self_features_8_denselayer18_norm2 self_features_8_denselayer18_norm2 (self_features_8_denselayer18_conv1,) {}
call_module self_features_8_denselayer18_relu2 self_features_8_denselayer18_relu2 (self_features_8_denselayer18_norm2,) {}
call_module self_features_8_denselayer18_conv2 self_features_8_denselayer18_conv2 (self_features_8_denselayer18_relu2,) {}
call_function cat_38 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2, self_features_8_denselayer4_conv2, self_features_8_denselayer5_conv2, self_features_8_denselayer6_conv2, self_features_8_denselayer7_conv2, self_features_8_denselayer8_conv2, self_features_8_denselayer9_conv2, self_features_8_denselayer10_conv2, self_features_8_denselayer11_conv2, self_features_8_denselayer12_conv2, self_features_8_denselayer13_conv2, self_features_8_denselayer14_conv2, self_features_8_denselayer15_conv2, self_features_8_denselayer16_conv2, self_features_8_denselayer17_conv2, self_features_8_denselayer18_conv2], 1) {}
call_module self_features_8_denselayer19_norm1 self_features_8_denselayer19_norm1 (cat_38,) {}
call_module self_features_8_denselayer19_relu1 self_features_8_denselayer19_relu1 (self_features_8_denselayer19_norm1,) {}
call_module self_features_8_denselayer19_conv1 self_features_8_denselayer19_conv1 (self_features_8_denselayer19_relu1,) {}
call_module self_features_8_denselayer19_norm2 self_features_8_denselayer19_norm2 (self_features_8_denselayer19_conv1,) {}
call_module self_features_8_denselayer19_relu2 self_features_8_denselayer19_relu2 (self_features_8_denselayer19_norm2,) {}
call_module self_features_8_denselayer19_conv2 self_features_8_denselayer19_conv2 (self_features_8_denselayer19_relu2,) {}
call_function cat_39 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2, self_features_8_denselayer4_conv2, self_features_8_denselayer5_conv2, self_features_8_denselayer6_conv2, self_features_8_denselayer7_conv2, self_features_8_denselayer8_conv2, self_features_8_denselayer9_conv2, self_features_8_denselayer10_conv2, self_features_8_denselayer11_conv2, self_features_8_denselayer12_conv2, self_features_8_denselayer13_conv2, self_features_8_denselayer14_conv2, self_features_8_denselayer15_conv2, self_features_8_denselayer16_conv2, self_features_8_denselayer17_conv2, self_features_8_denselayer18_conv2, self_features_8_denselayer19_conv2], 1) {}
call_module self_features_8_denselayer20_norm1 self_features_8_denselayer20_norm1 (cat_39,) {}
call_module self_features_8_denselayer20_relu1 self_features_8_denselayer20_relu1 (self_features_8_denselayer20_norm1,) {}
call_module self_features_8_denselayer20_conv1 self_features_8_denselayer20_conv1 (self_features_8_denselayer20_relu1,) {}
call_module self_features_8_denselayer20_norm2 self_features_8_denselayer20_norm2 (self_features_8_denselayer20_conv1,) {}
call_module self_features_8_denselayer20_relu2 self_features_8_denselayer20_relu2 (self_features_8_denselayer20_norm2,) {}
call_module self_features_8_denselayer20_conv2 self_features_8_denselayer20_conv2 (self_features_8_denselayer20_relu2,) {}
call_function cat_40 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2, self_features_8_denselayer4_conv2, self_features_8_denselayer5_conv2, self_features_8_denselayer6_conv2, self_features_8_denselayer7_conv2, self_features_8_denselayer8_conv2, self_features_8_denselayer9_conv2, self_features_8_denselayer10_conv2, self_features_8_denselayer11_conv2, self_features_8_denselayer12_conv2, self_features_8_denselayer13_conv2, self_features_8_denselayer14_conv2, self_features_8_denselayer15_conv2, self_features_8_denselayer16_conv2, self_features_8_denselayer17_conv2, self_features_8_denselayer18_conv2, self_features_8_denselayer19_conv2, self_features_8_denselayer20_conv2], 1) {}
call_module self_features_8_denselayer21_norm1 self_features_8_denselayer21_norm1 (cat_40,) {}
call_module self_features_8_denselayer21_relu1 self_features_8_denselayer21_relu1 (self_features_8_denselayer21_norm1,) {}
call_module self_features_8_denselayer21_conv1 self_features_8_denselayer21_conv1 (self_features_8_denselayer21_relu1,) {}
call_module self_features_8_denselayer21_norm2 self_features_8_denselayer21_norm2 (self_features_8_denselayer21_conv1,) {}
call_module self_features_8_denselayer21_relu2 self_features_8_denselayer21_relu2 (self_features_8_denselayer21_norm2,) {}
call_module self_features_8_denselayer21_conv2 self_features_8_denselayer21_conv2 (self_features_8_denselayer21_relu2,) {}
call_function cat_41 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2, self_features_8_denselayer4_conv2, self_features_8_denselayer5_conv2, self_features_8_denselayer6_conv2, self_features_8_denselayer7_conv2, self_features_8_denselayer8_conv2, self_features_8_denselayer9_conv2, self_features_8_denselayer10_conv2, self_features_8_denselayer11_conv2, self_features_8_denselayer12_conv2, self_features_8_denselayer13_conv2, self_features_8_denselayer14_conv2, self_features_8_denselayer15_conv2, self_features_8_denselayer16_conv2, self_features_8_denselayer17_conv2, self_features_8_denselayer18_conv2, self_features_8_denselayer19_conv2, self_features_8_denselayer20_conv2, self_features_8_denselayer21_conv2], 1) {}
call_module self_features_8_denselayer22_norm1 self_features_8_denselayer22_norm1 (cat_41,) {}
call_module self_features_8_denselayer22_relu1 self_features_8_denselayer22_relu1 (self_features_8_denselayer22_norm1,) {}
call_module self_features_8_denselayer22_conv1 self_features_8_denselayer22_conv1 (self_features_8_denselayer22_relu1,) {}
call_module self_features_8_denselayer22_norm2 self_features_8_denselayer22_norm2 (self_features_8_denselayer22_conv1,) {}
call_module self_features_8_denselayer22_relu2 self_features_8_denselayer22_relu2 (self_features_8_denselayer22_norm2,) {}
call_module self_features_8_denselayer22_conv2 self_features_8_denselayer22_conv2 (self_features_8_denselayer22_relu2,) {}
call_function cat_42 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2, self_features_8_denselayer4_conv2, self_features_8_denselayer5_conv2, self_features_8_denselayer6_conv2, self_features_8_denselayer7_conv2, self_features_8_denselayer8_conv2, self_features_8_denselayer9_conv2, self_features_8_denselayer10_conv2, self_features_8_denselayer11_conv2, self_features_8_denselayer12_conv2, self_features_8_denselayer13_conv2, self_features_8_denselayer14_conv2, self_features_8_denselayer15_conv2, self_features_8_denselayer16_conv2, self_features_8_denselayer17_conv2, self_features_8_denselayer18_conv2, self_features_8_denselayer19_conv2, self_features_8_denselayer20_conv2, self_features_8_denselayer21_conv2, self_features_8_denselayer22_conv2], 1) {}
call_module self_features_8_denselayer23_norm1 self_features_8_denselayer23_norm1 (cat_42,) {}
call_module self_features_8_denselayer23_relu1 self_features_8_denselayer23_relu1 (self_features_8_denselayer23_norm1,) {}
call_module self_features_8_denselayer23_conv1 self_features_8_denselayer23_conv1 (self_features_8_denselayer23_relu1,) {}
call_module self_features_8_denselayer23_norm2 self_features_8_denselayer23_norm2 (self_features_8_denselayer23_conv1,) {}
call_module self_features_8_denselayer23_relu2 self_features_8_denselayer23_relu2 (self_features_8_denselayer23_norm2,) {}
call_module self_features_8_denselayer23_conv2 self_features_8_denselayer23_conv2 (self_features_8_denselayer23_relu2,) {}
call_function cat_43 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2, self_features_8_denselayer4_conv2, self_features_8_denselayer5_conv2, self_features_8_denselayer6_conv2, self_features_8_denselayer7_conv2, self_features_8_denselayer8_conv2, self_features_8_denselayer9_conv2, self_features_8_denselayer10_conv2, self_features_8_denselayer11_conv2, self_features_8_denselayer12_conv2, self_features_8_denselayer13_conv2, self_features_8_denselayer14_conv2, self_features_8_denselayer15_conv2, self_features_8_denselayer16_conv2, self_features_8_denselayer17_conv2, self_features_8_denselayer18_conv2, self_features_8_denselayer19_conv2, self_features_8_denselayer20_conv2, self_features_8_denselayer21_conv2, self_features_8_denselayer22_conv2, self_features_8_denselayer23_conv2], 1) {}
call_module self_features_8_denselayer24_norm1 self_features_8_denselayer24_norm1 (cat_43,) {}
call_module self_features_8_denselayer24_relu1 self_features_8_denselayer24_relu1 (self_features_8_denselayer24_norm1,) {}
call_module self_features_8_denselayer24_conv1 self_features_8_denselayer24_conv1 (self_features_8_denselayer24_relu1,) {}
call_module self_features_8_denselayer24_norm2 self_features_8_denselayer24_norm2 (self_features_8_denselayer24_conv1,) {}
call_module self_features_8_denselayer24_relu2 self_features_8_denselayer24_relu2 (self_features_8_denselayer24_norm2,) {}
call_module self_features_8_denselayer24_conv2 self_features_8_denselayer24_conv2 (self_features_8_denselayer24_relu2,) {}
call_function cat_44 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_7_3, self_features_8_denselayer1_conv2, self_features_8_denselayer2_conv2, self_features_8_denselayer3_conv2, self_features_8_denselayer4_conv2, self_features_8_denselayer5_conv2, self_features_8_denselayer6_conv2, self_features_8_denselayer7_conv2, self_features_8_denselayer8_conv2, self_features_8_denselayer9_conv2, self_features_8_denselayer10_conv2, self_features_8_denselayer11_conv2, self_features_8_denselayer12_conv2, self_features_8_denselayer13_conv2, self_features_8_denselayer14_conv2, self_features_8_denselayer15_conv2, self_features_8_denselayer16_conv2, self_features_8_denselayer17_conv2, self_features_8_denselayer18_conv2, self_features_8_denselayer19_conv2, self_features_8_denselayer20_conv2, self_features_8_denselayer21_conv2, self_features_8_denselayer22_conv2, self_features_8_denselayer23_conv2, self_features_8_denselayer24_conv2], 1) {}
call_module self_features_9_0 self_features_9_0 (cat_44,) {}
call_module self_features_9_1 self_features_9_1 (self_features_9_0,) {}
call_module self_features_9_2 self_features_9_2 (self_features_9_1,) {}
call_module self_features_9_3 self_features_9_3 (self_features_9_2,) {}
call_function cat_45 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_9_3], 1) {}
call_module self_features_10_denselayer1_norm1 self_features_10_denselayer1_norm1 (cat_45,) {}
call_module self_features_10_denselayer1_relu1 self_features_10_denselayer1_relu1 (self_features_10_denselayer1_norm1,) {}
call_module self_features_10_denselayer1_conv1 self_features_10_denselayer1_conv1 (self_features_10_denselayer1_relu1,) {}
call_module self_features_10_denselayer1_norm2 self_features_10_denselayer1_norm2 (self_features_10_denselayer1_conv1,) {}
call_module self_features_10_denselayer1_relu2 self_features_10_denselayer1_relu2 (self_features_10_denselayer1_norm2,) {}
call_module self_features_10_denselayer1_conv2 self_features_10_denselayer1_conv2 (self_features_10_denselayer1_relu2,) {}
call_function cat_46 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_9_3, self_features_10_denselayer1_conv2], 1) {}
call_module self_features_10_denselayer2_norm1 self_features_10_denselayer2_norm1 (cat_46,) {}
call_module self_features_10_denselayer2_relu1 self_features_10_denselayer2_relu1 (self_features_10_denselayer2_norm1,) {}
call_module self_features_10_denselayer2_conv1 self_features_10_denselayer2_conv1 (self_features_10_denselayer2_relu1,) {}
call_module self_features_10_denselayer2_norm2 self_features_10_denselayer2_norm2 (self_features_10_denselayer2_conv1,) {}
call_module self_features_10_denselayer2_relu2 self_features_10_denselayer2_relu2 (self_features_10_denselayer2_norm2,) {}
call_module self_features_10_denselayer2_conv2 self_features_10_denselayer2_conv2 (self_features_10_denselayer2_relu2,) {}
call_function cat_47 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_9_3, self_features_10_denselayer1_conv2, self_features_10_denselayer2_conv2], 1) {}
call_module self_features_10_denselayer3_norm1 self_features_10_denselayer3_norm1 (cat_47,) {}
call_module self_features_10_denselayer3_relu1 self_features_10_denselayer3_relu1 (self_features_10_denselayer3_norm1,) {}
call_module self_features_10_denselayer3_conv1 self_features_10_denselayer3_conv1 (self_features_10_denselayer3_relu1,) {}
call_module self_features_10_denselayer3_norm2 self_features_10_denselayer3_norm2 (self_features_10_denselayer3_conv1,) {}
call_module self_features_10_denselayer3_relu2 self_features_10_denselayer3_relu2 (self_features_10_denselayer3_norm2,) {}
call_module self_features_10_denselayer3_conv2 self_features_10_denselayer3_conv2 (self_features_10_denselayer3_relu2,) {}
call_function cat_48 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_9_3, self_features_10_denselayer1_conv2, self_features_10_denselayer2_conv2, self_features_10_denselayer3_conv2], 1) {}
call_module self_features_10_denselayer4_norm1 self_features_10_denselayer4_norm1 (cat_48,) {}
call_module self_features_10_denselayer4_relu1 self_features_10_denselayer4_relu1 (self_features_10_denselayer4_norm1,) {}
call_module self_features_10_denselayer4_conv1 self_features_10_denselayer4_conv1 (self_features_10_denselayer4_relu1,) {}
call_module self_features_10_denselayer4_norm2 self_features_10_denselayer4_norm2 (self_features_10_denselayer4_conv1,) {}
call_module self_features_10_denselayer4_relu2 self_features_10_denselayer4_relu2 (self_features_10_denselayer4_norm2,) {}
call_module self_features_10_denselayer4_conv2 self_features_10_denselayer4_conv2 (self_features_10_denselayer4_relu2,) {}
call_function cat_49 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_9_3, self_features_10_denselayer1_conv2, self_features_10_denselayer2_conv2, self_features_10_denselayer3_conv2, self_features_10_denselayer4_conv2], 1) {}
call_module self_features_10_denselayer5_norm1 self_features_10_denselayer5_norm1 (cat_49,) {}
call_module self_features_10_denselayer5_relu1 self_features_10_denselayer5_relu1 (self_features_10_denselayer5_norm1,) {}
call_module self_features_10_denselayer5_conv1 self_features_10_denselayer5_conv1 (self_features_10_denselayer5_relu1,) {}
call_module self_features_10_denselayer5_norm2 self_features_10_denselayer5_norm2 (self_features_10_denselayer5_conv1,) {}
call_module self_features_10_denselayer5_relu2 self_features_10_denselayer5_relu2 (self_features_10_denselayer5_norm2,) {}
call_module self_features_10_denselayer5_conv2 self_features_10_denselayer5_conv2 (self_features_10_denselayer5_relu2,) {}
call_function cat_50 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_9_3, self_features_10_denselayer1_conv2, self_features_10_denselayer2_conv2, self_features_10_denselayer3_conv2, self_features_10_denselayer4_conv2, self_features_10_denselayer5_conv2], 1) {}
call_module self_features_10_denselayer6_norm1 self_features_10_denselayer6_norm1 (cat_50,) {}
call_module self_features_10_denselayer6_relu1 self_features_10_denselayer6_relu1 (self_features_10_denselayer6_norm1,) {}
call_module self_features_10_denselayer6_conv1 self_features_10_denselayer6_conv1 (self_features_10_denselayer6_relu1,) {}
call_module self_features_10_denselayer6_norm2 self_features_10_denselayer6_norm2 (self_features_10_denselayer6_conv1,) {}
call_module self_features_10_denselayer6_relu2 self_features_10_denselayer6_relu2 (self_features_10_denselayer6_norm2,) {}
call_module self_features_10_denselayer6_conv2 self_features_10_denselayer6_conv2 (self_features_10_denselayer6_relu2,) {}
call_function cat_51 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_9_3, self_features_10_denselayer1_conv2, self_features_10_denselayer2_conv2, self_features_10_denselayer3_conv2, self_features_10_denselayer4_conv2, self_features_10_denselayer5_conv2, self_features_10_denselayer6_conv2], 1) {}
call_module self_features_10_denselayer7_norm1 self_features_10_denselayer7_norm1 (cat_51,) {}
call_module self_features_10_denselayer7_relu1 self_features_10_denselayer7_relu1 (self_features_10_denselayer7_norm1,) {}
call_module self_features_10_denselayer7_conv1 self_features_10_denselayer7_conv1 (self_features_10_denselayer7_relu1,) {}
call_module self_features_10_denselayer7_norm2 self_features_10_denselayer7_norm2 (self_features_10_denselayer7_conv1,) {}
call_module self_features_10_denselayer7_relu2 self_features_10_denselayer7_relu2 (self_features_10_denselayer7_norm2,) {}
call_module self_features_10_denselayer7_conv2 self_features_10_denselayer7_conv2 (self_features_10_denselayer7_relu2,) {}
call_function cat_52 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_9_3, self_features_10_denselayer1_conv2, self_features_10_denselayer2_conv2, self_features_10_denselayer3_conv2, self_features_10_denselayer4_conv2, self_features_10_denselayer5_conv2, self_features_10_denselayer6_conv2, self_features_10_denselayer7_conv2], 1) {}
call_module self_features_10_denselayer8_norm1 self_features_10_denselayer8_norm1 (cat_52,) {}
call_module self_features_10_denselayer8_relu1 self_features_10_denselayer8_relu1 (self_features_10_denselayer8_norm1,) {}
call_module self_features_10_denselayer8_conv1 self_features_10_denselayer8_conv1 (self_features_10_denselayer8_relu1,) {}
call_module self_features_10_denselayer8_norm2 self_features_10_denselayer8_norm2 (self_features_10_denselayer8_conv1,) {}
call_module self_features_10_denselayer8_relu2 self_features_10_denselayer8_relu2 (self_features_10_denselayer8_norm2,) {}
call_module self_features_10_denselayer8_conv2 self_features_10_denselayer8_conv2 (self_features_10_denselayer8_relu2,) {}
call_function cat_53 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_9_3, self_features_10_denselayer1_conv2, self_features_10_denselayer2_conv2, self_features_10_denselayer3_conv2, self_features_10_denselayer4_conv2, self_features_10_denselayer5_conv2, self_features_10_denselayer6_conv2, self_features_10_denselayer7_conv2, self_features_10_denselayer8_conv2], 1) {}
call_module self_features_10_denselayer9_norm1 self_features_10_denselayer9_norm1 (cat_53,) {}
call_module self_features_10_denselayer9_relu1 self_features_10_denselayer9_relu1 (self_features_10_denselayer9_norm1,) {}
call_module self_features_10_denselayer9_conv1 self_features_10_denselayer9_conv1 (self_features_10_denselayer9_relu1,) {}
call_module self_features_10_denselayer9_norm2 self_features_10_denselayer9_norm2 (self_features_10_denselayer9_conv1,) {}
call_module self_features_10_denselayer9_relu2 self_features_10_denselayer9_relu2 (self_features_10_denselayer9_norm2,) {}
call_module self_features_10_denselayer9_conv2 self_features_10_denselayer9_conv2 (self_features_10_denselayer9_relu2,) {}
call_function cat_54 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_9_3, self_features_10_denselayer1_conv2, self_features_10_denselayer2_conv2, self_features_10_denselayer3_conv2, self_features_10_denselayer4_conv2, self_features_10_denselayer5_conv2, self_features_10_denselayer6_conv2, self_features_10_denselayer7_conv2, self_features_10_denselayer8_conv2, self_features_10_denselayer9_conv2], 1) {}
call_module self_features_10_denselayer10_norm1 self_features_10_denselayer10_norm1 (cat_54,) {}
call_module self_features_10_denselayer10_relu1 self_features_10_denselayer10_relu1 (self_features_10_denselayer10_norm1,) {}
call_module self_features_10_denselayer10_conv1 self_features_10_denselayer10_conv1 (self_features_10_denselayer10_relu1,) {}
call_module self_features_10_denselayer10_norm2 self_features_10_denselayer10_norm2 (self_features_10_denselayer10_conv1,) {}
call_module self_features_10_denselayer10_relu2 self_features_10_denselayer10_relu2 (self_features_10_denselayer10_norm2,) {}
call_module self_features_10_denselayer10_conv2 self_features_10_denselayer10_conv2 (self_features_10_denselayer10_relu2,) {}
call_function cat_55 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_9_3, self_features_10_denselayer1_conv2, self_features_10_denselayer2_conv2, self_features_10_denselayer3_conv2, self_features_10_denselayer4_conv2, self_features_10_denselayer5_conv2, self_features_10_denselayer6_conv2, self_features_10_denselayer7_conv2, self_features_10_denselayer8_conv2, self_features_10_denselayer9_conv2, self_features_10_denselayer10_conv2], 1) {}
call_module self_features_10_denselayer11_norm1 self_features_10_denselayer11_norm1 (cat_55,) {}
call_module self_features_10_denselayer11_relu1 self_features_10_denselayer11_relu1 (self_features_10_denselayer11_norm1,) {}
call_module self_features_10_denselayer11_conv1 self_features_10_denselayer11_conv1 (self_features_10_denselayer11_relu1,) {}
call_module self_features_10_denselayer11_norm2 self_features_10_denselayer11_norm2 (self_features_10_denselayer11_conv1,) {}
call_module self_features_10_denselayer11_relu2 self_features_10_denselayer11_relu2 (self_features_10_denselayer11_norm2,) {}
call_module self_features_10_denselayer11_conv2 self_features_10_denselayer11_conv2 (self_features_10_denselayer11_relu2,) {}
call_function cat_56 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_9_3, self_features_10_denselayer1_conv2, self_features_10_denselayer2_conv2, self_features_10_denselayer3_conv2, self_features_10_denselayer4_conv2, self_features_10_denselayer5_conv2, self_features_10_denselayer6_conv2, self_features_10_denselayer7_conv2, self_features_10_denselayer8_conv2, self_features_10_denselayer9_conv2, self_features_10_denselayer10_conv2, self_features_10_denselayer11_conv2], 1) {}
call_module self_features_10_denselayer12_norm1 self_features_10_denselayer12_norm1 (cat_56,) {}
call_module self_features_10_denselayer12_relu1 self_features_10_denselayer12_relu1 (self_features_10_denselayer12_norm1,) {}
call_module self_features_10_denselayer12_conv1 self_features_10_denselayer12_conv1 (self_features_10_denselayer12_relu1,) {}
call_module self_features_10_denselayer12_norm2 self_features_10_denselayer12_norm2 (self_features_10_denselayer12_conv1,) {}
call_module self_features_10_denselayer12_relu2 self_features_10_denselayer12_relu2 (self_features_10_denselayer12_norm2,) {}
call_module self_features_10_denselayer12_conv2 self_features_10_denselayer12_conv2 (self_features_10_denselayer12_relu2,) {}
call_function cat_57 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_9_3, self_features_10_denselayer1_conv2, self_features_10_denselayer2_conv2, self_features_10_denselayer3_conv2, self_features_10_denselayer4_conv2, self_features_10_denselayer5_conv2, self_features_10_denselayer6_conv2, self_features_10_denselayer7_conv2, self_features_10_denselayer8_conv2, self_features_10_denselayer9_conv2, self_features_10_denselayer10_conv2, self_features_10_denselayer11_conv2, self_features_10_denselayer12_conv2], 1) {}
call_module self_features_10_denselayer13_norm1 self_features_10_denselayer13_norm1 (cat_57,) {}
call_module self_features_10_denselayer13_relu1 self_features_10_denselayer13_relu1 (self_features_10_denselayer13_norm1,) {}
call_module self_features_10_denselayer13_conv1 self_features_10_denselayer13_conv1 (self_features_10_denselayer13_relu1,) {}
call_module self_features_10_denselayer13_norm2 self_features_10_denselayer13_norm2 (self_features_10_denselayer13_conv1,) {}
call_module self_features_10_denselayer13_relu2 self_features_10_denselayer13_relu2 (self_features_10_denselayer13_norm2,) {}
call_module self_features_10_denselayer13_conv2 self_features_10_denselayer13_conv2 (self_features_10_denselayer13_relu2,) {}
call_function cat_58 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_9_3, self_features_10_denselayer1_conv2, self_features_10_denselayer2_conv2, self_features_10_denselayer3_conv2, self_features_10_denselayer4_conv2, self_features_10_denselayer5_conv2, self_features_10_denselayer6_conv2, self_features_10_denselayer7_conv2, self_features_10_denselayer8_conv2, self_features_10_denselayer9_conv2, self_features_10_denselayer10_conv2, self_features_10_denselayer11_conv2, self_features_10_denselayer12_conv2, self_features_10_denselayer13_conv2], 1) {}
call_module self_features_10_denselayer14_norm1 self_features_10_denselayer14_norm1 (cat_58,) {}
call_module self_features_10_denselayer14_relu1 self_features_10_denselayer14_relu1 (self_features_10_denselayer14_norm1,) {}
call_module self_features_10_denselayer14_conv1 self_features_10_denselayer14_conv1 (self_features_10_denselayer14_relu1,) {}
call_module self_features_10_denselayer14_norm2 self_features_10_denselayer14_norm2 (self_features_10_denselayer14_conv1,) {}
call_module self_features_10_denselayer14_relu2 self_features_10_denselayer14_relu2 (self_features_10_denselayer14_norm2,) {}
call_module self_features_10_denselayer14_conv2 self_features_10_denselayer14_conv2 (self_features_10_denselayer14_relu2,) {}
call_function cat_59 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_9_3, self_features_10_denselayer1_conv2, self_features_10_denselayer2_conv2, self_features_10_denselayer3_conv2, self_features_10_denselayer4_conv2, self_features_10_denselayer5_conv2, self_features_10_denselayer6_conv2, self_features_10_denselayer7_conv2, self_features_10_denselayer8_conv2, self_features_10_denselayer9_conv2, self_features_10_denselayer10_conv2, self_features_10_denselayer11_conv2, self_features_10_denselayer12_conv2, self_features_10_denselayer13_conv2, self_features_10_denselayer14_conv2], 1) {}
call_module self_features_10_denselayer15_norm1 self_features_10_denselayer15_norm1 (cat_59,) {}
call_module self_features_10_denselayer15_relu1 self_features_10_denselayer15_relu1 (self_features_10_denselayer15_norm1,) {}
call_module self_features_10_denselayer15_conv1 self_features_10_denselayer15_conv1 (self_features_10_denselayer15_relu1,) {}
call_module self_features_10_denselayer15_norm2 self_features_10_denselayer15_norm2 (self_features_10_denselayer15_conv1,) {}
call_module self_features_10_denselayer15_relu2 self_features_10_denselayer15_relu2 (self_features_10_denselayer15_norm2,) {}
call_module self_features_10_denselayer15_conv2 self_features_10_denselayer15_conv2 (self_features_10_denselayer15_relu2,) {}
call_function cat_60 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_9_3, self_features_10_denselayer1_conv2, self_features_10_denselayer2_conv2, self_features_10_denselayer3_conv2, self_features_10_denselayer4_conv2, self_features_10_denselayer5_conv2, self_features_10_denselayer6_conv2, self_features_10_denselayer7_conv2, self_features_10_denselayer8_conv2, self_features_10_denselayer9_conv2, self_features_10_denselayer10_conv2, self_features_10_denselayer11_conv2, self_features_10_denselayer12_conv2, self_features_10_denselayer13_conv2, self_features_10_denselayer14_conv2, self_features_10_denselayer15_conv2], 1) {}
call_module self_features_10_denselayer16_norm1 self_features_10_denselayer16_norm1 (cat_60,) {}
call_module self_features_10_denselayer16_relu1 self_features_10_denselayer16_relu1 (self_features_10_denselayer16_norm1,) {}
call_module self_features_10_denselayer16_conv1 self_features_10_denselayer16_conv1 (self_features_10_denselayer16_relu1,) {}
call_module self_features_10_denselayer16_norm2 self_features_10_denselayer16_norm2 (self_features_10_denselayer16_conv1,) {}
call_module self_features_10_denselayer16_relu2 self_features_10_denselayer16_relu2 (self_features_10_denselayer16_norm2,) {}
call_module self_features_10_denselayer16_conv2 self_features_10_denselayer16_conv2 (self_features_10_denselayer16_relu2,) {}
call_function cat_61 <built-in method cat of type object at 0x7f5ab02ab540> ([self_features_9_3, self_features_10_denselayer1_conv2, self_features_10_denselayer2_conv2, self_features_10_denselayer3_conv2, self_features_10_denselayer4_conv2, self_features_10_denselayer5_conv2, self_features_10_denselayer6_conv2, self_features_10_denselayer7_conv2, self_features_10_denselayer8_conv2, self_features_10_denselayer9_conv2, self_features_10_denselayer10_conv2, self_features_10_denselayer11_conv2, self_features_10_denselayer12_conv2, self_features_10_denselayer13_conv2, self_features_10_denselayer14_conv2, self_features_10_denselayer15_conv2, self_features_10_denselayer16_conv2], 1) {}
call_module self_features_11 self_features_11 (cat_61,) {}
call_function relu <function relu at 0x7f5a3ae0d2d0> (self_features_11,) {'inplace': True}
call_function adaptive_avg_pool2d <function adaptive_avg_pool2d at 0x7f5a3ae0cdc0> (relu, (1, 1)) {}
call_function flatten <built-in method flatten of type object at 0x7f5ab02ab540> (adaptive_avg_pool2d, 1) {}
call_module self_classifier self_classifier (flatten,) {}
output output output ((self_classifier,),) {}
tensor([[-5.5279e-02, -1.0446e-02, 4.9728e-01, ..., -3.5208e-02,
-1.0469e-01, 4.3263e-01],
[-7.7735e-02, 9.0289e-02, 5.1820e-01, ..., 4.9915e-02,
-1.9321e-01, 3.2847e-01],
[-4.0582e-02, 8.9672e-02, 5.9331e-01, ..., -8.9621e-02,
-1.1201e-01, 1.9783e-01],
...,
[-6.2944e-02, 4.8209e-02, 7.3849e-01, ..., -1.5723e-01,
-6.1362e-02, 3.3617e-01],
[-3.2220e-02, -1.8281e-02, 6.0680e-01, ..., -9.6612e-02,
-1.0105e-01, 2.8053e-01],
[-1.1590e-01, -5.4388e-02, 5.2024e-01, ..., 3.7762e-02,
-1.2914e-04, 3.8643e-01]], device='cuda:0', grad_fn=<AddmmBackward0>)
Using our custom backend, we can now see how TorchDynamo is able to handle
data-dependent control flow. Consider the function below, where the line
if b.sum() < 0
is the source of data-dependent control flow.
def bar(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
opt_bar = torch.compile(bar, backend=custom_backend)
inp1 = torch.randn(10)
inp2 = torch.randn(10)
opt_bar(inp1, inp2)
opt_bar(inp1, -inp2)
custom backend called with FX graph:
opcode name target args kwargs
------------- ------- ------------------------------------------------------ ---------------- --------
placeholder a a () {}
placeholder b b () {}
call_function abs_1 <built-in method abs of type object at 0x7f5ab02ab540> (a,) {}
call_function add <built-in function add> (abs_1, 1) {}
call_function truediv <built-in function truediv> (a, add) {}
call_method sum_1 sum (b,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}
custom backend called with FX graph:
opcode name target args kwargs
------------- ------ ----------------------- --------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (x, b) {}
output output output ((mul,),) {}
custom backend called with FX graph:
opcode name target args kwargs
------------- ------ ----------------------- ----------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (b, -1) {}
call_function mul_1 <built-in function mul> (x, mul) {}
output output output ((mul_1,),) {}
tensor([-0.2891, -0.8032, 0.0415, 0.4794, 0.7609, -0.5526, -0.0287, -0.9205,
0.4987, 0.0521])
The output reveals that TorchDynamo extracted 3 different FX graphs corresponding the following code (order may differ from the output above):
x = a / (torch.abs(a) + 1)
b = b * -1; return x * b
return x * b
When TorchDynamo encounters unsupported Python features, such as data-dependent control flow, it breaks the computation graph, lets the default Python interpreter handle the unsupported code, then resumes capturing the graph.
Let’s investigate by example how TorchDynamo would step through bar
.
If b.sum() < 0
, then TorchDynamo would run graph 1, let
Python determine the result of the conditional, then run
graph 2. On the other hand, if not b.sum() < 0
, then TorchDynamo
would run graph 1, let Python determine the result of the conditional, then
run graph 3.
This highlights a major difference between TorchDynamo and previous PyTorch compiler solutions. When encountering unsupported Python features, previous solutions either raise an error or silently fail. TorchDynamo, on the other hand, will break the computation graph.
We can see where TorchDynamo breaks the graph by using torch._dynamo.explain
:
# Reset since we are using a different backend.
torch._dynamo.reset()
explanation, out_guards, graphs, ops_per_graph, break_reasons, explanation_verbose = torch._dynamo.explain(
bar, torch.randn(10), torch.randn(10)
)
print(explanation_verbose)
Dynamo produced 2 graphs with 1 graph break and 6 ops
Break reasons:
1. generic_jump TensorVariable()
File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 428, in bar
if b.sum() < 0:
2. return_value
File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 430, in <graph break in bar>
return x * b
TorchDynamo compilation metrics:
Function Runtimes (s)
------------------------------ --------------
_compile 0.0135, 0.0060
OutputGraph.call_user_compiler 0.0001, 0.0000
In order to maximize speedup, graph breaks should be limited.
We can force TorchDynamo to raise an error upon the first graph
break encountered by using fullgraph=True
:
opt_bar = torch.compile(bar, fullgraph=True)
try:
opt_bar(torch.randn(10), torch.randn(10))
except:
tb.print_exc()
Traceback (most recent call last):
File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 478, in <module>
opt_bar(torch.randn(10), torch.randn(10))
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
return callback(frame, cache_size, hooks)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
return _compile(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
out_code = transform_code_object(code, transform)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
transformations(instructions, code_options)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
tracer.run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
super().run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
and self.step()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
getattr(self, inst.opname)(inst)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 327, in inner
unimplemented(f"generic_jump {typestr(value)}")
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 71, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: generic_jump TensorVariable()
from user code:
File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 428, in bar
if b.sum() < 0:
Set torch._dynamo.config.verbose=True for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
And below, we demonstrate that TorchDynamo does not break the graph on the model we used above for demonstrating speedups.
opt_model = torch.compile(init_model(), fullgraph=True)
print(opt_model(generate_data(16)[0]))
tensor([[ 0.0946, 0.2639, 0.0113, ..., 0.0328, -0.1605, -0.4471],
[ 0.2962, 0.4572, 0.1513, ..., -0.0847, -0.3181, -0.3021],
[ 0.2369, 0.3794, 0.3332, ..., -0.0374, -0.3046, -0.4236],
...,
[ 0.1613, 0.3540, 0.1016, ..., 0.0168, -0.3515, -0.2999],
[ 0.1791, 0.3969, 0.2462, ..., -0.0157, -0.1855, -0.3647],
[ 0.3133, 0.2200, 0.2195, ..., 0.0322, -0.2236, -0.4182]],
device='cuda:0', grad_fn=<CompiledFunctionBackward>)
Finally, if we simply want TorchDynamo to output the FX graph for export,
we can use torch._dynamo.export
. Note that torch._dynamo.export
, like
fullgraph=True
, raises an error if TorchDynamo breaks the graph.
try:
torch._dynamo.export(bar, torch.randn(10), torch.randn(10))
except:
tb.print_exc()
model_exp = torch._dynamo.export(init_model(), generate_data(16)[0])
print(model_exp[0](generate_data(16)[0]))
Traceback (most recent call last):
File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 495, in <module>
torch._dynamo.export(bar, torch.randn(10), torch.randn(10))
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 601, in export
result_traced = opt_f(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
return callback(frame, cache_size, hooks)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
return _compile(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
out_code = transform_code_object(code, transform)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
transformations(instructions, code_options)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
tracer.run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
super().run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
and self.step()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
getattr(self, inst.opname)(inst)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 327, in inner
unimplemented(f"generic_jump {typestr(value)}")
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 71, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: generic_jump TensorVariable()
from user code:
File "/var/lib/jenkins/workspace/intermediate_source/torch_compile_tutorial.py", line 428, in bar
if b.sum() < 0:
Set torch._dynamo.config.verbose=True for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
tensor([[ 0.4881, 0.0524, 0.2258, ..., -0.1283, -0.1812, -0.3570],
[ 0.3723, -0.1481, 0.3892, ..., -0.0518, -0.3119, -0.3431],
[ 0.5131, -0.2636, 0.2015, ..., -0.0433, -0.2169, -0.3728],
...,
[ 0.3709, -0.0996, 0.2814, ..., -0.0221, -0.2553, -0.4394],
[ 0.4731, 0.0373, 0.1736, ..., 0.0506, -0.2278, -0.4136],
[ 0.4274, 0.0640, 0.3437, ..., -0.1889, -0.3916, -0.3896]],
device='cuda:0', grad_fn=<AddmmBackward0>)
Conclusion¶
In this tutorial, we introduced torch.compile
by covering
basic usage, demonstrating speedups over eager mode, comparing to previous
PyTorch compiler solutions, and briefly investigating TorchDynamo and its interactions
with FX graphs. We hope that you will give torch.compile
a try!
Total running time of the script: ( 4 minutes 13.771 seconds)