Source code for torch_tensorrt._Device

from __future__ import annotations

import logging
import sys
from typing import Any, Optional, Tuple

if sys.version_info >= (3, 11):
    from typing import Self
    from typing_extensions import Self

import torch
from torch_tensorrt._enums import DeviceType
from torch_tensorrt._features import ENABLED_FEATURES

import tensorrt as trt

[docs]class Device(object): """ Defines a device that can be used to specify target devices for engines Attributes: device_type (torch_tensorrt.DeviceType): Target device type (GPU or DLA). Set implicitly based on if dla_core is specified. gpu_id (int): Device ID for target GPU dla_core (int): Core ID for target DLA core allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed """ device_type: DeviceType = ( DeviceType.UNKNOWN ) #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified. gpu_id: int = -1 #: Device ID for target GPU dla_core: int = -1 #: Core ID for target DLA core allow_gpu_fallback: bool = ( False #: Whether falling back to GPU if DLA cannot support an op should be allowed )
[docs] def __init__(self, *args: Any, **kwargs: Any): """__init__ Method for torch_tensorrt.Device Device accepts one of a few construction patterns Args: spec (str): String with device spec e.g. "dla:0" for dla, core_id 0 Keyword Arguments: gpu_id (int): ID of target GPU (will get overrided if dla_core is specified to the GPU managing DLA). If specified, no positional arguments should be provided dla_core (int): ID of target DLA core. If specified, no positional arguments should be provided. allow_gpu_fallback (bool): Allow TensorRT to schedule operations on GPU if they are not supported on DLA (ignored if device type is not DLA) Examples: - Device("gpu:1") - Device("cuda:1") - Device("dla:0", allow_gpu_fallback=True) - Device(gpu_id=0, dla_core=0, allow_gpu_fallback=True) - Device(dla_core=0, allow_gpu_fallback=True) - Device(gpu_id=1) """ if len(args) == 1: if not isinstance(args[0], str): raise TypeError( "When specifying Device through positional argument, argument must be str" ) else: (self.device_type, id) = Device._parse_device_str(args[0]) if self.device_type == DeviceType.DLA: self.dla_core = id self.gpu_id = 0 logging.warning( "Setting GPU id to 0 for device because device 0 manages DLA on AGX Devices", ) else: self.gpu_id = id elif len(args) == 0: if "gpu_id" in kwargs or "dla_core" in kwargs: if "dla_core" in kwargs: self.dla_core = kwargs["dla_core"] if "gpu_id" in kwargs: self.gpu_id = kwargs["gpu_id"] if self.dla_core >= 0: self.device_type = DeviceType.DLA if self.gpu_id != 0: self.gpu_id = 0 logging.warning( "Setting GPU id to 0 for device because device 0 manages DLA on AGX Platforms", ) else: self.device_type = DeviceType.GPU else: raise ValueError( "Either gpu_id or dla_core or both must be defined if no string with device specs is provided as an arg" ) else: raise ValueError( f"Unexpected number of positional arguments for class Device \n Found {len(args)} arguments, expected either zero or a single positional arguments" ) if "allow_gpu_fallback" in kwargs: if not isinstance(kwargs["allow_gpu_fallback"], bool): raise TypeError("allow_gpu_fallback must be a bool") self.allow_gpu_fallback = kwargs["allow_gpu_fallback"] if "device_type" in kwargs: if isinstance(kwargs["device_type"], trt.DeviceType): self.device_type = DeviceType._from(kwargs["device_type"])
def __str__(self) -> str: suffix = ( ")" if self.device_type == DeviceType.GPU else f", dla_core={self.dla_core}, allow_gpu_fallback={self.allow_gpu_fallback})" ) dev_str: str = f"Device(type={self.device_type}, gpu_id={self.gpu_id}{suffix}" return dev_str def __repr__(self) -> str: return self.__str__() @classmethod def _from(cls, d: Optional[Self | torch.device | str]) -> Device: """Cast a device-type to torch_tensorrt.Device Returns the corresponding torch_tensorrt.Device """ if isinstance(d, Device): return d elif isinstance(d, torch.device): if d.type != "cuda": raise ValueError('Torch Device specs must have type "cuda"') return cls(gpu_id=d.index) elif d is None: return cls(gpu_id=torch.cuda.current_device()) else: return cls(d) @classmethod def _from_torch_device(cls, torch_dev: torch.device) -> Device: return cls._from(torch_dev) @classmethod def _current_device(cls) -> Device: dev_id = torch.cuda.current_device() return cls(gpu_id=dev_id) @staticmethod def _parse_device_str(s: str) -> Tuple[trt.DeviceType, int]: s = s.lower() spec = s.split(":") if spec[0] == "gpu" or spec[0] == "cuda": return (DeviceType.GPU, int(spec[1])) elif spec[0] == "dla": return (DeviceType.DLA, int(spec[1])) else: raise ValueError(f"Unknown device type {spec[0]}") def to(self, t: type) -> torch.device: if t == torch.device: if self.gpu_id != -1: return torch.device(self.gpu_id) else: raise ValueError("Invalid GPU ID provided for the CUDA device provided") else: raise TypeError("Unsupported target type for device conversion") def _to_serialized_rt_device(self) -> str: if not ENABLED_FEATURES.torch_tensorrt_runtime: raise NotImplementedError("Torch-TensorRT runtime is not available") delim = torch.ops.tensorrt.SERIALIZED_RT_DEVICE_DELIM()[0] dev_info = torch.cuda.get_device_properties(self.gpu_id) rt_info = [ self.gpu_id, dev_info.major, dev_info.minor, int(, # type: ignore[arg-type], ] rt_info = [str(i) for i in rt_info] packed_rt_info: str = delim.join(rt_info) logging.debug(f"Serialized Device Info: {packed_rt_info}") return packed_rt_info


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources