Shortcuts

Source code for torch.xpu

# mypy: allow-untyped-defs
r"""
This package introduces support for the XPU backend, specifically tailored for
Intel GPU optimization.

This package is lazily initialized, so you can always import it, and use
:func:`is_available()` to determine if your system supports XPU.
"""
import threading
import traceback
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch._C
from torch import device as _device
from torch._utils import _dummy_type, _LazySeedTracker

from ._utils import _get_device_index
from .streams import Event, Stream


_initialized = False
_tls = threading.local()
_initialization_lock = threading.Lock()
_queued_calls: List[
    Tuple[Callable[[], None], List[str]]
] = []  # don't invoke these until initialization occurs
_is_in_bad_fork = getattr(torch._C, "_xpu_isInBadFork", lambda: False)
_device_t = Union[_device, str, int, None]
_lazy_seed_tracker = _LazySeedTracker()
default_generators: Tuple[torch._C.Generator] = ()  # type: ignore[assignment]


def _is_compiled() -> bool:
    r"""Return true if compile with XPU support."""
    return torch._C._has_xpu


if _is_compiled():
    _XpuDeviceProperties = torch._C._XpuDeviceProperties
    _exchange_device = torch._C._xpu_exchangeDevice
    _maybe_exchange_device = torch._C._xpu_maybeExchangeDevice
else:
    # Define dummy if PyTorch was compiled without XPU
    _XpuDeviceProperties = _dummy_type("_XpuDeviceProperties")  # type: ignore[assignment, misc]

    def _exchange_device(device: int) -> int:
        raise NotImplementedError("PyTorch was compiled without XPU support")

    def _maybe_exchange_device(device: int) -> int:
        raise NotImplementedError("PyTorch was compiled without XPU support")


[docs]@lru_cache(maxsize=1) def device_count() -> int: r"""Return the number of XPU device available.""" if not _is_compiled(): return 0 return torch._C._xpu_getDeviceCount()
[docs]def is_available() -> bool: r"""Return a bool indicating if XPU is currently available.""" # This function nerver throws. return device_count() > 0
def is_bf16_supported(): r"""Return a bool indicating if the current XPU device supports dtype bfloat16.""" return True
[docs]def is_initialized(): r"""Return whether PyTorch's XPU state has been initialized.""" return _initialized and not _is_in_bad_fork()
def _lazy_call(callable, **kwargs): if is_initialized(): callable() else: global _lazy_seed_tracker if kwargs.get("seed_all", False): _lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack()) elif kwargs.get("seed", False): _lazy_seed_tracker.queue_seed(callable, traceback.format_stack()) else: # Don't store the actual traceback to avoid memory cycle _queued_calls.append((callable, traceback.format_stack()))
[docs]def init(): r"""Initialize PyTorch's XPU state. This is a Python API about lazy initialization that avoids initializing XPU until the first time it is accessed. Does nothing if the XPU state is already initialized. """ _lazy_init()
def _lazy_init(): global _initialized, _queued_calls if is_initialized() or hasattr(_tls, "is_initializing"): return with _initialization_lock: # This test was was protected via GIL. Double-check whether XPU has # already been initialized. if is_initialized(): return # Stop promptly upon encountering a bad fork error. if _is_in_bad_fork(): raise RuntimeError( "Cannot re-initialize XPU in forked subprocess. To use XPU with " "multiprocessing, you must use the 'spawn' start method" ) if not _is_compiled(): raise AssertionError("Torch not compiled with XPU enabled") # This function inits XPU backend and detects bad fork processing. torch._C._xpu_init() # Some of the queued calls may reentrantly call _lazy_init(); We need to # just return without initializing in that case. _tls.is_initializing = True _queued_calls.extend(calls for calls in _lazy_seed_tracker.get_calls() if calls) try: for queued_call, orig_traceback in _queued_calls: try: queued_call() except Exception as e: msg = ( f"XPU call failed lazily at initialization with error: {str(e)}\n\n" f"XPU call was originally invoked at:\n\n{''.join(orig_traceback)}" ) raise Exception(msg) from e # noqa: TRY002 finally: delattr(_tls, "is_initializing") _initialized = True class _DeviceGuard: def __init__(self, index: int): self.idx = index self.prev_idx = -1 def __enter__(self): self.prev_idx = torch.xpu._exchange_device(self.idx) def __exit__(self, type: Any, value: Any, traceback: Any): self.idx = torch.xpu._maybe_exchange_device(self.prev_idx) return False
[docs]class device: r"""Context-manager that changes the selected device. Args: device (torch.device or int or str): device index to select. It's a no-op if this argument is a negative integer or ``None``. """ def __init__(self, device: Any): self.idx = _get_device_index(device, optional=True) self.prev_idx = -1 def __enter__(self): self.prev_idx = torch.xpu._exchange_device(self.idx) def __exit__(self, type: Any, value: Any, traceback: Any): self.idx = torch.xpu._maybe_exchange_device(self.prev_idx) return False
[docs]class device_of(device): r"""Context-manager that changes the current device to that of given object. You can use both tensors and storages as arguments. If a given object is not allocated on a XPU, this is a no-op. Args: obj (Tensor or Storage): object allocated on the selected device. """ def __init__(self, obj): idx = obj.get_device() if obj.is_xpu else -1 super().__init__(idx)
[docs]def set_device(device: _device_t) -> None: r"""Set the current device. Args: device (torch.device or int or str): selected device. This function is a no-op if this argument is negative. """ _lazy_init() device = _get_device_index(device) if device >= 0: torch._C._xpu_setDevice(device)
[docs]def get_device_name(device: Optional[_device_t] = None) -> str: r"""Get the name of a device. Args: device (torch.device or int or str, optional): device for which to return the name. This function is a no-op if this argument is a negative integer. It uses the current device, given by :func:`~torch.xpu.current_device`, if :attr:`device` is ``None`` (default). Returns: str: the name of the device """ return get_device_properties(device).name
[docs]@lru_cache(None) def get_device_capability(device: Optional[_device_t] = None) -> Dict[str, Any]: r"""Get the xpu capability of a device. Args: device (torch.device or int or str, optional): device for which to return the device capability. This function is a no-op if this argument is a negative integer. It uses the current device, given by :func:`~torch.xpu.current_device`, if :attr:`device` is ``None`` (default). Returns: Dict[str, Any]: the xpu capability dictionary of the device """ props = get_device_properties(device) # pybind service attributes are no longer needed and their presence breaks # the further logic related to the serialization of the created dictionary. # In particular it filters out `<bound method PyCapsule._pybind11_conduit_v1_ of _XpuDeviceProperties..>` # to fix Triton tests. # This field appears after updating pybind to 2.13.6. return { prop: getattr(props, prop) for prop in dir(props) if not prop.startswith(("__", "_pybind11_")) }
[docs]def get_device_properties(device: Optional[_device_t] = None) -> _XpuDeviceProperties: r"""Get the properties of a device. Args: device (torch.device or int or str): device for which to return the properties of the device. Returns: _XpuDeviceProperties: the properties of the device """ _lazy_init() device = _get_device_index(device, optional=True) if device < 0 or device >= device_count(): raise AssertionError("Invalid device index") return _get_device_properties(device) # type: ignore[name-defined] # noqa: F821
[docs]def current_device() -> int: r"""Return the index of a currently selected device.""" _lazy_init() return torch._C._xpu_getDevice()
def _get_device(device: Union[int, str, torch.device]) -> torch.device: r"""Return the torch.device type object from the passed in device. Args: device (torch.device or int or str): selected device. """ if isinstance(device, str): device = torch.device(device) elif isinstance(device, int): device = torch.device("xpu", device) return device
[docs]class StreamContext: r"""Context-manager that selects a given stream. All XPU kernels queued within its context will be enqueued on a selected stream. Args: Stream (Stream): selected stream. This manager is a no-op if it's ``None``. .. note:: Streams are per-device. """ cur_stream: Optional["torch.xpu.Stream"] def __init__(self, stream: Optional["torch.xpu.Stream"]): self.stream = stream self.idx = _get_device_index(None, True) if self.idx is None: self.idx = -1 def __enter__(self): cur_stream = self.stream if cur_stream is None or self.idx == -1: return self.src_prev_stream = torch.xpu.current_stream(None) # If the stream is not on the current device, then set the current stream on the device if self.src_prev_stream.device != cur_stream.device: with device(cur_stream.device): self.dst_prev_stream = torch.xpu.current_stream(cur_stream.device) torch.xpu.set_stream(cur_stream) def __exit__(self, type: Any, value: Any, traceback: Any): cur_stream = self.stream if cur_stream is None or self.idx == -1: return # Reset the stream on the original device and destination device if self.src_prev_stream.device != cur_stream.device: torch.xpu.set_stream(self.dst_prev_stream) torch.xpu.set_stream(self.src_prev_stream)
[docs]def stream(stream: Optional["torch.xpu.Stream"]) -> StreamContext: r"""Wrap around the Context-manager StreamContext that selects a given stream. Arguments: stream (Stream): selected stream. This manager is a no-op if it's ``None``. """ return StreamContext(stream)
def _set_stream_by_id(stream_id, device_index, device_type): r"""set stream specified by the stream id, device index and device type Args: stream_id (int): not visible to the user, used to assigned to the specific stream. device_index (int): selected device index. device_type (int): selected device type. """ torch._C._xpu_setStream( stream_id=stream_id, device_index=device_index, device_type=device_type, )
[docs]def set_stream(stream: Stream): r"""Set the current stream.This is a wrapper API to set the stream. Usage of this function is discouraged in favor of the ``stream`` context manager. Args: stream (Stream): selected stream. This function is a no-op if this argument is ``None``. """ if stream is None: return _lazy_init() _set_stream_by_id( stream_id=stream.stream_id, device_index=stream.device_index, device_type=stream.device_type, )
[docs]def current_stream(device: Optional[_device_t] = None) -> Stream: r"""Return the currently selected :class:`Stream` for a given device. Args: device (torch.device or int, optional): selected device. Returns the currently selected :class:`Stream` for the current device, given by :func:`~torch.xpu.current_device`, if :attr:`device` is ``None`` (default). """ _lazy_init() streamdata = torch._C._xpu_getCurrentStream( _get_device_index(device, optional=True) ) return Stream( stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2] )
[docs]def synchronize(device: _device_t = None) -> None: r"""Wait for all kernels in all streams on a XPU device to complete. Args: device (torch.device or int, optional): device for which to synchronize. It uses the current device, given by :func:`~torch.xpu.current_device`, if :attr:`device` is ``None`` (default). """ _lazy_init() device = _get_device_index(device, optional=True) return torch._C._xpu_synchronize(device)
[docs]def get_arch_list() -> List[str]: r"""Return list XPU architectures this library was compiled for.""" if not is_available(): return [] arch_flags = torch._C._xpu_getArchFlags() if arch_flags is None: return [] return arch_flags.split()
[docs]def get_gencode_flags() -> str: r"""Return XPU AOT(ahead-of-time) build flags this library was compiled with.""" arch_list = get_arch_list() if len(arch_list) == 0: return "" return f'-device {",".join(arch for arch in arch_list)}'
def _get_generator(device: torch.device) -> torch._C.Generator: r"""Return the XPU Generator object for the given device. Args: device (torch.device): selected device. """ idx = device.index if idx is None: idx = current_device() return torch.xpu.default_generators[idx] def _set_rng_state_offset( offset: int, device: Union[int, str, torch.device] = "xpu" ) -> None: r"""Set the random number generator state offset of the specified GPU. Args: offset (int): The desired offset device (torch.device or int, optional): The device to set the RNG state. Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device). """ final_device = _get_device(device) def cb(): default_generator = _get_generator(final_device) default_generator.set_offset(offset) _lazy_call(cb) def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int: r"""Return the random number generator state offset of the specified GPU. Args: device (torch.device or int, optional): The device to return the RNG state offset of. Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device). .. warning:: This function eagerly initializes XPU. """ _lazy_init() final_device = _get_device(device) default_generator = _get_generator(final_device) return default_generator.get_offset() # import here to avoid circular import from .memory import ( empty_cache, max_memory_allocated, max_memory_reserved, memory_allocated, memory_reserved, memory_stats, memory_stats_as_nested_dict, reset_accumulated_memory_stats, reset_peak_memory_stats, ) from .random import ( get_rng_state, get_rng_state_all, initial_seed, manual_seed, manual_seed_all, seed, seed_all, set_rng_state, set_rng_state_all, ) __all__ = [ "Event", "Stream", "StreamContext", "current_device", "current_stream", "default_generators", "device", "device_of", "device_count", "empty_cache", "get_arch_list", "get_device_capability", "get_device_name", "get_device_properties", "get_gencode_flags", "get_rng_state", "get_rng_state_all", "get_stream", "init", "initial_seed", "is_available", "is_bf16_supported", "is_initialized", "manual_seed", "manual_seed_all", "max_memory_allocated", "max_memory_reserved", "memory_allocated", "memory_reserved", "memory_stats", "memory_stats_as_nested_dict", "reset_accumulated_memory_stats", "reset_peak_memory_stats", "seed", "seed_all", "set_device", "set_rng_state", "set_rng_state_all", "set_stream", "stream", "streams", "synchronize", ]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources