Shortcuts

torch.compile Troubleshooting

You’re trying to use torch.compile on your PyTorch model to enhance its performance but it’s not working as expected. Perhaps performance isn’t improving, crashes are happening, or compilation time is too long. This article provides tips, workarounds, and debugging tools to help you overcome these challenges.

Contents

Setting Expectations

torch.compile is designed as a general-purpose PyTorch compiler. Unlike the previous compiler solution, TorchScript, torch.compile requires fewer code changes, meaning models typically don’t need to be rewritten from scratch. It also manages unsupported code more gracefully - unsupported code results in a lost optimization opportunity rather than a crash.

In the ideal world, one can simply apply torch.compile to any PyTorch model and enjoy automatic speedups. However, in reality, code complexities can lead to one of three scenarios:

  1. torch.compile works seamlessly, providing speedups.

  2. Some code modifications are necessary. torch.compile doesn’t crash or take too long, but you might not be seeing significant performance gains.

  3. Extensive changes to your code are required.

We anticipate most code will fall under scenarios (1) and (2). This document provides tips, arranged by level of involvement, to help address code issues in scenario (2).

Compile times

torch.compile functions as a just-in-time compiler, so the initial one or two runs of the compiled function are expected to be significantly slower. Recompilations, which can occur under certain conditions (detailed below), will also make runs slower. Various torch.compile components cache results to reduce compilation time for future invocations, even in different processes. Cold-start (uncached) compilation time typically ranges from seconds to minutes for common or benchmarked models. Larger models may take upwards of 30 minutes to a few hours.

Terminology

The following terms are relevant to troubleshooting torch.compile problems.

Graph break

torch.compile traces your code and attempts to capture your PyTorch code into a single computation graph of PyTorch operators (FX graph). However, this is not always possible. When encountering code that can’t be traced, a “graph break” occurs. A graph break involves compiling the FX graph has been determined so far, running the unsupported code, then resuming tracing after the unsupported code with a new FX graph. Because the computation graph is broken up, we lose optimization opportunities, so model code should avoid graph breaks whenever possible. Graph breaks occur on things like:

  • Data-dependent if-statements

  • Many Python built-in functions

  • C functions

Below is an example of a graph break due to the function copy.deepcopy from a Python builtin library (exact output may differ).

import torch

@torch.compile
def fn(x):
    x = x + 1
    with open("test.txt", "r") as f:
        return x + len(f.read())

fn(torch.ones(3, 3))
$TORCH_LOGS="graph_breaks" python playground.py
Graph break in user code at /data/users/williamwen/pytorch/playground.py:7
Reason: Unsupported: builtin: open [<class 'torch._dynamo.variables.constant.ConstantVariable'>, <class 'torch._dynamo.variables.constant.ConstantVariable'>] False
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 7, in fn
    with open("test.txt", "r") as f:
Traceback (most recent call last):
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 635, in wrapper
    return inner_fn(self, inst)
        ^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2414, in CALL
    self._call(inst)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2408, in _call
    self.call_function(fn, args, kwargs)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 962, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/builtin.py", line 997, in call_function
    return handler(tx, args, kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/builtin.py", line 831, in <lambda>
    return lambda *args: unimplemented(error_msg)
                        ^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/exc.py", line 313, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: builtin: open [<class 'torch._dynamo.variables.constant.ConstantVariable'>, <class 'torch._dynamo.variables.constant.ConstantVariable'>] False

Guards

torch.compile makes some assumptions about runtime values as we trace through code. During tracing, we generate “guards”, which are runtime checks for these assumptions. Guards are run in future calls to the compiled function to determine if we can reuse previously compiled code. Examples of runtime checks are constant values, types, and object IDs.

Below is an example of generated guards. The TENSOR_MATCH guard checks for the input’s type, device, dtype, shape, etc.

import torch

@torch.compile
def fn(x):
    return x + 1

fn(torch.ones(3, 3))
$ TORCH_LOGS="guards" python playground.py
GUARDS:

TREE_GUARD_MANAGER:
+- RootGuardManager
| +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:471 in init_ambient_guards
| +- GLOBAL_STATE: ___check_global_state()
| +- TORCH_FUNCTION_MODE_STACK: ___check_torch_function_mode_stack()
| +- GuardManager: source=L['x'], accessed_by=DictGetItemGuardAccessor(x)
| | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3, 3], stride=[3, 1])  # return x + 1  # playground.py:6 in fn
| | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False           # return x + 1  # playground.py:6 in fn

Recompilation

If the guards fail for every instance of previously compiled code, then torch.compile must “recompile” the function, requiring the original code to be traced again.

In the example below, recompilation is necessary because the guard checking the tensor argument’s shape failed.

import torch

@torch.compile
def fn(x):
    return x + 1

fn(torch.ones(3, 3))
fn(torch.ones(4, 4))
$ TORCH_LOGS="recompiles" python playground.py
Recompiling function fn in /data/users/williamwen/pytorch/playground.py:3
    triggered by the following guard failure(s):
    - 0/0: tensor 'L['x']' size mismatch at index 0. expected 3, actual 4

Dynamic Shapes

torch.compile initially assumes tensor shapes are static/constant and guards based on these assumptions. By using “dynamic shapes,” we can get torch.compile to produce compiled code that can accept tensor inputs with different shapes - we avoid recompiling every time shapes differ. By default, automatic dynamic shapes are enabled torch.compile(dynamic=None) - if compilation fails due to shape mismatch, recompilation is attempted with dynamic shapes. Dynamic shapes can also be fully enabled dynamic=True or disabled dynamic=False.

Below, we enable dynamic shapes and note that we no longer need to recompile.

import torch

@torch.compile(dynamic=True)
def fn(x):
    return x + 1

fn(torch.ones(3, 3))
fn(torch.ones(4, 4))
$ TORCH_LOGS="dynamic,recompiles" python playground.py
create_symbol s0 = 3 for L['x'].size()[0] [2, int_oo] at playground.py:5 in fn (_dynamo/variables/builder.py:2718 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0"
produce_guards
produce_guards

For more information on dynamic shapes, see The dynamic shapes manual.

Logging Tools

tlparse / TORCH_TRACE

tlparse / TORCH_TRACE are a pair of tools that produce compilation reports that look like this: https://web.mit.edu/~ezyang/Public/bhack-20240609-tlparse/index.html.

Traces are very easy to collect. To collect a trace, run your reproduction command with

TORCH_TRACE="/tmp/tracedir" python foo.py
pip install tlparse
tlparse /tmp/tracedir

This approach works even if you are running a distributed job, providing a trace for each rank. It will open your browser with HTML similar to what’s generated above. If you are making a bug report for a complicated problem that you don’t have a standalone reproduction for, you can still greatly assist PyTorch developers by attaching the trace log generated in /tmp/tracedir.

Warning

The trace log contains all of your model code. Do not share the trace log if the model you are working on is sensitive. The trace log does NOT contain weights.

The output of tlparse is primarily aimed for PyTorch developers, and the log format is easy to upload and share on GitHub. However, as a non-PyTorch developer, you can still extract useful information from it. We recommend starting with the inline help text in the report, which explains its contents. Here are some insights you can gain from a tlparse:

  • What model code was compiled by looking at the stack trie? This is especially useful if you’re not familiar with the codebase being compiled!

  • How many graph breaks / distinct compilation regions are there? (Each distinct compile is its own color coded block like [0/0]). Frames that are potentially graph-broken are light green [2/4]. If there are a lot of frames, that is suspicious, and suggests that you had some catastrophic graph breaks, or maybe your code isn’t a good match for torch.compile.

  • How many times did I recompile a particular frame? Something that recompiled a lot will look like: [10/0] [10/1] [10/2] - if something is being recompiled a lot, that is very suspicious and worth looking into, even if it isn’t the root cause of your problem.

  • Was there a compilation error? Frames that errored will look like [0/1].

  • What intermediate compiler products did I generate for a given frame? For example, you can look at the high-level generated FX graph or the generated Triton code.

  • Is there relevant information for a particular frame? You can find these in compilation_metrics.

TORCH_LOGS

You can use the TORCH_LOGS environment variable to selectively enable parts of the torch.compile stack to log. TORCH_LOGS is in fact the source of logs for tlparse. The format of the TORCH_LOGS environment variable looks like this:

TORCH_LOGS="<option1>,<option2>,..." python foo.py

Useful high-level options include:

  • graph_breaks: logs locations of graph breaks in user code and the reason for the graph break

  • guards: logs guards that are generated

  • recompiles: logs which function recompiled and the guards that failed, leading to the recompilation

  • dynamic: logs related to dynamic shapes

Also, you can programmatically set logging options using torch._logging.set_logs:

import logging
torch._logging.set_logs(graph_breaks=True)
...

More TORCH_LOGS options are detailed below. For the full list of options, see torch._logging and torch._logging.set_logs.

tlparse vs. TORCH_LOGS

Generally, we suggest first using tlparse when encountering issues. tlparse is ideal for debugging large models and gaining a high-level overview of how your model was compiled. On the other hand, TORCH_LOGS is preferred for small examples and fine-grained debugging detail, when we already have an idea of which torch.compile component is causing the problem.

Simple Workarounds

Here, we describe some workarounds to torch.compile issues involving small code modifications or changing some torch.compile settings.

Where to apply torch.compile?

We recommend applying torch.compile to the highest-level function that doesn’t cause excessive problems. Typically, it is your train or eval step with the optimizer but without the loop, your top-level nn.Module, or some sub-nn.Module``s. ``torch.compile specifically doesn’t handle distributed wrapper modules like DDP or FSDP very well, so consider applying torch.compile to the inner module passed to the wrapper.

# inference
model = ...
opt_model = torch.compile(model)

for _ in range(N_ITERS):
    inp = ...
    out = opt_model(inp)
# training
model = ...
opt = torch.optim.Adam(model.parameters())

@torch.compile
def train(mod, data):
    opt.zero_grad(True)
    pred = mod(data[0])
    loss = torch.nn.CrossEntropyLoss()(pred, data[1])
    loss.backward()
    opt.step()

for _ in range(N_ITERS):
    inp = ...
    train(model, inp)
# DistributedDataParallel
model = ...
opt_model = torch.compile(model)
model_ddp = DistributedDataParallel(opt_model, ...)

for _ in range(N_ITERS):
    inp = ...
    out = model_ddp(inp)

Disabling and Suppressing Errors

For some model architectures, there are portions of the model which are particularly difficult to compile - either there are many graph breaks, or there are crashes. You may want to explicitly disable these portions of the model which are problematic so that you can apply torch.compile to the parts that work. You can do this by using the @torch.compiler.disable decorator. When torch.compile attempts to call a disabled function, it breaks the graph and skips tracing the disabled function, resuming tracing after the call. By default, all recursive calls made from a disabled function are also disabled. Use the recursive=False option to allow compilation for recursive calls.

def bad1_inner(...):
    # skipped

@torch.compiler.disable
def bad1_outer(...):
    # skipped
    bad1_inner(...)

def bad2_inner(...)
    # traced

@torch.compiler.disable(recursive=False)
def bad2_outer(...):
    # skipped
    bad2_inner(...)

@torch.compile
def fn(...):
    # graph break
    bad1_outer(...)
    ...
    # graph break
    bad2_outer(...)

For example, we use torch.compiler.disable to disable torch.compile on sparse architecture in recommendation models, as the sparse arch is difficult to compile. Preprocessing and logging functions are other examples of functions that typically cause a lot of graph breaks and do not get value from being compiled.

If you are experiencing compiler crashes and you want to continue regardless, you can set torch._dynamo.config.suppress_errors = True. When the compiler crashes, we will just skip tracing the function and try again later. This is not best practice - it is better to eventually manually add disable annotations as necessary.

Resolving graph breaks

To maximize optimization opportunities, it’s important to reduce the number of graph breaks. Recall that you can see what graph breaks are happening using tlparse or TORCH_LOGS="graph_breaks". In general, graph breaks are caused by one of the following:

  1. You’re trying to do something that fundamentally cannot be traced, such as data-dependent control flow.

  2. You’re trying to do something not yet supported. . For example, we currently have limited support for tracing code that uses the built-in Python inspect module.

  3. Your code has an error in it. For example, you may have tried calling a function with an incorrect number of arguments.

Graph break logs will tell you the user code location and reason for the graph break. Unfortunately, many graph breaks are not actionable without a deeper understanding of Dynamo. It can even be challenging to determine which of the three causes was the true cause of your graph break. We are working on making graph break messages more actionable.

Additionally, the impact of lost optimization opportunities differs between graph breaks. For example, graph breaks that happen in the middle of your model’s forward are likely to have a more negatie impact than graph breaks in a preprocessing part at the beginning of the forward. So it is not crucial to prevent every single break, but rather to prevent the ones that cause significant performance hits.

If a graph break message doesn’t suggest any action, you suspect that the cause of your graph break is (2), and you believe that the graph break is causing performance hits, then please report the graph break as an issue. If a function has many graph breaks, consider disabling compilation on that function, as the overhead cost for the graph breaks may become prohibitive.

Below are some common graph breaks and some workarounds.

Data-dependent operations

torch.compile graph breaks on data-dependent operations such as data-dependent control flow (if-statements, loops with tensors) and direct tensor data accesses (.item, .data_ptr).

import torch

@torch.compile
def fn(x):
    y = x.sum()
    if y > 0:
        return x + y.item()
    return x - y.item()

fn(torch.ones(3, 3))
$ TORCH_LOGS="graph_breaks" python playground.py
Graph break in user code at /data/users/williamwen/pytorch/playground.py:6
Reason: Data-dependent jump
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 6, in fn
    if y > 0:

Graph break in user code at /data/users/williamwen/pytorch/playground.py:7
Reason: Unsupported: Tensor.item
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 7, in torch_dynamo_resume_in_fn_at_6
    return x + y.item()
Traceback (most recent call last):
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 616, in wrapper
    return inner_fn(self, inst)
        ^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2288, in CALL
    self._call(inst)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2282, in _call
    self.call_function(fn, args, kwargs)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 838, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/misc.py", line 1038, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/tensor.py", line 527, in call_method
    result = handler_method(*args, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/tensor.py", line 773, in method_item
    unimplemented("Tensor.item")
File "/data/users/williamwen/pytorch/torch/_dynamo/exc.py", line 304, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: Tensor.item

The general workaround for these graph breaks is to avoid doing data-dependent operations. Some specific workarounds are:

  • If your control flow doesn’t actually depend on data values, consider modifying your code to perform control flow on constants.

# old
x = torch.randn(3, 3)
@torch.compile
def fn(y):
    if x.sum() > 0:
        return y + x
    else:
        return y - x

# new
x = torch.randn(3, 3)
cond = (x.sum() > 0).item()
@torch.compile
def fn(y):
    if cond:
        return y + x
    else:
        return y - x
# old
@torch.compile
def fn(x):
    if x.sum() > 0:
        return x + 1
    return x - 1

# new
@torch.compile
def fn(x):
    return torch.cond(
        x.sum() > 0,
        lambda x: x + 1,
        lambda x: x - 1,
        (x,),
    )
  • If you have a .item() call, try torch._dynamo.config.capture_scalar_outputs = True or TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1

  • Wrap problematic parts of the function in a custom op

Custom ops

If you have code that torch.compile has trouble tracing through, either due to missing support or fundamental incompatibility, you can consider wrapping the problematic code in a custom op.

Custom ops require a little bit of additional work to get them to be compatible with torch.compile. See https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details.

Printing

Printing/logging/issuing warnings will result in a graph break. If you have a function that makes many logging calls, for example, a function that logs data about a training iteration, consider applying torch.compiler.disable on it.

Alternatively, you can try using torch._dynamo.config.reorderable_logging_functions. This config is used to reorder logging functions so that they are called at the end of the traced function, thus avoiding a graph break. However, the logged contents may differ if, for example, a mutation occurs.

import torch

torch._dynamo.config.reorderable_logging_functions.add(print)

@torch.compile
def fn(x):
    x += 1
    print("log!")
    return torch.sin(x)

fn(torch.ones(3, 3))
$ TORCH_LOGS="graph_breaks" python playground.py
log!

Incorrect code

Your code may be wrong, or is otherwise encountering an error from outside torch.compile. In the code below, we made a typo in the torch.sin call by providing an extra argument.

import torch

@torch.compile
def fn(x):
    y = torch.sin(x, x)
    return y

fn(torch.ones(3, 3))
$ TORCH_LOGS="graph_breaks" python playground.py
Graph break in user code at /data/users/williamwen/pytorch/playground.py:5
Reason: Unsupported: TypeError <built-in method sin of type object at 0x7fd6fd764600>: sin() takes 1 positional argument but 2 were given
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 5, in fn
    y = torch.sin(x, x)
...

It can be difficult to tell from the logs if the error is caused by your code or because of a torch.compile bug. In order to differentiate, we recommend trying to run your code without torch.compile to see if you still get the error.

Dealing with recompilations

You can view recompilations and their reasons using tlparse or TORCH_LOGS=recompiles.

Is dynamic shapes enabled?

Recompilations due to mismatched shapes are in the form:

tensor 'L['x']' size mismatch at index 0. expected 3, actual 4

Make sure that the dynamic option of torch.compile is not set to False. The default option, dynamic=None, will only attempt dynamic shapes after the first compilation. You can set dynamic=True to upfront compile as dynamic as possible.

For more information on dynamic shapes, see The dynamic shapes manual.

Changing the cache size limit

There is a limit to how many times a function can be recompiled, determined by torch._dynamo.config.cache_size_limit and torch._dynamo.config.accumulated_cache_size_limit. If either limit is exceeded, then we will not attempt to compile the function again and instead will run the function eagerly. torch.compile will also issue a warning containing the affected function and which limit was hit. In the example below, each function call results in a recompile attempt. When we hit the cache size limit (8), we stop attempting to recompile.

import torch

@torch.compile(dynamic=False)
def fn(x):
    return x + 1

for i in range(1, 10):
    fn(torch.ones(i))
$ python playground.py
torch._dynamo hit config.cache_size_limit (8)
    function: 'fn' (/data/users/williamwen/pytorch/playground.py:5)
    last reason: 0/0: tensor 'L['x']' size mismatch at index 0. expected 1, actual 9

If you know that the number of recompilations has a reasonable constant upper bound, you can raise the cache size limit. If the cost of recompilation outweighs the benefit of compilation, then you can consider lowering the cache size limit.

Wrapping constants with tensors

By default, int / float variables are treated as constants and are guarded as such. In the below example, we have a recompilation for each function call.

import torch

@torch.compile
def fn(x, c):
    return x + c

for i in range(1, 10):
    fn(torch.ones(i), 0.5 + i)
$ TORCH_LOGS="recompiles" python playground.py
Recompiling function fn in /data/users/williamwen/pytorch/playground.py:3
    triggered by the following guard failure(s):
    - 0/7: L['c'] == 8.5
    - 0/6: L['c'] == 7.5
    - 0/5: L['c'] == 6.5
    - 0/4: L['c'] == 5.5
    - 0/3: L['c'] == 4.5
    - 0/2: L['c'] == 3.5
    - 0/1: L['c'] == 2.5
    - 0/0: L['c'] == 1.5
torch._dynamo hit config.cache_size_limit (8)
    function: 'fn' (/data/users/williamwen/pytorch/playground.py:3)
    last reason: 0/0: L['c'] == 1.5

In particular, for LR schedulers, initializing with a constant can lead to recompilations:

import torch

mod = torch.nn.Linear(3, 3)
opt = torch.optim.Adam(mod.parameters(), lr=0.01)
sched = torch.optim.lr_scheduler.ExponentialLR(opt, 0.9)

@torch.compile
def fn(inp):
    opt.zero_grad(True)
    out = mod(inp).sum()
    out.backward()
    opt.step()
    sched.step()

for i in range(1, 10):
    fn(torch.ones(3, 3))
$ TORCH_LOGS="recompiles" python playground.py
Recompiling function step in /data/users/williamwen/pytorch/torch/optim/adam.py:189
    triggered by the following guard failure(s):
    - 3/7: L['self'].param_groups[0]['lr'] == 0.004782969000000002
    - 3/6: L['self'].param_groups[0]['lr'] == 0.005314410000000002
    - 3/5: L['self'].param_groups[0]['lr'] == 0.005904900000000002
    - 3/4: L['self'].param_groups[0]['lr'] == 0.006561000000000002
    - 3/3: L['self'].param_groups[0]['lr'] == 0.007290000000000001
    - 3/2: L['self'].param_groups[0]['lr'] == 0.008100000000000001
    - 3/1: L['self'].param_groups[0]['lr'] == 0.009000000000000001
    - 3/0: L['self'].param_groups[0]['lr'] == 0.01
torch._dynamo hit config.cache_size_limit (8)
    function: 'step' (/data/users/williamwen/pytorch/torch/optim/adam.py:189)
    last reason: 3/0: L['self'].param_groups[0]['lr'] == 0.01

In both examples, we can wrap float variables in tensors in order to prevent recompilations.

# first example
for i in range(1, 10):
    fn(torch.ones(i), torch.tensor(0.5 + i))

# second example
opt = torch.optim.Adam(mod.parameters(), lr=torch.tensor(0.01))
sched = torch.optim.lr_scheduler.ExponentialLR(opt, torch.tensor(0.9))

Reporting Issues

If the workarounds provided above were not enough to get torch.compile working, then you should consider reporting the issue to PyTorch. But there are a few things that you can do to make our lives significantly easier.

Ablation

Check which component of the torch.compile stack is the one causing the issue using the backend= option for torch.compile. In particular, try:

  • torch.compile(fn, backend="eager"), which only runs TorchDynamo, the graph capture component of torch.compile.

  • torch.compile(fn, backend="aot_eager"), which runs TorchDynamo and AOTAutograd, which additionally generates the backward graph during compilation.

  • torch.compile(fn, backend="aot_eager_decomp_partition"), which runs TorchDynamo and AOTAutograd with operator decompositions/partitions.

  • torch.compile(fn, backend="inductor"), which runs TorchDynamo, AOTAutograd, and TorchInductor, the backend ML compiler that generates compiled kernels.

If you only fail with the Inductor backend, you can additionally test various Inductor modes:

  • torch.compile(fn, backend="inductor", mode="default")

  • torch.compile(fn, backend="inductor", mode="reduce-overhead")

  • torch.compile(fn, backend="inductor", mode="max-autotune")

You can also check if dynamic shapes is causing issues with any backend:

  • torch.compile(fn, dynamic=True) (always use dynamic shapes)

  • torch.compile(fn, dynamic=False) (never use dynamic shapes)

  • torch.compile(fn, dynamic=None) (automatic dynamic shapes)

Bisecting

Did you try on the latest nightly? Did something work in the past but now no longer works? Can you bisect to determine the first nightly where your issue occurs? Bisecting is especially helpful for performance, accuracy, or compile time regressions, where it is not immediately obvious where the problem originates from.

Creating a reproducer

Creating reproducers is a lot of work, and it is perfectly fine if you do not have the time to do it. However, if you are a motivated user unfamiliar with the internals of torch.compile, creating a standalone reproducer can have a huge impact on our ability to fix the bug. Without a reproducer, your bug report must contain enough information for us to identify the root cause of the problem and write a reproducer from scratch.

Here’s a list of useful reproducers, ranked from most to least preferred:

  1. Self-contained, small reproducer: A script with no external dependencies, under 100 lines of code, that reproduces the problem when run.

  2. Self-contained, large reproducer: Even if it’s large, being self-contained is a huge advantage!

  3. Non-self-contained reproducer with manageable dependencies: For example, if you can reproduce the problem by running a script after pip install transformers, that’s manageable. We can likely run it and investigate.

  4. Non-self-contained reproducer requiring substantial setup: This might involve downloading datasets, multiple environment setup steps, or specific system library versions requiring a Docker image. The more complex the setup, the harder it is for us to recreate the environment.

    Note

    Docker simplifies setup but complicates changes to the environment, so it’s not a perfect solution, though we’ll use it if necessary.

Somewhat orthogonally, a reproducer that can be run in a single process is better than a reproducer that requires multiprocess training (but once again, if you only have a multiprocess reproducer, we’ll take it!).

Additionally, below is a non-exhaustive list of aspects to check in your issue that you can attempt to replicate in your reproducer:

  • Autograd. Did you have tensor inputs with requires_grad=True? Did you call backward() on the output?

  • Dynamic shapes. Did you set dynamic=True? Or did you run the test code multiple times with varying shapes?

  • Custom operators. Is there a custom operator involved in the real workflow? Can you replicate some of its important characteristics using the Python custom operator API?

  • Configuration. Did you set all the same configuration? This includes torch._dynamo.config and torch._inductor.config settings, as well as arguments to torch.compile like backend / mode.

  • Context managers. Did you replicate any active context managers? This could be torch.no_grad, automatic mixed precision, TorchFunctionMode / TorchDispatchMode, activation checkpointing, compiled autograd etc.

  • Tensor subclasses. Is there a tensor subclass involved?

Minifier

The minifier is an early torch.compile tool that, given an FX graph that crashes when we attempt to run or compile it, finds a subgraph that also crashes and outputs the code that performs that subgraph’s operations. Essentially, the minifier finds a minimal repro for a certain class of torch.compile-related crashes. This assumes that we were able to successfully trace through code.

Unfortunately, most of the time nowadays, the minifier doesn’t work as expected, and alternative methods may be necessary. This is likely because bugs that can be automatically reproduced in this manner are generally easier to fix and have already been addressed, leaving more complex issues that do not reproduce easily. However, it is straightforward to attempt using the minifier, so it is worth trying even if it may not succeed.

Instructions for operating the minifier can be found here. If the compiler is crashing, you can set TORCHDYNAMO_REPRO_AFTER="dynamo" or TORCHDYNAMO_REPRO_AFTER="aot" The aot option is more likely to succeed, although it may not identify the AOTAutograd issues. This will generate the repro.py file which may help to diagnose the problem. For accuracy-related issues, consider setting TORCHDYNAMO_REPRO_LEVEL=4. Please note that this may not always successfully identify the problematic subgraph.

Debugging Deeper

This section provides tools and techniques for independently debugging torch.compile issues or for gaining a deeper understanding of the torch.compile stack. These methods are more involved than those presented above and are used by PyTorch developers regularly to debug real torch.compile issues.

Below is a high-level overview of the stack:

_images/td_stack.png

The stack comprises three main components: TorchDynamo, AOTAutograd, and Inductor. Our debugging strategy involves first identifying the component in which the error occurs and then individually debugging the component. To determine the component responsible for the issue, see the Ablation section under Reporting Issues above. For guidance on debugging a specific component, consult the sections below.

TorchDynamo

Logging what Dynamo is tracing

The TORCH_LOGS=trace_bytecode option enables you to view the precise bytecode instructions that Dynamo is tracing, as well as a symbolic representation of the Python interpreter stack. When encountering a graph break or crash, it is advisable to inspect the last few bytecode instructions traced.

You can also use TORCH_LOGS=trace_source to see which lines of source code Dynamo is tracing through. This is useful in combination with trace_bytecode to see the line of source code each traced bytecode instruction corresponds to.

Finally, you can use TORCH_LOGS=graph_code to see the Python code representing the FX graph that Dynamo traced. You can view this code to double check that the correct ops are being traced.

import torch

def g(x, y):
    return x + y

@torch.compile(backend="eager")
def f(x):
    x = torch.sin(x)
    x = g(x, x)
    return x

f(torch.ones(3, 3))
$ TORCH_LOGS="trace_bytecode,trace_source,graph_code" python playground.py
TRACE starts_line /data/users/williamwen/pytorch/playground.py:6 in f ()
    @torch.compile(backend="eager")
TRACE RESUME 0 []
TRACE starts_line /data/users/williamwen/pytorch/playground.py:8 in f (f)
        x = torch.sin(x)
TRACE LOAD_GLOBAL torch []
TRACE LOAD_ATTR sin [NullVariable(), PythonModuleVariable(<module 'torch' from '/data/users/williamwen/pytorch/torch/__init__.py'>)]
TRACE LOAD_FAST x [NullVariable(), TorchInGraphFunctionVariable(<built-in method sin of type object at 0x7f00f6964600>)]
TRACE CALL 1 [NullVariable(), TorchInGraphFunctionVariable(<built-in method sin of type object at 0x7f00f6964600>), LazyVariableTracker()]
TRACE STORE_FAST x [TensorVariable()]
TRACE starts_line /data/users/williamwen/pytorch/playground.py:9 in f (f)
        x = g(x, x)
TRACE LOAD_GLOBAL g []
TRACE LOAD_FAST x [NullVariable(), UserFunctionVariable()]
TRACE LOAD_FAST x [NullVariable(), UserFunctionVariable(), TensorVariable()]
TRACE CALL 2 [NullVariable(), UserFunctionVariable(), TensorVariable(), TensorVariable()]
TRACE starts_line /data/users/williamwen/pytorch/playground.py:3 in g (g) (inline depth: 1)
    def g(x, y):
TRACE RESUME 0 []
TRACE starts_line /data/users/williamwen/pytorch/playground.py:4 in g (g) (inline depth: 1)
        return x + y
TRACE LOAD_FAST x []
TRACE LOAD_FAST y [TensorVariable()]
TRACE BINARY_OP 0 [TensorVariable(), TensorVariable()]
TRACE RETURN_VALUE None [TensorVariable()]
TRACE STORE_FAST x [TensorVariable()]
TRACE starts_line /data/users/williamwen/pytorch/playground.py:10 in f (f)
        return x
TRACE LOAD_FAST x []
TRACE RETURN_VALUE None [TensorVariable()]
TRACED GRAPH
===== __compiled_fn_1 =====
/data/users/williamwen/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3, 3][3, 1]cpu"):
        l_x_ = L_x_

        # File: /data/users/williamwen/pytorch/playground.py:8 in f, code: x = torch.sin(x)
        x: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_);  l_x_ = None

        # File: /data/users/williamwen/pytorch/playground.py:4 in g, code: return x + y
        x_1: "f32[3, 3][3, 1]cpu" = x + x;  x = None
        return (x_1,)

