• Tutorials >
  • Dynamic Compilation Control with torch.compiler.set_stance
Shortcuts

Dynamic Compilation Control with torch.compiler.set_stance

Author: William Wen

torch.compiler.set_stance is a torch.compiler API that enables you to change the behavior of torch.compile across different calls to your model without having to reapply torch.compile to your model.

This recipe provides some examples on how to use torch.compiler.set_stance.

Prerequisites

  • torch >= 2.6

Description

torch.compile.set_stance can be used as a decorator, context manager, or raw function to change the behavior of torch.compile across different calls to your model.

In the example below, the "force_eager" stance ignores all torch.compile directives.

import torch


@torch.compile
def foo(x):
    if torch.compiler.is_compiling():
        # torch.compile is active
        return x + 1
    else:
        # torch.compile is not active
        return x - 1


inp = torch.zeros(3)

print(foo(inp))  # compiled, prints 1
tensor([1., 1., 1.])

Sample decorator usage

@torch.compiler.set_stance("force_eager")
def bar(x):
    # force disable the compiler
    return foo(x)


print(bar(inp))  # not compiled, prints -1
tensor([-1., -1., -1.])

Sample context manager usage

with torch.compiler.set_stance("force_eager"):
    print(foo(inp))  # not compiled, prints -1
tensor([-1., -1., -1.])

Sample raw function usage

torch.compiler.set_stance("force_eager")
print(foo(inp))  # not compiled, prints -1
torch.compiler.set_stance("default")

print(foo(inp))  # compiled, prints 1
tensor([-1., -1., -1.])
tensor([1., 1., 1.])

torch.compile stance can only be changed outside of any torch.compile region. Attempts to do otherwise will result in an error.

@torch.compile
def baz(x):
    # error!
    with torch.compiler.set_stance("force_eager"):
        return x + 1


try:
    baz(inp)
except Exception as e:
    print(e)


@torch.compiler.set_stance("force_eager")
def inner(x):
    return x + 1


@torch.compile
def outer(x):
    # error!
    return inner(x)


try:
    outer(inp)
except Exception as e:
    print(e)
Attempt to trace forbidden callable <function set_stance at 0x7fd08ee79870>

from user code:
   File "/var/lib/workspace/recipes_source/torch_compiler_set_stance_tutorial.py", line 85, in baz
    with torch.compiler.set_stance("force_eager"):

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Attempt to trace forbidden callable <function inner at 0x7fcf493fdf30>

from user code:
   File "/var/lib/workspace/recipes_source/torch_compiler_set_stance_tutorial.py", line 103, in outer
    return inner(x)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
Other stances include:
  • "default": The default stance, used for normal compilation.

  • "eager_on_recompile": Run code eagerly when a recompile is necessary. If there is cached compiled code valid for the input, it will still be used.

  • "fail_on_recompile": Raise an error when recompiling a function.

See the torch.compiler.set_stance doc page for more stances and options. More stances/options may also be added in the future.

Examples

Preventing recompilation

Some models do not expect any recompilations - for example, you may always have inputs with the same shape. Since recompilations may be expensive, we may wish to error out when we attempt to recompile so we can detect and fix recompilation cases. The "fail_on_recompilation" stance can be used for this.

@torch.compile
def my_big_model(x):
    return torch.relu(x)


# first compilation
my_big_model(torch.randn(3))

with torch.compiler.set_stance("fail_on_recompile"):
    my_big_model(torch.randn(3))  # no recompilation - OK
    try:
        my_big_model(torch.randn(4))  # recompilation - error
    except Exception as e:
        print(e)
Detected recompile when torch.compile stance is 'fail_on_recompile'

If erroring out is too disruptive, we can use "eager_on_recompile" instead, which will cause torch.compile to fall back to eager instead of erroring out. This may be useful if we don’t expect recompilations to happen frequently, but when one is required, we’d rather pay the cost of running eagerly over the cost of recompilation.

@torch.compile
def my_huge_model(x):
    if torch.compiler.is_compiling():
        return x + 1
    else:
        return x - 1


# first compilation
print(my_huge_model(torch.zeros(3)))  # 1

with torch.compiler.set_stance("eager_on_recompile"):
    print(my_huge_model(torch.zeros(3)))  # 1
    print(my_huge_model(torch.zeros(4)))  # -1
    print(my_huge_model(torch.zeros(3)))  # 1
tensor([1., 1., 1.])
tensor([1., 1., 1.])
tensor([-1., -1., -1., -1.])
tensor([1., 1., 1.])

Measuring performance gains

torch.compiler.set_stance can be used to compare eager vs. compiled performance without having to define a separate eager model.

# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000


@torch.compile
def my_gigantic_model(x, y):
    x = x @ y
    x = x @ y
    x = x @ y
    return x


inps = torch.randn(5, 5), torch.randn(5, 5)

with torch.compiler.set_stance("force_eager"):
    print("eager:", timed(lambda: my_gigantic_model(*inps))[1])

# warmups
for _ in range(3):
    my_gigantic_model(*inps)

print("compiled:", timed(lambda: my_gigantic_model(*inps))[1])
eager: 0.00014342400431632997
compiled: 0.0001496960073709488

Crashing sooner

Running an eager iteration first before a compiled iteration using the "force_eager" stance can help us to catch errors unrelated to torch.compile before attempting a very long compile.

@torch.compile
def my_humongous_model(x):
    return torch.sin(x, x)


try:
    with torch.compiler.set_stance("force_eager"):
        print(my_humongous_model(torch.randn(3)))
    # this call to the compiled model won't run
    print(my_humongous_model(torch.randn(3)))
except Exception as e:
    print(e)
sin() takes 1 positional argument but 2 were given

Conclusion

In this recipe, we have learned how to use the torch.compiler.set_stance API to modify the behavior of torch.compile across different calls to a model without needing to reapply it. The recipe demonstrates using torch.compiler.set_stance as a decorator, context manager, or raw function to control compilation stances like force_eager, default, eager_on_recompile, and “fail_on_recompile.”

For more information, see: torch.compiler.set_stance API documentation.

Total running time of the script: ( 0 minutes 13.378 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources