Source code for torch_xla.experimental.eager

import functools
from contextlib import contextmanager

import torch_xla

[docs]def eager_mode(enable: bool): """Configure torch_xla's default executation mode. Under eager mode only functions that was `torch_xla.compile`d will be traced and compiled. Other torch ops will be executed eagerly. """ torch_xla._XLAC._set_use_eager_mode(enable)
def is_eager_mode() -> bool: """Return True if torch_xla is currently under eager mode """ return torch_xla._XLAC._get_use_eager_mode() @contextmanager def eager_mode_context(enable: bool): """Context manager to enable/disable the eager mode. """ saved_eager_mode = is_eager_mode() eager_mode(enable) try: yield saved_eager_mode finally: eager_mode(saved_eager_mode)
[docs]def compile(func): """Compile the func with Lazy Tensor. Return the optimized function that takes exact same input. Compile will run the target func under the tracing mode using Lazy tensor. """ @functools.wraps(func) # Keep function's name, docstring, etc. def wrapper(*args, **kwargs): # compile should only be called with assert torch_xla._XLAC._get_use_eager_mode() == True torch_xla._XLAC._set_use_eager_mode(False) # clear the pending graph if any torch_xla.sync() try: # Target Function Execution result = func(*args, **kwargs) # Sync the graph generated by the target function. torch_xla.sync() except Exception as e: # Handle exceptions (if needed) print(f"Error in target function: {e}") raise # Re-raise the exception torch_xla._XLAC._set_use_eager_mode(True) return result return wrapper


