PyTorch 2.0 Troubleshooting¶
Author: Michael Lazos
We are actively developing debug tools, profilers, and improving our error and warning messages. Below is a table of the available tools and their typical usage. For additional help see Diagnosing Runtime Errors.
Tool |
Purpose |
Usage |
---|---|---|
Info logging |
View summarized steps of compilation |
|
Debug logging |
View detailed steps of compilation (print every instruction traced) |
|
Minifier for any backend |
Find smallest subgraph which reproduces errors for any backend |
set environment variable |
Minifier for |
If the error is known to occur after |
set environment variable |
Dynamo accuracy minifier |
Finds the smallest subgraph which reproduces an accuracy issue
between an eager mode model and optimized model, when you
suspect the problem is in |
|
Inductor accuracy minifier |
Finds the smallest subgraph which reproduces an accuracy issue between an eager mode model and optimized model, when you suspect the problem is in the backend (e.g., inductor). If this doesn’t work, try the Dynamo accuracy minifier instead. |
|
|
Find graph breaks and display reasoning for them |
|
Record/Replay |
Record and replay frames which to reproduce errors during graph capture |
|
TorchDynamo function name filtering |
Only compile functions with the given name to reduce noise when debugging an issue |
set environment variable |
TorchInductor Debug logging |
Print general TorchInductor debug info and generated Triton/C++ code |
|
TorchInductor Tracing |
Show time taken in each TorchInductor stage + output code and graph visualization |
set the environment variable TORCH_COMPILE_DEBUG=1 or
|
In addition to info and debug logging, you can use torch._logging for more fine-grained logging.
Diagnosing Runtime Errors¶
At a high level, the TorchDynamo stack consists of a graph capture from Python code (TorchDynamo) and a backend compiler. For example, a backend compiler may consist of backward graph tracing (AOTAutograd) and graph lowering (TorchInductor)*. Errors can occur in any component of the stack and will provide full stack traces.
To determine in which component an error occurred,
you may use info-level logging
torch._logging.set_logs(dynamo = logging.INFO)
or TORCH_LOGS="dynamo"
and look for Step #: ...
outputs. Logs are made at the beginning and end of
each step, so the step that an error should correspond to is the most recently
logged step whose end has not yet been logged. The steps correspond to the
following parts of the stack:
Step |
Component |
---|---|
1 |
TorchDynamo |
2 |
Compiler Backend |
3 |
TorchInductor |
If info logging is insufficient, you can use available backend options. These options include:
"eager"
: only runs TorchDynamo forward graph capture and then runs the captured graph with PyTorch. This provides an indication as to whether TorchDynamo is raising the error."aot_eager"
: runs TorchDynamo to capture a forward graph, and then AOTAutograd to trace the backward graph without any additional backend compiler steps. PyTorch eager will then be used to run the forward and backward graphs. This is useful to narrow down the issue to AOTAutograd.
The general procedure to narrow down an issue is the following:
Run your program with the
"eager"
backend. If the error no longer occurs, the issue is in the backend compiler that is being used (if using TorchInductor, proceed to step 2. If not, see this section). If the error still occurs with the"eager"
backend, it is an error while running torchdynamo.This step is only necessary if
TorchInductor
is used as the backend compiler. Run the model with the"aot_eager"
backend. If this backend raises an error then the error is occurring during AOTAutograd tracing. If the error no longer occurs with this backend, then the error is in TorchInductor*.
Each of these cases are analyzed in the following sections.
Note
The TorchInductor backend consists of
both AOTAutograd tracing and the TorchInductor compiler itself. We will
disambiguate by referring to TorchInductor
as the backend, and
TorchInductor lowering as the phase which lowers the graph traced by
AOTAutograd.
Torchdynamo Errors¶
If the error that is generated occurs with the "eager"
backend, then
TorchDynamo is most likely the source of the error. Here is a sample code
which will generate an error.
import torch
import torch._dynamo as dynamo
def test_assertion_error():
y = torch.ones(200, 200)
z = {y: 5}
return z
compiled_test_assertion_error = torch.compile(test_assertion_error, backend="eager")
compiled_test_assertion_error()
The code above generates the following error:
torch._dynamo.convert_frame: [ERROR] WON'T CONVERT test_assertion_error /scratch/mlazos/torchdynamo/../test/errors.py line 26
due to:
Traceback (most recent call last):
File "/scratch/mlazos/torchdynamo/torchdynamo/symbolic_convert.py", line 837, in BUILD_MAP
assert isinstance(k, ConstantVariable) or (
AssertionError
from user code:
File "/scratch/mlazos/torchdynamo/../test/errors.py", line 34, in test_assertion_error
z = {y: 5}
Set torch._dynamo.config.verbose=True for more information
==========
As the message suggests you can set
torch._dynamo.config.verbose=True
to get a full stack trace to both
the error in TorchDynamo and the user code. In addition to this flag,
you can also set the log_level
of TorchDynamo through
torch._logging.set_logs(dynamo = logging.INFO)
or TORCH_LOGS="dynamo"
. These levels include:
logging.DEBUG
orTORCH_LOGS="+dynamo"
: Print every instruction that is encountered in addition to all the log levels listed below.logging.INFO
: Print each function that is compiled (original and modified bytecode) and the graph that is captured in addition to all the log levels listed below.logging.WARNING
(default): Print graph breaks in addition to all the log levels listed below.logging.ERROR
: Print errors only.
If a model is very large, the logs can become overwhelming. If an error occurs deep within a model’s Python code, it can be useful to execute only the frame in which the error occurs to enable easier debugging. There are two tools available to enable this:
Setting the environment variable
TORCHDYNAMO_DEBUG_FUNCTION
to the desired function name will only run torchdynamo on functions with that name.Enabling the record/replay tool (set
torch._dynamo.config.replay_record_enabled = True
) which dumps an execution record when an error is encountered. This record can then be replayed to run only the frame where an error occurred.
Diagnosing TorchInductor Errors¶
If the error does not occur with the "eager"
backend, then the
backend compiler is the source of the error (example
error).
There are different choices
for backend compilers for TorchDynamo, with TorchInductor
fitting the needs of most users. This section focuses on TorchInductor
as the motivating example, but some tools can also be used with other
backend compilers.
Below is the portion of the stack which we are focusing on:
With TorchInductor as the chosen backend, AOTAutograd is used to
generate the backward graph from the forward graph captured by
torchdynamo. It is important to note that errors can occur during this
tracing and also while TorchInductor lowers the forward and backward
graphs to GPU code or C++. A model can often consist of hundreds or
thousands of FX nodes, so narrowing the exact nodes where this problem
occurred can be very difficult. Fortunately, there are tools available to
automatically minify these input graphs to the nodes which are causing
the issue. The first step is to determine whether the error occurs
during tracing of the backward graph with AOTAutograd or during
TorchInductor lowering. As mentioned above in step 2, the
"aot_eager"
backend can be used to run only AOTAutograd in isolation
without lowering. If the error still occurs with this backend, this
indicates that the error is occurring during AOTAutograd tracing.
Here is an example:
import torch
import torch._dynamo as dynamo
model = torch.nn.Sequential(*[torch.nn.Linear(200, 200) for _ in range(5)])
def test_backend_error():
y = torch.ones(200, 200)
x = torch.ones(200, 200)
z = x + y
a = torch.ops.aten._foobar(z) # dummy function which errors
return model(a)
compiled_test_backend_error = torch.compile(test_backend_error, backend="inductor")
compiled_test_backend_error()
Running this should give you this error with a longer stack trace below it:
Traceback (most recent call last):
File "/scratch/mlazos/torchdynamo/torchinductor/graph.py", line 246, in call_function
return lowerings[target](*args, **kwargs)
File "/scratch/mlazos/torchdynamo/torchinductor/lowering.py", line 185, in wrapped
return decomp_fn(*args, **kwargs)
File "/scratch/mlazos/torchdynamo/torchinductor/lowering.py", line 810, in _foobar
assert False
AssertionError
...
If you then change torch.compile(backend="inductor")
to
torch.compile(backend="aot_eager")
, it will run without error, because
the
issue
is in the TorchInductor lowering process, not in AOTAutograd.
Minifying TorchInductor Errors¶
From here, let’s run the minifier to get a minimal repro. Setting the
environment variable TORCHDYNAMO_REPRO_AFTER="aot"
(or setting
torch._dynamo.config.repro_after="aot"
directly) will generate a
Python program which reduces the graph produced by AOTAutograd to the
smallest subgraph which reproduces the error. (See below for an example
where we minify the graph produced by TorchDynamo) Running the program
with this environment variable should show nearly identical
output,
with an additional line indicating where minifier_launcher.py
has
been written to. The output directory is configurable by setting
torch._dynamo.config.base_dir
to a valid directory name. The final
step is to run the minifier and check that it runs successfully. A
successful run looks like
this.
If the minifier runs successfully, it generates runnable python code
which reproduces the exact error. For our example this is the following
code:
import torch
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
from torch.fx.experimental.proxy_tensor import make_fx
# torch version: 1.13.0a0+gitfddfc44
# torch cuda version: 11.6
# torch git version: fddfc4488afb207971c54ad4bf58130fdc8a4dc5
# CUDA Info:
# nvcc: NVIDIA (R) Cuda compiler driver
# Copyright (c) 2005-2022 NVIDIA Corporation
# Built on Thu_Feb_10_18:23:41_PST_2022
# Cuda compilation tools, release 11.6, V11.6.112
# Build cuda_11.6.r11.6/compiler.30978841_0
# GPU Hardware Info:
# NVIDIA A100-SXM4-40GB : 8
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, add):
_foobar = torch.ops.aten._foobar.default(add); add = None
return (_foobar,)
args = [((200, 200), (200, 1), torch.float32, 'cpu')]
args = [rand_strided(shape, stride, dtype, device) for shape, stride, dtype, device in args]
mod = make_fx(Repro())(*args)
from torch._inductor.compile_fx import compile_fx_inner
compiled = compile_fx_inner(mod, args)
compiled(*args)
The forward
method of the Repro
module contains the exact op
which causes the issue. When filing an issue, please include any
minified repros to aid in debugging.
Minifying Backend Compiler Errors¶
With backend compilers other than TorchInductor the process for finding the subgraph causing the error is nearly identical to the procedure in errors in TorchInductor with one important caveat. Namely, that the minifier will now be run on the graph that is traced by TorchDynamo, not the output graph of AOTAutograd. Let’s walk through an example.
import torch
import torch._dynamo as dynamo
model = torch.nn.Sequential(*[torch.nn.Linear(200, 200) for _ in range(5)])
# toy compiler which fails if graph contains relu
def toy_compiler(gm: torch.fx.GraphModule, _):
for node in gm.graph.nodes:
if node.target == torch.relu:
assert False
return gm
def test_backend_error():
y = torch.ones(200, 200)
x = torch.ones(200, 200)
z = x + y
a = torch.relu(z)
return model(a)
compiled_test_backend_error = torch.compile(test_backend_error, backend=toy_compiler)
compiled_test_backend_error()
In order to run the code after TorchDynamo has traced the forward graph,
you can use the TORCHDYNAMO_REPRO_AFTER
environment variable. Running
this program with TORCHDYNAMO_REPRO_AFTER="dynamo"
(or
torch._dynamo.config.repro_after="dynamo"
) should produce this
outputand
the following code in {torch._dynamo.config.base_dir}/repro.py
.
Note
The other option for TORCHDYNAMO_REPRO_AFTER is "aot"
, which
will run the minifier after the backward graph has been generated.
import torch
import torch._dynamo as dynamo
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
from torch._dynamo.debug_utils import run_fwd_maybe_bwd
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, add):
relu = torch.relu(add); add = None
return (relu,)
mod = Repro().cuda()
opt_mod = torch.compile(mod, backend="None")
args = [((200, 200), (200, 1), torch.float32, 'cpu', False)]
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]
with torch.cuda.amp.autocast(enabled=False):
ref = run_fwd_maybe_bwd(mod, args)
res = run_fwd_maybe_bwd(opt_mod, args)
The minifier successfully reduced the graph to the op that raises the
error in toy_compiler
. The other difference from the procedure in
TorchInductor Errors is that the minifier is
automatically run after encountering a backend compiler error. After a
successful run, the minifier writes repro.py
to
torch._dynamo.config.base_dir
.
Performance Profiling¶
Accessing TorchDynamo Profiler¶
TorchDynamo has a built-in stats function for collecting and displaying
the time spent in each compilation phase. These stats can be accessed by
calling torch._dynamo.utils.compile_times()
after executing
Torch._Dynamo. By default, this returns a string representation of the
compile times spent in each TorchDynamo function by name.
TorchInductor Debugging using TORCH_COMPILE_DEBUG¶
TorchInductor has a builtin stats and trace function for displaying time spent in each compilation phase, output code, output graph visualization and IR dump. This is a debugging tool designed to make it easier to understand and troubleshoot the internals of TorchInductor.
Let’s run an example with the following test program (repro.py
):
import torch
@torch.compile()
def test_model(x):
model = torch.nn.Sequential(
torch.nn.Linear(10, 10),
torch.nn.LayerNorm(10),
torch.nn.ReLU(),
)
return model(x)
y = test_model(torch.ones(10, 10))
Setting the environment variable TORCH_COMPILE_DEBUG=1
will cause a
debug trace directory to be created, by default this directory will be in the
current directory and named torch_compile_debug (this can be overridden in
the torchdynamo configuration field debug_dir_root
and also the
env var TORCH_COMPILE_DEBUG_DIR
). Inside this directory, each run will
have a separate folder named with the timestamp and process id of the run:
$ env TORCH_COMPILE_DEBUG=1 python repro.py
$ cd torch_compile_debug
$ ls
run_2023_03_01_08_20_52_143510-pid_180167
In the run folder there will be a torchdynamo
directory which contains
debug logs, and an torchinductor
folder which contains a subfolder for each
compiled kernel with inductor debug artifacts.
$ cd
run_2023_03_01_08_20_52_143510-pid_180167
$ ls
torchinductor torchdynamo
Moving further into the torchinductor
directory, the \*.log
files are
logs from the AOT Autograd phase of compilation, model__0_forward_1.0
contains
the inductor debug artifacts.
$ cd torchinductor
$ ls
aot_model___0_debug.log model__0_forward_1.0
$ cd model__0_forward_1.0
$ ls
debug.log fx_graph_readable.py fx_graph_runnable.py fx_graph_transformed.py ir_post_fusion.txt ir_pre_fusion.txt output_code.py
Here is a summary of the contents:
fx_graph_readable.py
andfx_graph_runnable.py
are the readable and runnable versions of thefx_graph
received by inductor.fx_graph_transformed.py
is the fx graph after inductor has run all fx passes.ir\*.txt
is the inductor ir pre and post fusion.output_code.py
is the compiled triton kernel for the subgraph.
Here are example debug directory contents for the test program:
import torch
@torch.compile()
def test_model(x):
model = torch.nn.Sequential(
torch.nn.Linear(10, 10),
torch.nn.LayerNorm(10),
torch.nn.ReLU(),
)
return model(x)
y = test_model(torch.ones(10, 10))
Each file in that debug trace can be enabled and disabled through
torch._inductor.config.trace.*
. The profile and the diagram are both
disabled by default since they are expensive to generate.
A single node in this new debug format looks like:
buf1: SchedulerNode(ComputedBuffer)
buf1.writes =
{ MemoryDep(name='buf1', index=0, size=()),
MemoryDep(name='buf1', index=0, size=(s0,))}
buf1.unmet_dependencies = {MemoryDep(name='buf0', index=c0, size=(s0,))}
buf1.met_dependencies = {MemoryDep(name='primals_2', index=c0, size=(s0,))}
buf1.group.device = cuda:0
buf1.group.iteration = (1, s0)
buf1.sizes = ([], [s0])
class buf1_loop_body:
var_ranges = {z0: s0}
index0 = z0
index1 = 0
def body(self, ops):
get_index = self.get_index('index0')
load = ops.load('buf0', get_index, False)
get_index_1 = self.get_index('index0')
load_1 = ops.load('primals_2', get_index_1, False)
add = ops.add(load, load_1)
get_index_2 = self.get_index('index1')
reduction = ops.reduction('buf1', torch.float32, torch.float32, 'sum', get_index_2, add)
return reduction
See the example debug directory output for more examples.
Graph Breaks¶
Given a program like this:
def some_fun(x):
...
compiled_fun = torch.compile(some_fun, ...)
...
TorchDynamo will attempt to compile all of the torch/tensor operations within some_fun into a single FX graph, but it may fail to capture everything into one graph.
Some graph break reasons are insurmountable to TorchDynamo, and can’t be easily fixed. - calling into a C extension other than torch is invisible to torchdynamo, and could do arbitrary things without TorchDynamo being able to introduce necessary guards (see Making Dynamo Sound: Guards) to ensure that the compiled program would be safe to reuse. Graph breaks can hinder performance if the resulting fragments are small. To maximize performance, it’s important to have as few graph breaks as possible.
Identifying the Cause of a Graph Break¶
To identify all graph breaks in a program and the associated reasons for
the breaks, torch._dynamo.explain
can be used. This tool runs
TorchDynamo on the supplied function and aggregates the graph breaks
that are encountered. Here is an example usage:
import torch
import torch._dynamo as dynamo
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
print("woo")
if b.sum() < 0:
b = b * -1
return x * b
explanation = dynamo.explain(toy_example)(torch.randn(10), torch.randn(10))
print(explanation_verbose)
"""
Graph Count: 3
Graph Break Count: 2
Op Count: 5
Break Reasons:
Break Reason 1:
Reason: builtin: print [<class 'torch._dynamo.variables.constant.ConstantVariable'>] False
User Stack:
<FrameSummary file foo.py, line 5 in toy_example>
Break Reason 2:
Reason: generic_jump TensorVariable()
User Stack:
<FrameSummary file foo.py, line 6 in torch_dynamo_resume_in_toy_example_at_5>
Ops per Graph:
...
Out Guards:
...
"""
Outputs include:
out_guards
- a list of lists where each sublist contains the guards that must pass to ensure the traced graphs are valid.graphs
- a list of graph modules which were successfully traced.ops_per_graph
- a list of lists where each sublist contains the ops that are run in the graph.
To throw an error on the first graph break encountered, use the fullgraph
mode. This mode disables TorchDynamo’s Python fallback, and only
succeeds if the entire program is convertible into a single graph. Example
usage:
def toy_example(a, b):
...
compiled_toy = torch.compile(toy_example, fullgraph=True, backend=<compiler>)(a, b)
Excessive Recompilation¶
When TorchDynamo compiles a function (or part of one), it makes certain
assumptions about locals and globals in order to allow compiler
optimizations, and expresses these assumptions as guards that check
particular values at runtime. If any of these guards fail, Dynamo will
recompile that function (or part) up to
torch._dynamo.config.cache_size_limit
times. If your program is
hitting the cache limit, you will first need to determine which guard is
failing and what part of your program is triggering it.
The compile profiler automates the process of setting TorchDynamo’s cache limit to 1 and running your program under an observation-only ‘compiler’ that records the causes of any guard failures. You should be sure to run your program for at least as long (as many iterations) as you were running when you ran into trouble, and the profiler will accumulate statistics over this duration.
If your program exhibits a bounded amount of dynamism, you may be able to tune the TorchDynamo cache limit to allow for each variation to be compiled and cached, but if the cache limit is too high you may find the cost of recompilation outweighs any optimization benefits.
torch._dynamo.config.cache_size_limit = <your desired cache limit>
TorchDynamo plans to support many common cases of dynamic tensor shapes, such as varying batch size or sequence length. It does not plan to support rank-dynamism. In the meantime, setting a specific cache limit can be used in coordination with bucketing techniques to achieve an acceptable number of recompilations for some dynamic models.
from torch._dynamo.utils import CompileProfiler
def my_model():
...
with CompileProfiler() as prof:
profiler_model = torch.compile(my_model, backend=prof)
profiler_model()
print(prof.report())
Accuracy Debugging¶
Accuracy issues can also be minified if you set the environment variable
TORCHDYNAMO_REPRO_LEVEL=4
, it operates with a similar git bisect
model and a full repro might be something like
TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4
the reason
we need this is downstream compilers will codegen code whether it’s
Triton code or the C++ backend, the numerics from those downstream
compilers can be different in subtle ways yet have dramatic impact on
your training stability. So the accuracy debugger is very useful for us
to detect bugs in our codegen or with a backend compiler.
If you’d like to ensure that random number generation is the same across both torch
and triton then you can enable torch._inductor.config.fallback_random = True
Extended Debugging¶
Extended debugging can be enabled by using the following experimental flags.
TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED
- provides extended debug information if the
string representation of a guard matches this flag value. For example, set it to
“Ne(s0, 10)” to generate full Python and C++ backtrace whenever guard was issued.
TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL
- provides extended debug information when
a particular symbol is allocated. For example, set this to “u2” to generate full Python
and C++ backtrace whenever this symbol was created.
TORCHDYNAMO_EXTENDED_DEBUG_CPP
- provides extended debug information (C++ backtrace)
for all extended debug settings as well as errors. For example, set this to “1”. The C++
backtrace is slow and very spammy so it is not included by default with extended debugging.
Cold Start Timing and Cache Corruption Debugging¶
In order to measure the cold start compilation time or debug a cache corruption,
it is possible pass TORCHINDUCTOR_FORCE_DISABLE_CACHES=1
or set
torch._inductor.config.force_disable_caches = True
which will override any
other caching config option and disable all compile time caching.