Source code for torch_xla.experimental.eager
import functools
from contextlib import contextmanager
import torch_xla
[docs]def eager_mode(enable: bool):
"""Configure torch_xla's default executation mode.
Under eager mode only functions that was `torch_xla.compile`d will be
traced and compiled. Other torch ops will be executed eagerly.
"""
torch_xla._XLAC._set_use_eager_mode(enable)
def is_eager_mode() -> bool:
"""Return True if torch_xla is currently under eager mode
"""
return torch_xla._XLAC._get_use_eager_mode()
@contextmanager
def eager_mode_context(enable: bool):
"""Context manager to enable/disable the eager mode.
"""
saved_eager_mode = is_eager_mode()
eager_mode(enable)
try:
yield saved_eager_mode
finally:
eager_mode(saved_eager_mode)
[docs]def compile(func):
"""Compile the func with Lazy Tensor.
Return the optimized function that takes exact same input. Compile will
run the target func under the tracing mode using Lazy tensor.
"""
@functools.wraps(func) # Keep function's name, docstring, etc.
def wrapper(*args, **kwargs):
# compile should only be called with
assert torch_xla._XLAC._get_use_eager_mode() == True
torch_xla._XLAC._set_use_eager_mode(False)
# clear the pending graph if any
torch_xla.sync()
try:
# Target Function Execution
result = func(*args, **kwargs)
# Sync the graph generated by the target function.
torch_xla.sync()
except Exception as e:
# Handle exceptions (if needed)
print(f"Error in target function: {e}")
raise # Re-raise the exception
torch_xla._XLAC._set_use_eager_mode(True)
return result
return wrapper