Shortcuts

Source code for torch_tensorrt._Device

import sys
from typing import Any, Optional, Tuple

if sys.version_info >= (3, 11):
    from typing import Self
else:
    from typing_extensions import Self

import warnings

# from torch_tensorrt import _enums
import tensorrt as trt
import torch
from torch_tensorrt import logging

try:
    from torch_tensorrt import _C
except ImportError:
    warnings.warn(
        "Unable to import torchscript frontend core and torch-tensorrt runtime. Some dependent features may be unavailable."
    )


[docs]class Device(object): """ Defines a device that can be used to specify target devices for engines Attributes: device_type (torch_tensorrt.DeviceType): Target device type (GPU or DLA). Set implicitly based on if dla_core is specified. gpu_id (int): Device ID for target GPU dla_core (int): Core ID for target DLA core allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed """ device_type: Optional[ trt.DeviceType ] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified. gpu_id: int = -1 #: Device ID for target GPU dla_core: int = -1 #: Core ID for target DLA core allow_gpu_fallback: bool = False #: Whether falling back to GPU if DLA cannot support an op should be allowed
[docs] def __init__(self, *args: Any, **kwargs: Any): """__init__ Method for torch_tensorrt.Device Device accepts one of a few construction patterns Args: spec (str): String with device spec e.g. "dla:0" for dla, core_id 0 Keyword Arguments: gpu_id (int): ID of target GPU (will get overrided if dla_core is specified to the GPU managing DLA). If specified, no positional arguments should be provided dla_core (int): ID of target DLA core. If specified, no positional arguments should be provided. allow_gpu_fallback (bool): Allow TensorRT to schedule operations on GPU if they are not supported on DLA (ignored if device type is not DLA) Examples: - Device("gpu:1") - Device("cuda:1") - Device("dla:0", allow_gpu_fallback=True) - Device(gpu_id=0, dla_core=0, allow_gpu_fallback=True) - Device(dla_core=0, allow_gpu_fallback=True) - Device(gpu_id=1) """ if len(args) == 1: if not isinstance(args[0], str): raise TypeError( "When specifying Device through positional argument, argument must be str" ) else: (self.device_type, id) = Device._parse_device_str(args[0]) if self.device_type == trt.DeviceType.GPU: self.gpu_id = id else: self.dla_core = id self.gpu_id = 0 logging.log( logging.Level.Warning, "Setting GPU id to 0 for device because device 0 manages DLA on Xavier", ) elif len(args) == 0: if "gpu_id" in kwargs or "dla_core" in kwargs: if "dla_core" in kwargs: self.device_type = trt.DeviceType.DLA self.dla_core = kwargs["dla_core"] if "gpu_id" in kwargs: self.gpu_id = kwargs["gpu_id"] else: self.gpu_id = 0 logging.log( logging.Level.Warning, "Setting GPU id to 0 for device because device 0 manages DLA on Xavier", ) else: self.gpu_id = kwargs["gpu_id"] self.device_type = trt.DeviceType.GPU else: raise ValueError( "Either gpu_id or dla_core or both must be defined if no string with device specs is provided as an arg" ) else: raise ValueError( "Unexpected number of positional arguments for class Device \n Found {} arguments, expected either zero or a single positional arguments".format( len(args) ) ) if "allow_gpu_fallback" in kwargs: if not isinstance(kwargs["allow_gpu_fallback"], bool): raise TypeError("allow_gpu_fallback must be a bool") self.allow_gpu_fallback = kwargs["allow_gpu_fallback"]
def __str__(self) -> str: return ( "Device(type={}, gpu_id={}".format(self.device_type, self.gpu_id) + ")" if self.device_type == trt.DeviceType.GPU else ", dla_core={}, allow_gpu_fallback={}".format( self.dla_core, self.allow_gpu_fallback ) ) def __repr__(self) -> str: return self.__str__() def _to_internal(self) -> _C.Device: internal_dev = _C.Device() if self.device_type == trt.DeviceType.GPU: internal_dev.device_type = _C.DeviceType.GPU elif self.device_type == trt.DeviceType.DLA: internal_dev.device_type = _C.DeviceType.DLA else: raise ValueError( "Invalid DeviceType detected while parsing the Device class" ) internal_dev.gpu_id = self.gpu_id internal_dev.dla_core = self.dla_core internal_dev.allow_gpu_fallback = self.allow_gpu_fallback return internal_dev def _to_serialized_rt_device(self) -> str: internal_dev = self._to_internal() serialized_rt_device: str = internal_dev._to_serialized_rt_device() return serialized_rt_device @classmethod def _from_torch_device(cls, torch_dev: torch.device) -> Self: if torch_dev.type != "cuda": raise ValueError('Torch Device specs must have type "cuda"') gpu_id = torch_dev.index return cls(gpu_id=gpu_id) @classmethod def _current_device(cls) -> Self: dev = _C._get_current_device() return cls(gpu_id=dev.gpu_id) @staticmethod def _parse_device_str(s: str) -> Tuple[trt.DeviceType, int]: s = s.lower() spec = s.split(":") if spec[0] == "gpu" or spec[0] == "cuda": return (trt.DeviceType.GPU, int(spec[1])) elif spec[0] == "dla": return (trt.DeviceType.DLA, int(spec[1])) else: raise ValueError(f"Unknown device type {spec[0]}")

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources