Source code for torch.mps
# mypy: allow-untyped-defs
r"""
This package enables an interface for accessing MPS (Metal Performance Shaders) backend in Python.
Metal is Apple's API for programming metal GPU (graphics processor unit). Using MPS means that increased
performance can be achieved, by running work on the metal GPU(s).
See https://developer.apple.com/documentation/metalperformanceshaders for more details.
"""
from typing import Union
import torch
from torch import Tensor
_is_in_bad_fork = getattr(torch._C, "_mps_is_in_bad_fork", lambda: False)
_default_mps_generator: torch._C.Generator = None # type: ignore[assignment]
# local helper function (not public or exported)
def _get_default_mps_generator() -> torch._C.Generator:
global _default_mps_generator
if _default_mps_generator is None:
_default_mps_generator = torch._C._mps_get_default_generator()
return _default_mps_generator
[docs]def device_count() -> int:
r"""Returns the number of available MPS devices."""
return int(torch._C._has_mps and torch._C._mps_is_available())
[docs]def synchronize() -> None:
r"""Waits for all kernels in all streams on a MPS device to complete."""
return torch._C._mps_deviceSynchronize()
[docs]def get_rng_state(device: Union[int, str, torch.device] = "mps") -> Tensor:
r"""Returns the random number generator state as a ByteTensor.
Args:
device (torch.device or int, optional): The device to return the RNG state of.
Default: ``'mps'`` (i.e., ``torch.device('mps')``, the current MPS device).
"""
return _get_default_mps_generator().get_state()
[docs]def set_rng_state(
new_state: Tensor, device: Union[int, str, torch.device] = "mps"
) -> None:
r"""Sets the random number generator state.
Args:
new_state (torch.ByteTensor): The desired state
device (torch.device or int, optional): The device to set the RNG state.
Default: ``'mps'`` (i.e., ``torch.device('mps')``, the current MPS device).
"""
new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
_get_default_mps_generator().set_state(new_state_copy)
[docs]def manual_seed(seed: int) -> None:
r"""Sets the seed for generating random numbers.
Args:
seed (int): The desired seed.
"""
# the torch.mps.manual_seed() can be called from the global
# torch.manual_seed() in torch/random.py. So we need to make
# sure mps is available (otherwise we just return without
# erroring out)
if not torch._C._has_mps:
return
seed = int(seed)
_get_default_mps_generator().manual_seed(seed)
[docs]def seed() -> None:
r"""Sets the seed for generating random numbers to a random number."""
_get_default_mps_generator().seed()
[docs]def empty_cache() -> None:
r"""Releases all unoccupied cached memory currently held by the caching
allocator so that those can be used in other GPU applications.
"""
torch._C._mps_emptyCache()
[docs]def set_per_process_memory_fraction(fraction) -> None:
r"""Set memory fraction for limiting process's memory allocation on MPS device.
The allowed value equals the fraction multiplied by recommended maximum device memory
(obtained from Metal API device.recommendedMaxWorkingSetSize).
If trying to allocate more than the allowed value in a process, it will raise an out of
memory error in allocator.
Args:
fraction(float): Range: 0~2. Allowed memory equals total_memory * fraction.
.. note::
Passing 0 to fraction means unlimited allocations
(may cause system failure if out of memory).
Passing fraction greater than 1.0 allows limits beyond the value
returned from device.recommendedMaxWorkingSetSize.
"""
if not isinstance(fraction, float):
raise TypeError("Invalid type for fraction argument, must be `float`")
if fraction < 0 or fraction > 2:
raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~2")
torch._C._mps_setMemoryFraction(fraction)
[docs]def current_allocated_memory() -> int:
r"""Returns the current GPU memory occupied by tensors in bytes.
.. note::
The returned size does not include cached allocations in
memory pools of MPSAllocator.
"""
return torch._C._mps_currentAllocatedMemory()
[docs]def driver_allocated_memory() -> int:
r"""Returns total GPU memory allocated by Metal driver for the process in bytes.
.. note::
The returned size includes cached allocations in MPSAllocator pools
as well as allocations from MPS/MPSGraph frameworks.
"""
return torch._C._mps_driverAllocatedMemory()
[docs]def recommended_max_memory() -> int:
r"""Returns recommended max Working set size for GPU memory in bytes.
.. note::
Recommended max working set size for Metal.
returned from device.recommendedMaxWorkingSetSize.
"""
return torch._C._mps_recommendedMaxMemory()
def _compile_shader(source: str):
r"""Compiles compute shader from source and allows one to invoke kernels
defined there from the comfort of Python runtime
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_MPS)
>>> lib = torch.mps._compile_shader(
... "kernel void full(device float* out, constant float& val, uint idx [[thread_position_in_grid]]) { out[idx] = val; }"
... )
>>> x = torch.zeros(16, device="mps")
>>> lib.full(x, 3.14)
"""
if not hasattr(torch._C, "_mps_compileShader"):
raise RuntimeError("MPS is not available")
return torch._C._mps_compileShader(source)
def is_available() -> bool:
return device_count() > 0
from . import profiler
from .event import Event
__all__ = [
"device_count",
"get_rng_state",
"manual_seed",
"seed",
"set_rng_state",
"synchronize",
"empty_cache",
"set_per_process_memory_fraction",
"current_allocated_memory",
"driver_allocated_memory",
"Event",
"profiler",
"recommended_max_memory",
"is_available",
]