Breakpointing Dynamo tracing

Inserting a breakpoint in Dynamo/user code is helpful at times to see what the state of Dynamo is when tracing through user code. Unfortunately, inserting a breakpoint in the normal Python fashion will result in a graph break in TorchDynamo, so we will not be able to view the state of Dynamo at the point where we intended to breakpoint.

The first method for setting a breakpoint is to insert it within the Dynamo source code. Three recommended locations to place a breakpoint are:

  • In torch/_dynamo/symbolic_convert.py, breakpoint at functions that are named after the problematic bytecode instruction, such as def CALL_FUNCTION and def STORE_ATTR. You can conditionally breakpoint depending on inputs, for example, the argval of the instruction, or the name of the object at the top of the stack since some bytecode opcodes are frequently used.

  • Breakpoint where the graph break or error originates from. Typically, graph breaks are emitted from a call to unimplemented(...).

  • Breakpoint in torch/_dynamo/variables/builder.py, function:_wrap. You will likely have to conditionally breakpoint on the input. This function determines how to symbolically represent a given value. Consider breakpointing here if you suspect that a value is represented incorrectly.

The second way to insert a breakpoint is to use torch._dynamo.comptime.comptime.breakpoint:

from torch._dynamo.comptime import comptime

@torch.compile
def f(...):
    ...
    comptime.breakpoint()
    ...

A comptime breakpoint is convenient as it enables you to inspect the Dynamo state at a specific location within the user code being traced. It does not require you to insert a breakpoint in the Dynamo source or to conditionally breakpoint based on variables.

When a comptime breakpoint is triggered, you can do the following:

  • ctx.print_bt() to print the user stack trace

  • ctx.print_locals() to print all current locals

  • ctx.print_graph() to print the currently traced graph

  • ctx.disas() to print the currently traced function’s bytecode

  • Use standard pdb commands, such as bt/u/d/n/s/r, - you can go up the pdb stack to inspect more Dynamo internals

import torch
from torch._dynamo.comptime import comptime

@torch.compile(backend="eager")
def f(x):
    y = x + 1
    comptime.breakpoint()
    y = y + 1
    return y

f(torch.ones(3, 3))
$ python playground.py
--Return--
> /data/users/williamwen/pytorch/torch/_dynamo/comptime.py(392)inner()->None
-> builtins.breakpoint()
(Pdb) ctx.print_bt()
File "/data/users/williamwen/pytorch/playground.py", line 7, in f
    comptime.breakpoint()

(Pdb) ctx.print_locals()
x = FakeTensor(..., size=(3, 3))
y = FakeTensor(..., size=(3, 3))
(Pdb) bt
...
/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py(826)call_function()
-> self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
/data/users/williamwen/pytorch/torch/_dynamo/variables/misc.py(331)call_function()
-> func(ComptimeContext(tx))
> /data/users/williamwen/pytorch/torch/_dynamo/comptime.py(392)inner()->None
-> builtins.breakpoint()
(Pdb) ctx.print_graph()



def forward(self, L_x_: "f32[3, 3]"):
    l_x_ = L_x_

    # File: /data/users/williamwen/pytorch/playground.py:6 in f, code: y = x + 1
    y: "f32[3, 3]" = l_x_ + 1;  l_x_ = y = None

Bytecode generation errors

Although uncommon, Dynamo may generate incorrect bytecode. This may occur if you determine the following:

  • Ablation reveals the error is happening at the TorchDynamo level

  • The error is not being emitted from TorchDynamo stack frames

  • The error looks more like a user error rather than a Dynamo error, or is a segmentation fault

  • The error does not occur without torch.compile

Bytecode generation bugs are generally tricky to fix and we recommend submitting an issue instead of trying to fix those yourself. If you are interested in seeing the bytecode that Dynamo generates, you can use TORCH_LOGS=bytecode. You can see a high-level overview on what bytecode Dynamo generates here.

AOTAutograd

AOTAutograd errors are typically difficult to debug - we recommend just submitting an issue. AOTAutograd logging output is primarily helpful to see what the input to Inductor is.

Summary of TORCH_LOGS options

A summary of helpful TORCH_LOGS options is:

Option

Description

+all

Output debug logs from all torch.compile components

+dynamo

Output debug logs from TorchDynamo

+aot

Output debug logs from AOTAutograd

+inductor

Output debug logs from TorchInductor

dynamic

Output logs from dynamic shapes

graph_code

Output the Python code for the FX graph that Dynamo generated

graph_sizes

Output the tensor sizes of the FX graph that Dynamo generated

trace_bytecode

Output the bytecode instructions that Dynamo is tracing through and the symbolic interpreter stack Dynamo is keeping track of

trace_source

Output the line of code in the original source that Dynamo is currently tracing through

bytecode

Output Dynamo-generated bytecode

guards

Output generated guards

recompiles

Output recompilation reasons (only the first guard check that fails)

recompiles_verbose

Output all guard checks that fail when a recompilation occurs

aot_graphs

Output graph generated by AOTAutograd

aot_joint_graphs

Output the joint forward-backward graph generated by AOTAutograd

output_code

Output code generated by Inductor

kernel_code

Output code generated by Inductor on a per-kernel basis

schedule

Output Inductor scheduling logs

perf_hints

Output Inductor perf hint logs

fusion

Output Inductor fusion logs

For the full list of options, see torch._logging and torch._logging.set_logs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources