Source code for torch.xpu
# mypy: allow-untyped-defs
r"""
This package introduces support for the XPU backend, specifically tailored for
Intel GPU optimization.
This package is lazily initialized, so you can always import it, and use
:func:`is_available()` to determine if your system supports XPU.
"""
import threading
import traceback
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch._C
from torch import device as _device
from torch._utils import _dummy_type, _LazySeedTracker
from ._utils import _get_device_index
from .streams import Event, Stream
_initialized = False
_tls = threading.local()
_initialization_lock = threading.Lock()
_queued_calls: List[
Tuple[Callable[[], None], List[str]]
] = [] # don't invoke these until initialization occurs
_is_in_bad_fork = getattr(torch._C, "_xpu_isInBadFork", lambda: False)
_device_t = Union[_device, str, int, None]
_lazy_seed_tracker = _LazySeedTracker()
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]
def _is_compiled() -> bool:
r"""Return true if compile with XPU support."""
return torch._C._has_xpu
if _is_compiled():
_XpuDeviceProperties = torch._C._XpuDeviceProperties
_exchange_device = torch._C._xpu_exchangeDevice
_maybe_exchange_device = torch._C._xpu_maybeExchangeDevice
else:
# Define dummy if PyTorch was compiled without XPU
_XpuDeviceProperties = _dummy_type("_XpuDeviceProperties") # type: ignore[assignment, misc]
def _exchange_device(device: int) -> int:
raise NotImplementedError("PyTorch was compiled without XPU support")
def _maybe_exchange_device(device: int) -> int:
raise NotImplementedError("PyTorch was compiled without XPU support")
[docs]@lru_cache(maxsize=1)
def device_count() -> int:
r"""Return the number of XPU device available."""
if not _is_compiled():
return 0
return torch._C._xpu_getDeviceCount()
[docs]def is_available() -> bool:
r"""Return a bool indicating if XPU is currently available."""
# This function nerver throws.
return device_count() > 0
def is_bf16_supported():
r"""Return a bool indicating if the current XPU device supports dtype bfloat16."""
return True
[docs]def is_initialized():
r"""Return whether PyTorch's XPU state has been initialized."""
return _initialized and not _is_in_bad_fork()
def _lazy_call(callable, **kwargs):
if is_initialized():
callable()
else:
global _lazy_seed_tracker
if kwargs.get("seed_all", False):
_lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
elif kwargs.get("seed", False):
_lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
else:
# Don't store the actual traceback to avoid memory cycle
_queued_calls.append((callable, traceback.format_stack()))
[docs]def init():
r"""Initialize PyTorch's XPU state.
This is a Python API about lazy initialization that avoids initializing
XPU until the first time it is accessed. Does nothing if the XPU state is
already initialized.
"""
_lazy_init()
def _lazy_init():
global _initialized, _queued_calls
if is_initialized() or hasattr(_tls, "is_initializing"):
return
with _initialization_lock:
# This test was was protected via GIL. Double-check whether XPU has
# already been initialized.
if is_initialized():
return
# Stop promptly upon encountering a bad fork error.
if _is_in_bad_fork():
raise RuntimeError(
"Cannot re-initialize XPU in forked subprocess. To use XPU with "
"multiprocessing, you must use the 'spawn' start method"
)
if not _is_compiled():
raise AssertionError("Torch not compiled with XPU enabled")
# This function inits XPU backend and detects bad fork processing.
torch._C._xpu_init()
# Some of the queued calls may reentrantly call _lazy_init(); We need to
# just return without initializing in that case.
_tls.is_initializing = True
_queued_calls.extend(calls for calls in _lazy_seed_tracker.get_calls() if calls)
try:
for queued_call, orig_traceback in _queued_calls:
try:
queued_call()
except Exception as e:
msg = (
f"XPU call failed lazily at initialization with error: {str(e)}\n\n"
f"XPU call was originally invoked at:\n\n{''.join(orig_traceback)}"
)
raise Exception(msg) from e # noqa: TRY002
finally:
delattr(_tls, "is_initializing")
_initialized = True
class _DeviceGuard:
def __init__(self, index: int):
self.idx = index
self.prev_idx = -1
def __enter__(self):
self.prev_idx = torch.xpu._exchange_device(self.idx)
def __exit__(self, type: Any, value: Any, traceback: Any):
self.idx = torch.xpu._maybe_exchange_device(self.prev_idx)
return False
[docs]class device:
r"""Context-manager that changes the selected device.
Args:
device (torch.device or int or str): device index to select. It's a no-op if
this argument is a negative integer or ``None``.
"""
def __init__(self, device: Any):
self.idx = _get_device_index(device, optional=True)
self.prev_idx = -1
def __enter__(self):
self.prev_idx = torch.xpu._exchange_device(self.idx)
def __exit__(self, type: Any, value: Any, traceback: Any):
self.idx = torch.xpu._maybe_exchange_device(self.prev_idx)
return False
[docs]class device_of(device):
r"""Context-manager that changes the current device to that of given object.
You can use both tensors and storages as arguments. If a given object is
not allocated on a XPU, this is a no-op.
Args:
obj (Tensor or Storage): object allocated on the selected device.
"""
def __init__(self, obj):
idx = obj.get_device() if obj.is_xpu else -1
super().__init__(idx)
[docs]def set_device(device: _device_t) -> None:
r"""Set the current device.
Args:
device (torch.device or int or str): selected device. This function is a
no-op if this argument is negative.
"""
_lazy_init()
device = _get_device_index(device)
if device >= 0:
torch._C._xpu_setDevice(device)
[docs]def get_device_name(device: Optional[_device_t] = None) -> str:
r"""Get the name of a device.
Args:
device (torch.device or int or str, optional): device for which to
return the name. This function is a no-op if this argument is a
negative integer. It uses the current device, given by :func:`~torch.xpu.current_device`,
if :attr:`device` is ``None`` (default).
Returns:
str: the name of the device
"""
return get_device_properties(device).name
[docs]@lru_cache(None)
def get_device_capability(device: Optional[_device_t] = None) -> Dict[str, Any]:
r"""Get the xpu capability of a device.
Args:
device (torch.device or int or str, optional): device for which to
return the device capability. This function is a no-op if this
argument is a negative integer. It uses the current device, given by
:func:`~torch.xpu.current_device`, if :attr:`device` is ``None``
(default).
Returns:
Dict[str, Any]: the xpu capability dictionary of the device
"""
props = get_device_properties(device)
# pybind service attributes are no longer needed and their presence breaks
# the further logic related to the serialization of the created dictionary.
# In particular it filters out `<bound method PyCapsule._pybind11_conduit_v1_ of _XpuDeviceProperties..>`
# to fix Triton tests.
# This field appears after updating pybind to 2.13.6.
return {
prop: getattr(props, prop)
for prop in dir(props)
if not prop.startswith(("__", "_pybind11_"))
}
[docs]def get_device_properties(device: Optional[_device_t] = None) -> _XpuDeviceProperties:
r"""Get the properties of a device.
Args:
device (torch.device or int or str): device for which to return the
properties of the device.
Returns:
_XpuDeviceProperties: the properties of the device
"""
_lazy_init()
device = _get_device_index(device, optional=True)
if device < 0 or device >= device_count():
raise AssertionError("Invalid device index")
return _get_device_properties(device) # type: ignore[name-defined] # noqa: F821
[docs]def current_device() -> int:
r"""Return the index of a currently selected device."""
_lazy_init()
return torch._C._xpu_getDevice()
def _get_device(device: Union[int, str, torch.device]) -> torch.device:
r"""Return the torch.device type object from the passed in device.
Args:
device (torch.device or int or str): selected device.
"""
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("xpu", device)
return device
[docs]class StreamContext:
r"""Context-manager that selects a given stream.
All XPU kernels queued within its context will be enqueued on a selected
stream.
Args:
Stream (Stream): selected stream. This manager is a no-op if it's
``None``.
.. note:: Streams are per-device.
"""
cur_stream: Optional["torch.xpu.Stream"]
def __init__(self, stream: Optional["torch.xpu.Stream"]):
self.stream = stream
self.idx = _get_device_index(None, True)
if self.idx is None:
self.idx = -1
def __enter__(self):
cur_stream = self.stream
if cur_stream is None or self.idx == -1:
return
self.src_prev_stream = torch.xpu.current_stream(None)
# If the stream is not on the current device, then set the current stream on the device
if self.src_prev_stream.device != cur_stream.device:
with device(cur_stream.device):
self.dst_prev_stream = torch.xpu.current_stream(cur_stream.device)
torch.xpu.set_stream(cur_stream)
def __exit__(self, type: Any, value: Any, traceback: Any):
cur_stream = self.stream
if cur_stream is None or self.idx == -1:
return
# Reset the stream on the original device and destination device
if self.src_prev_stream.device != cur_stream.device:
torch.xpu.set_stream(self.dst_prev_stream)
torch.xpu.set_stream(self.src_prev_stream)
[docs]def stream(stream: Optional["torch.xpu.Stream"]) -> StreamContext:
r"""Wrap around the Context-manager StreamContext that selects a given stream.
Arguments:
stream (Stream): selected stream. This manager is a no-op if it's ``None``.
"""
return StreamContext(stream)
def _set_stream_by_id(stream_id, device_index, device_type):
r"""set stream specified by the stream id, device index and device type
Args: stream_id (int): not visible to the user, used to assigned to the specific stream.
device_index (int): selected device index.
device_type (int): selected device type.
"""
torch._C._xpu_setStream(
stream_id=stream_id,
device_index=device_index,
device_type=device_type,
)
[docs]def set_stream(stream: Stream):
r"""Set the current stream.This is a wrapper API to set the stream.
Usage of this function is discouraged in favor of the ``stream``
context manager.
Args:
stream (Stream): selected stream. This function is a no-op
if this argument is ``None``.
"""
if stream is None:
return
_lazy_init()
_set_stream_by_id(
stream_id=stream.stream_id,
device_index=stream.device_index,
device_type=stream.device_type,
)
[docs]def current_stream(device: Optional[_device_t] = None) -> Stream:
r"""Return the currently selected :class:`Stream` for a given device.
Args:
device (torch.device or int, optional): selected device. Returns
the currently selected :class:`Stream` for the current device, given
by :func:`~torch.xpu.current_device`, if :attr:`device` is ``None``
(default).
"""
_lazy_init()
streamdata = torch._C._xpu_getCurrentStream(
_get_device_index(device, optional=True)
)
return Stream(
stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
)
[docs]def synchronize(device: _device_t = None) -> None:
r"""Wait for all kernels in all streams on a XPU device to complete.
Args:
device (torch.device or int, optional): device for which to synchronize.
It uses the current device, given by :func:`~torch.xpu.current_device`,
if :attr:`device` is ``None`` (default).
"""
_lazy_init()
device = _get_device_index(device, optional=True)
return torch._C._xpu_synchronize(device)
[docs]def get_arch_list() -> List[str]:
r"""Return list XPU architectures this library was compiled for."""
if not is_available():
return []
arch_flags = torch._C._xpu_getArchFlags()
if arch_flags is None:
return []
return arch_flags.split()
[docs]def get_gencode_flags() -> str:
r"""Return XPU AOT(ahead-of-time) build flags this library was compiled with."""
arch_list = get_arch_list()
if len(arch_list) == 0:
return ""
return f'-device {",".join(arch for arch in arch_list)}'
def _get_generator(device: torch.device) -> torch._C.Generator:
r"""Return the XPU Generator object for the given device.
Args:
device (torch.device): selected device.
"""
idx = device.index
if idx is None:
idx = current_device()
return torch.xpu.default_generators[idx]
def _set_rng_state_offset(
offset: int, device: Union[int, str, torch.device] = "xpu"
) -> None:
r"""Set the random number generator state offset of the specified GPU.
Args:
offset (int): The desired offset
device (torch.device or int, optional): The device to set the RNG state.
Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
"""
final_device = _get_device(device)
def cb():
default_generator = _get_generator(final_device)
default_generator.set_offset(offset)
_lazy_call(cb)
def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int:
r"""Return the random number generator state offset of the specified GPU.
Args:
device (torch.device or int, optional): The device to return the RNG state offset of.
Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
.. warning::
This function eagerly initializes XPU.
"""
_lazy_init()
final_device = _get_device(device)
default_generator = _get_generator(final_device)
return default_generator.get_offset()
# import here to avoid circular import
from .memory import (
empty_cache,
max_memory_allocated,
max_memory_reserved,
memory_allocated,
memory_reserved,
memory_stats,
memory_stats_as_nested_dict,
reset_accumulated_memory_stats,
reset_peak_memory_stats,
)
from .random import (
get_rng_state,
get_rng_state_all,
initial_seed,
manual_seed,
manual_seed_all,
seed,
seed_all,
set_rng_state,
set_rng_state_all,
)
__all__ = [
"Event",
"Stream",
"StreamContext",
"current_device",
"current_stream",
"default_generators",
"device",
"device_of",
"device_count",
"empty_cache",
"get_arch_list",
"get_device_capability",
"get_device_name",
"get_device_properties",
"get_gencode_flags",
"get_rng_state",
"get_rng_state_all",
"get_stream",
"init",
"initial_seed",
"is_available",
"is_bf16_supported",
"is_initialized",
"manual_seed",
"manual_seed_all",
"max_memory_allocated",
"max_memory_reserved",
"memory_allocated",
"memory_reserved",
"memory_stats",
"memory_stats_as_nested_dict",
"reset_accumulated_memory_stats",
"reset_peak_memory_stats",
"seed",
"seed_all",
"set_device",
"set_rng_state",
"set_rng_state_all",
"set_stream",
"stream",
"streams",
"synchronize",
]