Note
This page describes an internal API which is not intended to be used outside of the PyTorch codebase and can be modified or removed without notice.
Source code for torch._dynamo
from . import allowed_functions, convert_frame, eval_frame, resume_execution
from .backends.registry import list_backends, register_backend
from .convert_frame import replay
from .eval_frame import (
assume_constant_result,
disable,
explain,
export,
optimize,
optimize_assert,
OptimizedModule,
reset_code,
run,
skip,
)
from .external_utils import is_compiling
from .utils import compilation_metrics, guard_failures, orig_code_map, reset_frame_count
__all__ = [
"allow_in_graph",
"assume_constant_result",
"disallow_in_graph",
"graph_break",
"optimize",
"optimize_assert",
"export",
"explain",
"run",
"replay",
"disable",
"reset",
"skip",
"OptimizedModule",
"is_compiling",
"register_backend",
"list_backends",
]
[docs]def reset():
"""Clear all compile caches and restore initial state"""
for weak_code in convert_frame.input_codes.seen + convert_frame.output_codes.seen:
code = weak_code()
if code:
reset_code(code)
convert_frame.input_codes.clear()
convert_frame.output_codes.clear()
orig_code_map.clear()
guard_failures.clear()
resume_execution.ContinueExecutionCache.cache.clear()
eval_frame.most_recent_backend = None
compilation_metrics.clear()
reset_frame_count()
[docs]def allow_in_graph(fn):
"""
Customize which functions TorchDynamo will include in the generated
graph. Similar to `torch.fx.wrap()`.
::
torch._dynamo.allow_in_graph(my_custom_function)
@torch._dynamo.optimize(...)
def fn(a):
x = torch.add(x, 1)
x = my_custom_function(x)
x = torch.add(x, 1)
return x
fn(...)
Will capture a single graph containing `my_custom_function()`.
"""
if isinstance(fn, (list, tuple)):
return [allow_in_graph(x) for x in fn]
assert callable(fn), "allow_in_graph expects a callable"
allowed_functions._allowed_function_ids.add(id(fn))
allowed_functions._disallowed_function_ids.remove(id(fn))
return fn
[docs]def disallow_in_graph(fn):
"""
Customize which functions TorchDynamo will exclude in the generated
graph and force a graph break on.
::
torch._dynamo.disallow_in_graph(torch.sub)
@torch._dynamo.optimize(...)
def fn(a):
x = torch.add(x, 1)
x = torch.sub(x, 1)
x = torch.add(x, 1)
return x
fn(...)
Will break the graph on `torch.sub`, and give two graphs each with a
single `torch.add()` op.
"""
if isinstance(fn, (list, tuple)):
return [disallow_in_graph(x) for x in fn]
assert callable(fn), "disallow_in_graph expects a callable"
allowed_functions._allowed_function_ids.remove(id(fn))
allowed_functions._disallowed_function_ids.add(id(fn))
return fn