Shortcuts

Note

This page describes an internal API which is not intended to be used outside of the PyTorch codebase and can be modified or removed without notice.

Source code for torch._dynamo

from . import allowed_functions, convert_frame, eval_frame, resume_execution
from .backends.registry import list_backends, register_backend
from .convert_frame import replay
from .eval_frame import (
    assume_constant_result,
    disable,
    explain,
    export,
    optimize,
    optimize_assert,
    OptimizedModule,
    reset_code,
    run,
    skip,
)
from .external_utils import is_compiling
from .utils import compilation_metrics, guard_failures, orig_code_map, reset_frame_count

__all__ = [
    "allow_in_graph",
    "assume_constant_result",
    "disallow_in_graph",
    "graph_break",
    "optimize",
    "optimize_assert",
    "export",
    "explain",
    "run",
    "replay",
    "disable",
    "reset",
    "skip",
    "OptimizedModule",
    "is_compiling",
    "register_backend",
    "list_backends",
]


[docs]def reset(): """Clear all compile caches and restore initial state""" for weak_code in convert_frame.input_codes.seen + convert_frame.output_codes.seen: code = weak_code() if code: reset_code(code) convert_frame.input_codes.clear() convert_frame.output_codes.clear() orig_code_map.clear() guard_failures.clear() resume_execution.ContinueExecutionCache.cache.clear() eval_frame.most_recent_backend = None compilation_metrics.clear() reset_frame_count()
[docs]def allow_in_graph(fn): """ Customize which functions TorchDynamo will include in the generated graph. Similar to `torch.fx.wrap()`. :: torch._dynamo.allow_in_graph(my_custom_function) @torch._dynamo.optimize(...) def fn(a): x = torch.add(x, 1) x = my_custom_function(x) x = torch.add(x, 1) return x fn(...) Will capture a single graph containing `my_custom_function()`. """ if isinstance(fn, (list, tuple)): return [allow_in_graph(x) for x in fn] assert callable(fn), "allow_in_graph expects a callable" allowed_functions._allowed_function_ids.add(id(fn)) allowed_functions._disallowed_function_ids.remove(id(fn)) return fn
[docs]def disallow_in_graph(fn): """ Customize which functions TorchDynamo will exclude in the generated graph and force a graph break on. :: torch._dynamo.disallow_in_graph(torch.sub) @torch._dynamo.optimize(...) def fn(a): x = torch.add(x, 1) x = torch.sub(x, 1) x = torch.add(x, 1) return x fn(...) Will break the graph on `torch.sub`, and give two graphs each with a single `torch.add()` op. """ if isinstance(fn, (list, tuple)): return [disallow_in_graph(x) for x in fn] assert callable(fn), "disallow_in_graph expects a callable" allowed_functions._allowed_function_ids.remove(id(fn)) allowed_functions._disallowed_function_ids.add(id(fn)) return fn
[docs]@disallow_in_graph def graph_break(): """Force a graph break""" pass

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources