from __future__ import annotations
import logging
from enum import Enum, auto
from typing import Any, Optional, Type, Union
import numpy as np
import tensorrt as trt
import torch
from torch_tensorrt._features import ENABLED_FEATURES, needs_torch_tensorrt_runtime
[docs]class dtype(Enum):
"""Enum to describe data types to Torch-TensorRT, has compatibility with torch, tensorrt and numpy dtypes"""
# Supported types in Torch-TensorRT
unknown = auto()
"""Sentinel value
:meta hide-value:
"""
u8 = auto()
"""Unsigned 8 bit integer, equivalent to ``dtype.uint8``
:meta hide-value:
"""
i8 = auto()
"""Signed 8 bit integer, equivalent to ``dtype.int8``, when enabled as a kernel precision typically requires the model to support quantization
:meta hide-value:
"""
i32 = auto()
"""Signed 32 bit integer, equivalent to ``dtype.int32`` and ``dtype.int``
:meta hide-value:
"""
i64 = auto()
"""Signed 64 bit integer, equivalent to ``dtype.int64`` and ``dtype.long``
:meta hide-value:
"""
f16 = auto()
"""16 bit floating-point number, equivalent to ``dtype.half``, ``dtype.fp16`` and ``dtype.float16``
:meta hide-value:
"""
f32 = auto()
"""32 bit floating-point number, equivalent to ``dtype.float``, ``dtype.fp32`` and ``dtype.float32``
:meta hide-value:
"""
f64 = auto()
"""64 bit floating-point number, equivalent to ``dtype.double``, ``dtype.fp64`` and ``dtype.float64``
:meta hide-value:
"""
b = auto()
"""Boolean value, equivalent to ``dtype.bool``
:meta hide-value:
"""
bf16 = auto()
"""16 bit "Brain" floating-point number, equivalent to ``dtype.bfloat16``
:meta hide-value:
"""
f8 = auto()
"""8 bit floating-point number, equivalent to ``dtype.fp8`` and ``dtype.float8``
:meta hide-value:
"""
uint8 = u8
int8 = i8
int32 = i32
long = i64
int64 = i64
float8 = f8
fp8 = f8
half = f16
fp16 = f16
float16 = f16
float = f32
fp32 = f32
float32 = f32
double = f64
fp64 = f64
float64 = f64
bfloat16 = bf16
@staticmethod
def _is_np_obj(t: Any) -> bool:
if isinstance(t, np.dtype):
return True
elif isinstance(t, type):
if issubclass(t, np.generic):
return True
return False
@classmethod
def _from(
cls,
t: Union[torch.dtype, trt.DataType, np.dtype, dtype, type],
use_default: bool = False,
) -> dtype:
"""Create a Torch-TensorRT dtype from another library's dtype system.
Takes a dtype enum from one of numpy, torch, and tensorrt and create a ``torch_tensorrt.dtype``.
If the source dtype system is not supported or the type is not supported in Torch-TensorRT,
then an exception will be raised. As such it is not recommended to use this method directly.
Alternatively use ``torch_tensorrt.dtype.try_from()``
Arguments:
t (Union(torch.dtype, tensorrt.DataType, numpy.dtype, dtype)): Data type enum from another library
use_default (bool): In some cases a catch all type (such as ``torch_tensorrt.dtype.f32``) is sufficient, so instead of throwing an exception, return default value.
Returns:
dtype: Equivalent ``torch_tensorrt.dtype`` to ``t``
Raises:
TypeError: Unsupported data type or unknown source
Examples:
.. code:: py
# Succeeds
float_dtype = torch_tensorrt.dtype._from(torch.float) # Returns torch_tensorrt.dtype.f32
# Throws exception
float_dtype = torch_tensorrt.dtype._from(torch.complex128)
"""
# TODO: Ideally implemented with match statement but need to wait for Py39 EoL
if isinstance(t, torch.dtype):
if t == torch.uint8:
return dtype.u8
elif t == torch.int8:
return dtype.i8
elif t == torch.long:
return dtype.i64
elif t == torch.int32:
return dtype.i32
elif t == torch.float8_e4m3fn:
return dtype.f8
elif t == torch.half:
return dtype.f16
elif t == torch.float:
return dtype.f32
elif t == torch.float64:
return dtype.f64
elif t == torch.bool:
return dtype.b
elif t == torch.bfloat16:
return dtype.bf16
elif use_default:
logging.warning(
f"Given dtype that does not have direct mapping to Torch-TensorRT supported types ({t}), defaulting to torch_tensorrt.dtype.float"
)
return dtype.float
else:
raise TypeError(
f"Provided an unsupported data type as a data type for translation (support: bool, int, long, half, float, bfloat16), got: {t}"
)
elif isinstance(t, trt.DataType):
if t == trt.DataType.UINT8:
return dtype.u8
elif t == trt.DataType.INT8:
return dtype.i8
elif t == trt.DataType.FP8:
return dtype.f8
elif t == trt.DataType.INT32:
return dtype.i32
elif t == trt.DataType.INT64:
return dtype.i64
elif t == trt.DataType.HALF:
return dtype.f16
elif t == trt.DataType.FLOAT:
return dtype.f32
elif t == trt.DataType.BOOL:
return dtype.b
elif t == trt.DataType.BF16:
return dtype.bf16
else:
raise TypeError(
f"Provided an unsupported data type as a data type for translation (support: bool, int, half, float, bfloat16), got: {t}"
)
elif dtype._is_np_obj(t):
if t == np.uint8:
return dtype.u8
elif t == np.int8:
return dtype.i8
elif t == np.int32:
return dtype.i32
elif t == np.int64:
return dtype.i64
elif t == np.float16:
return dtype.f16
elif t == np.float32:
return dtype.f32
elif t == np.float64:
return dtype.f64
elif t == np.bool_:
return dtype.b
# TODO: Consider using ml_dtypes when issues like this are resolved:
# https://github.com/pytorch/pytorch/issues/109873
# elif t == ml_dtypes.bfloat16:
# return dtype.bf16
elif use_default:
logging.warning(
f"Given dtype that does not have direct mapping to Torch-TensorRT supported types ({t}), defaulting to torch_tensorrt.dtype.float"
)
return dtype.float
else:
raise TypeError(
"Provided an unsupported data type as an input data type (support: bool, int, long, half, float, bfloat16), got: "
+ str(t)
)
elif isinstance(t, dtype):
return t
elif ENABLED_FEATURES.torchscript_frontend:
from torch_tensorrt import _C
if isinstance(t, _C.dtype):
if t == _C.dtype.long:
return dtype.i64
elif t == _C.dtype.int32:
return dtype.i32
elif t == _C.dtype.int8:
return dtype.i8
elif t == _C.dtype.half:
return dtype.f16
elif t == _C.dtype.float:
return dtype.f32
elif t == _C.dtype.double:
return dtype.f64
elif t == _C.dtype.bool:
return dtype.b
elif t == _C.dtype.unknown:
return dtype.unknown
else:
raise TypeError(
f"Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: {t}"
)
# else: # commented out for mypy
raise TypeError(
f"Provided unsupported source type for dtype conversion (got: {t})"
)
[docs] @classmethod
def try_from(
cls,
t: Union[torch.dtype, trt.DataType, np.dtype, dtype],
use_default: bool = False,
) -> Optional[dtype]:
"""Create a Torch-TensorRT dtype from another library's dtype system.
Takes a dtype enum from one of numpy, torch, and tensorrt and create a ``torch_tensorrt.dtype``.
If the source dtype system is not supported or the type is not supported in Torch-TensorRT,
then returns ``None``.
Arguments:
t (Union(torch.dtype, tensorrt.DataType, numpy.dtype, dtype)): Data type enum from another library
use_default (bool): In some cases a catch all type (such as ``torch_tensorrt.dtype.f32``) is sufficient, so instead of throwing an exception, return default value.
Returns:
Optional(dtype): Equivalent ``torch_tensorrt.dtype`` to ``t`` or ``None``
Examples:
.. code:: py
# Succeeds
float_dtype = torch_tensorrt.dtype.try_from(torch.float) # Returns torch_tensorrt.dtype.f32
# Unsupported type
float_dtype = torch_tensorrt.dtype.try_from(torch.complex128) # Returns None
"""
try:
casted_format = dtype._from(t, use_default=use_default)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(
f"Conversion from {t} to torch_tensorrt.dtype failed", exc_info=True
)
return None
[docs] def to(
self,
t: Union[Type[torch.dtype], Type[trt.DataType], Type[np.dtype], Type[dtype]],
use_default: bool = False,
) -> Union[torch.dtype, trt.DataType, np.dtype, dtype]:
"""Convert dtype into the equivalent type in [torch, numpy, tensorrt]
Converts ``self`` into one of numpy, torch, and tensorrt equivalent dtypes.
If ``self`` is not supported in the target library, then an exception will be raised.
As such it is not recommended to use this method directly.
Alternatively use ``torch_tensorrt.dtype.try_to()``
Arguments:
t (Union(Type(torch.dtype), Type(tensorrt.DataType), Type(numpy.dtype), Type(dtype))): Data type enum from another library to convert to
use_default (bool): In some cases a catch all type (such as ``torch.float``) is sufficient, so instead of throwing an exception, return default value.
Returns:
Union(torch.dtype, tensorrt.DataType, numpy.dtype, dtype): dtype equivalent ``torch_tensorrt.dtype`` from library enum ``t``
Raises:
TypeError: Unsupported data type or unknown target
Examples:
.. code:: py
# Succeeds
float_dtype = torch_tensorrt.dtype.f32.to(torch.dtype) # Returns torch.float
# Failure
float_dtype = torch_tensorrt.dtype.bf16.to(numpy.dtype) # Throws exception
"""
# TODO: Ideally implemented with match statement but need to wait for Py39 EoL
if t == torch.dtype:
if self == dtype.u8:
return torch.uint8
elif self == dtype.i8:
return torch.int8
elif self == dtype.i32:
return torch.int
elif self == dtype.i64:
return torch.long
elif self == dtype.f8:
return torch.float8_e4m3fn
elif self == dtype.f16:
return torch.half
elif self == dtype.f32:
return torch.float
elif self == dtype.f64:
return torch.double
elif self == dtype.b:
return torch.bool
elif self == dtype.bf16:
return torch.bfloat16
elif use_default:
logging.warning(
f"Given dtype that does not have direct mapping to torch ({self}), defaulting to torch.float"
)
return torch.float
else:
raise TypeError(f"Unsupported torch dtype (had: {self})")
elif t == trt.DataType:
if self == dtype.u8:
return trt.DataType.UINT8
if self == dtype.i8:
return trt.DataType.INT8
elif self == dtype.i32:
return trt.DataType.INT32
elif self == dtype.f8:
return trt.DataType.FP8
elif self == dtype.i64:
return trt.DataType.INT64
elif self == dtype.f16:
return trt.DataType.HALF
elif self == dtype.f32:
return trt.DataType.FLOAT
elif self == dtype.b:
return trt.DataType.BOOL
elif self == dtype.bf16:
return trt.DataType.BF16
elif use_default:
return trt.DataType.FLOAT
else:
raise TypeError("Unsupported tensorrt dtype")
elif t == np.dtype:
if self == dtype.u8:
return np.uint8
elif self == dtype.i8:
return np.int8
elif self == dtype.i32:
return np.int32
elif self == dtype.i64:
return np.int64
elif self == dtype.f16:
return np.float16
elif self == dtype.f32:
return np.float32
elif self == dtype.f64:
return np.float64
elif self == dtype.b:
return np.bool_
# TODO: Consider using ml_dtypes when issues like this are resolved:
# https://github.com/pytorch/pytorch/issues/109873
# elif self == dtype.bf16:
# return ml_dtypes.bfloat16
elif use_default:
return np.float32
else:
raise TypeError("Unsupported numpy dtype")
elif t == dtype:
return self
elif ENABLED_FEATURES.torchscript_frontend:
from torch_tensorrt import _C
if t == _C.dtype:
if self == dtype.i64:
return _C.dtype.long
elif self == dtype.i8:
return _C.dtype.int8
elif self == dtype.i32:
return _C.dtype.int32
elif self == dtype.f16:
return _C.dtype.half
elif self == dtype.f32:
return _C.dtype.float
elif self == dtype.f64:
return _C.dtype.double
elif self == dtype.b:
return _C.dtype.bool
elif self == dtype.unknown:
return _C.dtype.unknown
else:
raise TypeError(
f"Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: {self}"
)
# else: # commented out for mypy
raise TypeError(
f"Provided unsupported destination type for dtype conversion {t}"
)
[docs] def try_to(
self,
t: Union[Type[torch.dtype], Type[trt.DataType], Type[np.dtype], Type[dtype]],
use_default: bool,
) -> Optional[Union[torch.dtype, trt.DataType, np.dtype, dtype]]:
"""Convert dtype into the equivalent type in [torch, numpy, tensorrt]
Converts ``self`` into one of numpy, torch, and tensorrt equivalent dtypes.
If ``self`` is not supported in the target library, then returns ``None``.
Arguments:
t (Union(Type(torch.dtype), Type(tensorrt.DataType), Type(numpy.dtype), Type(dtype))): Data type enum from another library to convert to
use_default (bool): In some cases a catch all type (such as ``torch.float``) is sufficient, so instead of throwing an exception, return default value.
Returns:
Optional(Union(torch.dtype, tensorrt.DataType, numpy.dtype, dtype)): dtype equivalent ``torch_tensorrt.dtype`` from library enum ``t``
Examples:
.. code:: py
# Succeeds
float_dtype = torch_tensorrt.dtype.f32.to(torch.dtype) # Returns torch.float
# Failure
float_dtype = torch_tensorrt.dtype.bf16.to(numpy.dtype) # Returns None
"""
try:
casted_format = self.to(t, use_default)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(
f"torch_tensorrt.dtype conversion to target type {t} failed",
exc_info=True,
)
return None
def __eq__(self, other: Union[torch.dtype, trt.DataType, np.dtype, dtype]) -> bool:
other_ = dtype._from(other)
return bool(self.value == other_.value)
def __hash__(self) -> int:
return hash(self.value)
# Putting aliases here that mess with mypy
bool = b
int = i32
[docs]class DeviceType(Enum):
"""Type of device TensorRT will target"""
UNKNOWN = auto()
"""
Sentinel value
:meta hide-value:
"""
GPU = auto()
"""
Target is a GPU
:meta hide-value:
"""
DLA = auto()
"""
Target is a DLA core
:meta hide-value:
"""
@classmethod
def _from(cls, d: Union[trt.DeviceType, DeviceType]) -> DeviceType:
"""Create a Torch-TensorRT device type enum from a TensorRT device type enum.
Takes a device type enum from tensorrt and create a ``torch_tensorrt.DeviceType``.
If the source is not supported or the device type is not supported in Torch-TensorRT,
then an exception will be raised. As such it is not recommended to use this method directly.
Alternatively use ``torch_tensorrt.DeviceType.try_from()``
Arguments:
d (Union(tensorrt.DeviceType, DeviceType)): Device type enum from another library
Returns:
DeviceType: Equivalent ``torch_tensorrt.DeviceType`` to ``d``
Raises:
TypeError: Unknown source type or unsupported device type
Examples:
.. code:: py
torchtrt_dla = torch_tensorrt.DeviceType._from(tensorrt.DeviceType.DLA)
"""
if isinstance(d, trt.DeviceType):
if d == trt.DeviceType.GPU:
return DeviceType.GPU
elif d == trt.DeviceType.DLA:
return DeviceType.DLA
else:
raise ValueError(
"Provided an unsupported device type (support: GPU/DLA)"
)
elif isinstance(d, DeviceType):
return d
elif ENABLED_FEATURES.torchscript_frontend:
from torch_tensorrt import _C
if isinstance(d, _C.DeviceType):
if d == _C.DeviceType.GPU:
return DeviceType.GPU
elif d == _C.DeviceType.DLA:
return DeviceType.DLA
else:
raise ValueError(
"Provided an unsupported device type (support: GPU/DLA)"
)
# else: # commented out for mypy
raise TypeError("Provided unsupported source type for DeviceType conversion")
[docs] @classmethod
def try_from(cls, d: Union[trt.DeviceType, DeviceType]) -> Optional[DeviceType]:
"""Create a Torch-TensorRT device type enum from a TensorRT device type enum.
Takes a device type enum from tensorrt and create a ``torch_tensorrt.DeviceType``.
If the source is not supported or the device type is not supported in Torch-TensorRT,
then an exception will be raised. As such it is not recommended to use this method directly.
Alternatively use ``torch_tensorrt.DeviceType.try_from()``
Arguments:
d (Union(tensorrt.DeviceType, DeviceType)): Device type enum from another library
Returns:
DeviceType: Equivalent ``torch_tensorrt.DeviceType`` to ``d``
Examples:
.. code:: py
torchtrt_dla = torch_tensorrt.DeviceType._from(tensorrt.DeviceType.DLA)
"""
try:
casted_format = DeviceType._from(d)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(
f"Conversion from {d} to torch_tensorrt.DeviceType failed",
exc_info=True,
)
return None
[docs] def to(
self,
t: Union[Type[trt.DeviceType], Type[DeviceType]],
use_default: bool = False,
) -> Union[trt.DeviceType, DeviceType]:
"""Convert ``DeviceType`` into the equivalent type in tensorrt
Converts ``self`` into one of torch or tensorrt equivalent device type.
If ``self`` is not supported in the target library, then an exception will be raised.
As such it is not recommended to use this method directly.
Alternatively use ``torch_tensorrt.DeviceType.try_to()``
Arguments:
t (Union(Type(tensorrt.DeviceType), Type(DeviceType))): Device type enum from another library to convert to
Returns:
Union(tensorrt.DeviceType, DeviceType): Device type equivalent ``torch_tensorrt.DeviceType`` in enum ``t``
Raises:
TypeError: Unknown target type or unsupported device type
Examples:
.. code:: py
# Succeeds
trt_dla = torch_tensorrt.DeviceType.DLA.to(tensorrt.DeviceType) # Returns tensorrt.DeviceType.DLA
"""
if t == trt.DeviceType:
if self == DeviceType.GPU:
return trt.DeviceType.GPU
elif self == DeviceType.DLA:
return trt.DeviceType.DLA
elif use_default:
return trt.DeviceType.GPU
else:
raise ValueError(
"Provided an unsupported device type (support: GPU/DLA)"
)
elif t == DeviceType:
return self
elif ENABLED_FEATURES.torchscript_frontend:
from torch_tensorrt import _C
if t == _C.DeviceType:
if self == DeviceType.GPU:
return _C.DeviceType.GPU
elif self == DeviceType.DLA:
return _C.DeviceType.DLA
else:
raise ValueError(
"Provided an unsupported device type (support: GPU/DLA)"
)
# else: # commented out for mypy
raise TypeError(
"Provided unsupported destination type for device type conversion"
)
[docs] def try_to(
self,
t: Union[Type[trt.DeviceType], Type[DeviceType]],
use_default: bool = False,
) -> Optional[Union[trt.DeviceType, DeviceType]]:
"""Convert ``DeviceType`` into the equivalent type in tensorrt
Converts ``self`` into one of torch or tensorrt equivalent memory format.
If ``self`` is not supported in the target library, then ``None`` will be returned.
Arguments:
t (Union(Type(tensorrt.DeviceType), Type(DeviceType))): Device type enum from another library to convert to
Returns:
Optional(Union(tensorrt.DeviceType, DeviceType)): Device type equivalent ``torch_tensorrt.DeviceType`` in enum ``t``
Examples:
.. code:: py
# Succeeds
trt_dla = torch_tensorrt.DeviceType.DLA.to(tensorrt.DeviceType) # Returns tensorrt.DeviceType.DLA
"""
try:
casted_format = self.to(t, use_default=use_default)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(
f"torch_tensorrt.DeviceType conversion to target type {t} failed",
exc_info=True,
)
return None
def __eq__(self, other: Union[trt.DeviceType, DeviceType]) -> bool:
other_ = DeviceType._from(other)
return bool(self.value == other_.value)
def __hash__(self) -> int:
return hash(self.value)
[docs]class EngineCapability(Enum):
"""
EngineCapability determines the restrictions of a network during build time and what runtime it targets.
"""
STANDARD = auto()
"""
EngineCapability.STANDARD does not provide any restrictions on functionality and the resulting serialized engine can be executed with TensorRT’s standard runtime APIs.
:meta hide-value:
"""
SAFETY = auto()
"""
EngineCapability.SAFETY provides a restricted subset of network operations that are safety certified and the resulting serialized engine can be executed with TensorRT’s safe runtime APIs in the tensorrt.safe namespace.
:meta hide-value:
"""
DLA_STANDALONE = auto()
"""
``EngineCapability.DLA_STANDALONE`` provides a restricted subset of network operations that are DLA compatible and the resulting serialized engine can be executed using standalone DLA runtime APIs.
:meta hide-value:
"""
@classmethod
def _from(
cls, c: Union[trt.EngineCapability, EngineCapability]
) -> EngineCapability:
"""Create a Torch-TensorRT Engine capability enum from a TensorRT Engine capability enum.
Takes a device type enum from tensorrt and create a ``torch_tensorrt.EngineCapability``.
If the source is not supported or the engine capability is not supported in Torch-TensorRT,
then an exception will be raised. As such it is not recommended to use this method directly.
Alternatively use ``torch_tensorrt.EngineCapability.try_from()``
Arguments:
c (Union(tensorrt.EngineCapability, EngineCapability)): Engine capability enum from another library
Returns:
EngineCapability: Equivalent ``torch_tensorrt.EngineCapability`` to ``c``
Raises:
TypeError: Unknown source type or unsupported engine capability
Examples:
.. code:: py
torchtrt_ec = torch_tensorrt.EngineCapability._from(tensorrt.EngineCapability.SAFETY)
"""
if isinstance(c, trt.EngineCapability):
if c == trt.EngineCapability.STANDARD:
return EngineCapability.STANDARD
elif c == trt.EngineCapability.SAFETY:
return EngineCapability.SAFETY
elif c == trt.EngineCapability.DLA_STANDALONE:
return EngineCapability.DLA_STANDALONE
else:
raise ValueError("Provided an unsupported engine capability")
elif isinstance(c, EngineCapability):
return c
elif ENABLED_FEATURES.torchscript_frontend:
from torch_tensorrt import _C
if isinstance(c, _C.EngineCapability):
if c == _C.EngineCapability.STANDARD:
return EngineCapability.STANDARD
elif c == _C.EngineCapability.SAFETY:
return EngineCapability.SAFETY
elif c == _C.EngineCapability.DLA_STANDALONE:
return EngineCapability.DLA_STANDALONE
else:
raise ValueError("Provided an unsupported engine capability")
# else: # commented out for mypy
raise TypeError(
"Provided unsupported source type for EngineCapability conversion"
)
[docs] @classmethod
def try_from(
c: Union[trt.EngineCapability, EngineCapability],
) -> Optional[EngineCapability]:
"""Create a Torch-TensorRT engine capability enum from a TensorRT engine capability enum.
Takes a device type enum from tensorrt and create a ``torch_tensorrt.EngineCapability``.
If the source is not supported or the engine capability level is not supported in Torch-TensorRT,
then an exception will be raised. As such it is not recommended to use this method directly.
Alternatively use ``torch_tensorrt.EngineCapability.try_from()``
Arguments:
c (Union(tensorrt.EngineCapability, EngineCapability)): Engine capability enum from another library
Returns:
EngineCapability: Equivalent ``torch_tensorrt.EngineCapability`` to ``c``
Examples:
.. code:: py
torchtrt_safety_ec = torch_tensorrt.EngineCapability._from(tensorrt.EngineCapability.SAEFTY)
"""
try:
casted_format = EngineCapability._from(c)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(
f"Conversion from {c} to torch_tensorrt.EngineCapablity failed",
exc_info=True,
)
return None
[docs] def to(
self, t: Union[Type[trt.EngineCapability], Type[EngineCapability]]
) -> Union[trt.EngineCapability, EngineCapability]:
"""Convert ``EngineCapability`` into the equivalent type in tensorrt
Converts ``self`` into one of torch or tensorrt equivalent engine capability.
If ``self`` is not supported in the target library, then an exception will be raised.
As such it is not recommended to use this method directly.
Alternatively use ``torch_tensorrt.EngineCapability.try_to()``
Arguments:
t (Union(Type(tensorrt.EngineCapability), Type(EngineCapability))): Engine capability enum from another library to convert to
Returns:
Union(tensorrt.EngineCapability, EngineCapability): Engine capability equivalent ``torch_tensorrt.EngineCapability`` in enum ``t``
Raises:
TypeError: Unknown target type or unsupported engine capability
Examples:
.. code:: py
# Succeeds
torchtrt_dla_ec = torch_tensorrt.EngineCapability.DLA_STANDALONE.to(tensorrt.EngineCapability) # Returns tensorrt.EngineCapability.DLA
"""
if t == trt.EngineCapability:
if self == EngineCapability.STANDARD:
return trt.EngineCapability.STANDARD
elif self == EngineCapability.SAFETY:
return trt.EngineCapability.SAFETY
elif self == EngineCapability.DLA_STANDALONE:
return trt.EngineCapability.DLA_STANDALONE
else:
raise ValueError("Provided an unsupported engine capability")
elif t == EngineCapability:
return self
elif ENABLED_FEATURES.torchscript_frontend:
from torch_tensorrt import _C
if t == _C.EngineCapability:
if self == EngineCapability.STANDARD:
return _C.EngineCapability.STANDARD
elif self == EngineCapability.SAFETY:
return _C.EngineCapability.SAFETY
elif self == EngineCapability.DLA_STANDALONE:
return _C.EngineCapability.DLA_STANDALONE
else:
raise ValueError("Provided an unsupported engine capability")
# else: # commented out for mypy
raise TypeError(
"Provided unsupported destination type for engine capability type conversion"
)
[docs] def try_to(
self, t: Union[Type[trt.EngineCapability], Type[EngineCapability]]
) -> Optional[Union[trt.EngineCapability, EngineCapability]]:
"""Convert ``EngineCapability`` into the equivalent type in tensorrt
Converts ``self`` into one of torch or tensorrt equivalent engine capability.
If ``self`` is not supported in the target library, then ``None`` will be returned.
Arguments:
t (Union(Type(tensorrt.EngineCapability), Type(EngineCapability))): Engine capability enum from another library to convert to
Returns:
Optional(Union(tensorrt.EngineCapability, EngineCapability)): Engine capability equivalent ``torch_tensorrt.EngineCapability`` in enum ``t``
Examples:
.. code:: py
# Succeeds
trt_dla_ec = torch_tensorrt.EngineCapability.DLA.to(tensorrt.EngineCapability) # Returns tensorrt.EngineCapability.DLA_STANDALONE
"""
try:
casted_format = self.to(t)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(
f"torch_tensorrt.EngineCapablity conversion to target type {t} failed",
exc_info=True,
)
return None
def __eq__(self, other: Union[trt.EngineCapability, EngineCapability]) -> bool:
other_ = EngineCapability._from(other)
return bool(self.value == other_.value)
def __hash__(self) -> int:
return hash(self.value)
class Platform(Enum):
"""
Specifies a target OS and CPU architecture that a Torch-TensorRT program targets
"""
LINUX_X86_64 = auto()
"""
OS: Linux, CPU Arch: x86_64
:meta hide-value:
"""
LINUX_AARCH64 = auto()
"""
OS: Linux, CPU Arch: aarch64
:meta hide-value:
"""
WIN_X86_64 = auto()
"""
OS: Windows, CPU Arch: x86_64
:meta hide-value:
"""
UNKNOWN = auto()
@classmethod
def current_platform(cls) -> Platform:
"""
Returns an enum for the current platform Torch-TensorRT is running on
Returns:
Platform: Current platform
"""
import platform
if platform.system().lower().startswith("linux"):
# linux
if platform.machine().lower().startswith("aarch64"):
return Platform.LINUX_AARCH64
elif platform.machine().lower().startswith("x86_64"):
return Platform.LINUX_X86_64
elif platform.system().lower().startswith("windows"):
# Windows...
if platform.machine().lower().startswith("amd64"):
return Platform.WIN_X86_64
return Platform.UNKNOWN
def __str__(self) -> str:
return str(self.name)
@needs_torch_tensorrt_runtime # type: ignore
def _to_serialized_rt_platform(self) -> str:
val: str = torch.ops.tensorrt._platform_unknown()
if self == Platform.LINUX_X86_64:
val = torch.ops.tensorrt._platform_linux_x86_64()
elif self == Platform.LINUX_AARCH64:
val = torch.ops.tensorrt._platform_linux_aarch64()
elif self == Platform.WIN_X86_64:
val = torch.ops.tensorrt._platform_win_x86_64()
return val