get_device
- torchtune.utils.get_device(device: Optional[str] = None) device [source]
Function that takes an optional device string, verifies it’s correct and available given the machine and distributed settings, and returns a
device()
. If device string is not provided, this function will infer the device based on the environment.If CUDA-like is available and being used, this function also sets the CUDA-like device.
- Parameters:
device (Optional[str]) – The name of the device to use, one of “cuda”, “cpu”, “npu”, “xpu”, or “mps”.
Example
>>> device = get_device("cuda") >>> device device(type='cuda', index=0)
- Returns:
Device
- Return type: