Note
Click here to download the full example code
Dynamic Compilation Control with torch.compiler.set_stance
¶
Author: William Wen
torch.compiler.set_stance
is a torch.compiler
API that
enables you to change the behavior of torch.compile
across different
calls to your model without having to reapply torch.compile
to your model.
This recipe provides some examples on how to use torch.compiler.set_stance
.
Prerequisites¶
torch >= 2.6
Description¶
torch.compile.set_stance
can be used as a decorator, context manager, or raw function
to change the behavior of torch.compile
across different calls to your model.
In the example below, the "force_eager"
stance ignores all torch.compile
directives.
import torch
@torch.compile
def foo(x):
if torch.compiler.is_compiling():
# torch.compile is active
return x + 1
else:
# torch.compile is not active
return x - 1
inp = torch.zeros(3)
print(foo(inp)) # compiled, prints 1
tensor([1., 1., 1.])
Sample decorator usage
@torch.compiler.set_stance("force_eager")
def bar(x):
# force disable the compiler
return foo(x)
print(bar(inp)) # not compiled, prints -1
tensor([-1., -1., -1.])
Sample context manager usage
with torch.compiler.set_stance("force_eager"):
print(foo(inp)) # not compiled, prints -1
tensor([-1., -1., -1.])
Sample raw function usage
torch.compiler.set_stance("force_eager")
print(foo(inp)) # not compiled, prints -1
torch.compiler.set_stance("default")
print(foo(inp)) # compiled, prints 1
tensor([-1., -1., -1.])
tensor([1., 1., 1.])
torch.compile
stance can only be changed outside of any torch.compile
region. Attempts
to do otherwise will result in an error.
@torch.compile
def baz(x):
# error!
with torch.compiler.set_stance("force_eager"):
return x + 1
try:
baz(inp)
except Exception as e:
print(e)
@torch.compiler.set_stance("force_eager")
def inner(x):
return x + 1
@torch.compile
def outer(x):
# error!
return inner(x)
try:
outer(inp)
except Exception as e:
print(e)
Attempt to trace forbidden callable <function set_stance at 0x7fd5e807d870>
from user code:
File "/var/lib/workspace/recipes_source/torch_compiler_set_stance_tutorial.py", line 85, in baz
with torch.compiler.set_stance("force_eager"):
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Attempt to trace forbidden callable <function inner at 0x7fd4a27bb0a0>
from user code:
File "/var/lib/workspace/recipes_source/torch_compiler_set_stance_tutorial.py", line 103, in outer
return inner(x)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
- Other stances include:
"default"
: The default stance, used for normal compilation."eager_on_recompile"
: Run code eagerly when a recompile is necessary. If there is cached compiled code valid for the input, it will still be used."fail_on_recompile"
: Raise an error when recompiling a function.
See the torch.compiler.set_stance
doc page
for more stances and options. More stances/options may also be added in the future.
Examples¶
Preventing recompilation¶
Some models do not expect any recompilations - for example, you may always have inputs with the same shape.
Since recompilations may be expensive, we may wish to error out when we attempt to recompile so we can detect and fix recompilation cases.
The "fail_on_recompilation"
stance can be used for this.
@torch.compile
def my_big_model(x):
return torch.relu(x)
# first compilation
my_big_model(torch.randn(3))
with torch.compiler.set_stance("fail_on_recompile"):
my_big_model(torch.randn(3)) # no recompilation - OK
try:
my_big_model(torch.randn(4)) # recompilation - error
except Exception as e:
print(e)
Detected recompile when torch.compile stance is 'fail_on_recompile'
If erroring out is too disruptive, we can use "eager_on_recompile"
instead,
which will cause torch.compile
to fall back to eager instead of erroring out.
This may be useful if we don’t expect recompilations to happen frequently, but
when one is required, we’d rather pay the cost of running eagerly over the cost of recompilation.
@torch.compile
def my_huge_model(x):
if torch.compiler.is_compiling():
return x + 1
else:
return x - 1
# first compilation
print(my_huge_model(torch.zeros(3))) # 1
with torch.compiler.set_stance("eager_on_recompile"):
print(my_huge_model(torch.zeros(3))) # 1
print(my_huge_model(torch.zeros(4))) # -1
print(my_huge_model(torch.zeros(3))) # 1
tensor([1., 1., 1.])
tensor([1., 1., 1.])
tensor([-1., -1., -1., -1.])
tensor([1., 1., 1.])
Measuring performance gains¶
torch.compiler.set_stance
can be used to compare eager vs. compiled performance
without having to define a separate eager model.
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1000
@torch.compile
def my_gigantic_model(x, y):
x = x @ y
x = x @ y
x = x @ y
return x
inps = torch.randn(5, 5), torch.randn(5, 5)
with torch.compiler.set_stance("force_eager"):
print("eager:", timed(lambda: my_gigantic_model(*inps))[1])
# warmups
for _ in range(3):
my_gigantic_model(*inps)
print("compiled:", timed(lambda: my_gigantic_model(*inps))[1])
eager: 0.00016115200519561766
compiled: 0.00016368000209331514
Crashing sooner¶
Running an eager iteration first before a compiled iteration using the "force_eager"
stance
can help us to catch errors unrelated to torch.compile
before attempting a very long compile.
@torch.compile
def my_humongous_model(x):
return torch.sin(x, x)
try:
with torch.compiler.set_stance("force_eager"):
print(my_humongous_model(torch.randn(3)))
# this call to the compiled model won't run
print(my_humongous_model(torch.randn(3)))
except Exception as e:
print(e)
sin() takes 1 positional argument but 2 were given
Conclusion¶
In this recipe, we have learned how to use the torch.compiler.set_stance
API
to modify the behavior of torch.compile
across different calls to a model
without needing to reapply it. The recipe demonstrates using
torch.compiler.set_stance
as a decorator, context manager, or raw function
to control compilation stances like force_eager
, default
,
eager_on_recompile
, and “fail_on_recompile.”
For more information, see: torch.compiler.set_stance API documentation.
Total running time of the script: ( 0 minutes 13.738 seconds)