Source code for torch_tensorrt._Device
from __future__ import annotations
import logging
import sys
from typing import Any, Optional, Tuple
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
import torch
from torch_tensorrt._enums import DeviceType
from torch_tensorrt._features import needs_torch_tensorrt_runtime
import tensorrt as trt
[docs]class Device(object):
"""
Defines a device that can be used to specify target devices for engines
Attributes:
device_type (torch_tensorrt.DeviceType): Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
gpu_id (int): Device ID for target GPU
dla_core (int): Core ID for target DLA core
allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
"""
device_type: DeviceType = (
DeviceType.UNKNOWN
) #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
gpu_id: int = -1 #: Device ID for target GPU
dla_core: int = -1 #: Core ID for target DLA core
allow_gpu_fallback: bool = (
False #: Whether falling back to GPU if DLA cannot support an op should be allowed
)
[docs] def __init__(self, *args: Any, **kwargs: Any):
"""__init__ Method for torch_tensorrt.Device
Device accepts one of a few construction patterns
Args:
spec (str): String with device spec e.g. "dla:0" for dla, core_id 0
Keyword Arguments:
gpu_id (int): ID of target GPU (will get overridden if dla_core is specified to the GPU managing DLA). If specified, no positional arguments should be provided
dla_core (int): ID of target DLA core. If specified, no positional arguments should be provided.
allow_gpu_fallback (bool): Allow TensorRT to schedule operations on GPU if they are not supported on DLA (ignored if device type is not DLA)
Examples:
- Device("gpu:1")
- Device("cuda:1")
- Device("dla:0", allow_gpu_fallback=True)
- Device(gpu_id=0, dla_core=0, allow_gpu_fallback=True)
- Device(dla_core=0, allow_gpu_fallback=True)
- Device(gpu_id=1)
"""
if len(args) == 1:
if not isinstance(args[0], str):
raise TypeError(
"When specifying Device through positional argument, argument must be str"
)
else:
(self.device_type, id) = Device._parse_device_str(args[0])
if self.device_type == DeviceType.DLA:
self.dla_core = id
self.gpu_id = 0
logging.warning(
"Setting GPU id to 0 for device because device 0 manages DLA on AGX Devices",
)
else:
self.gpu_id = id
elif len(args) == 0:
if "gpu_id" in kwargs or "dla_core" in kwargs:
if "dla_core" in kwargs:
self.dla_core = kwargs["dla_core"]
if "gpu_id" in kwargs:
self.gpu_id = kwargs["gpu_id"]
if self.dla_core >= 0:
self.device_type = DeviceType.DLA
if self.gpu_id != 0:
self.gpu_id = 0
logging.warning(
"Setting GPU id to 0 for device because device 0 manages DLA on AGX Platforms",
)
else:
self.device_type = DeviceType.GPU
else:
raise ValueError(
"Either gpu_id or dla_core or both must be defined if no string with device specs is provided as an arg"
)
else:
raise ValueError(
f"Unexpected number of positional arguments for class Device \n Found {len(args)} arguments, expected either zero or a single positional arguments"
)
if "allow_gpu_fallback" in kwargs:
if not isinstance(kwargs["allow_gpu_fallback"], bool):
raise TypeError("allow_gpu_fallback must be a bool")
self.allow_gpu_fallback = kwargs["allow_gpu_fallback"]
if "device_type" in kwargs:
if isinstance(kwargs["device_type"], trt.DeviceType):
self.device_type = DeviceType._from(kwargs["device_type"])
def __str__(self) -> str:
suffix = (
")"
if self.device_type == DeviceType.GPU
else f", dla_core={self.dla_core}, allow_gpu_fallback={self.allow_gpu_fallback})"
)
dev_str: str = f"Device(type={self.device_type}, gpu_id={self.gpu_id}{suffix}"
return dev_str
def __repr__(self) -> str:
return self.__str__()
@classmethod
def _from(cls, d: Optional[Self | torch.device | str]) -> Device:
"""Cast a device-type to torch_tensorrt.Device
Returns the corresponding torch_tensorrt.Device
"""
if isinstance(d, Device):
return d
elif isinstance(d, torch.device):
if d.type != "cuda":
raise ValueError('Torch Device specs must have type "cuda"')
return cls(gpu_id=d.index)
elif d is None:
return cls(gpu_id=torch.cuda.current_device())
else:
return cls(d)
@classmethod
def _from_torch_device(cls, torch_dev: torch.device) -> Device:
return cls._from(torch_dev)
@classmethod
def _current_device(cls) -> Device:
dev_id = torch.cuda.current_device()
return cls(gpu_id=dev_id)
@staticmethod
def _parse_device_str(s: str) -> Tuple[trt.DeviceType, int]:
s = s.lower()
spec = s.split(":")
if spec[0] == "gpu" or spec[0] == "cuda":
return (DeviceType.GPU, int(spec[1]))
elif spec[0] == "dla":
return (DeviceType.DLA, int(spec[1]))
else:
raise ValueError(f"Unknown device type {spec[0]}")
def to(self, t: type) -> torch.device:
if t == torch.device:
if self.gpu_id != -1:
return torch.device(self.gpu_id)
else:
raise ValueError("Invalid GPU ID provided for the CUDA device provided")
else:
raise TypeError("Unsupported target type for device conversion")
@needs_torch_tensorrt_runtime
def _to_serialized_rt_device(self) -> str:
delim = torch.ops.tensorrt.SERIALIZED_RT_DEVICE_DELIM()[0]
dev_info = torch.cuda.get_device_properties(self.gpu_id)
rt_info = [
self.gpu_id,
dev_info.major,
dev_info.minor,
int(self.device_type.to(trt.DeviceType)), # type: ignore[arg-type]
dev_info.name,
]
rt_info = [str(i) for i in rt_info]
packed_rt_info: str = delim.join(rt_info)
logging.debug(f"Serialized Device Info: {packed_rt_info}")
return packed_rt_info