Source code for torch_tensorrt._enums
from __future__ import annotations
import logging
from enum import Enum, auto
from typing import Any, Optional, Type, Union
import numpy as np
import tensorrt as trt
import torch
from torch_tensorrt._features import ENABLED_FEATURES, needs_torch_tensorrt_runtime
[docs]class dtype(Enum):
"""Enum to describe data types to Torch-TensorRT, has compatibility with torch, tensorrt and numpy dtypes"""
# Supported types in Torch-TensorRT
unknown = auto()
"""Sentinel value
:meta hide-value:
"""
u8 = auto()
"""Unsigned 8 bit integer, equivalent to ``dtype.uint8``
:meta hide-value:
"""
i8 = auto()
"""Signed 8 bit integer, equivalent to ``dtype.int8``, when enabled as a kernel precision typically requires the model to support quantization
:meta hide-value:
"""
i32 = auto()
"""Signed 32 bit integer, equivalent to ``dtype.int32`` and ``dtype.int``
:meta hide-value:
"""
i64 = auto()
"""Signed 64 bit integer, equivalent to ``dtype.int64`` and ``dtype.long``
:meta hide-value:
"""
f16 = auto()
"""16 bit floating-point number, equivalent to ``dtype.half``, ``dtype.fp16`` and ``dtype.float16``
:meta hide-value:
"""
f32 = auto()
"""32 bit floating-point number, equivalent to ``dtype.float``, ``dtype.fp32`` and ``dtype.float32``
:meta hide-value:
"""
f64 = auto()
"""64 bit floating-point number, equivalent to ``dtype.double``, ``dtype.fp64`` and ``dtype.float64``
:meta hide-value:
"""
b = auto()
"""Boolean value, equivalent to ``dtype.bool``
:meta hide-value:
"""
bf16 = auto()
"""16 bit "Brain" floating-point number, equivalent to ``dtype.bfloat16``
:meta hide-value:
"""
f8 = auto()
"""8 bit floating-point number, equivalent to ``dtype.fp8`` and ``dtype.float8``
:meta hide-value:
"""
uint8 = u8
int8 = i8
int32 = i32
long = i64
int64 = i64
float8 = f8
fp8 = f8
half = f16
fp16 = f16
float16 = f16
float = f32
fp32 = f32
float32 = f32
double = f64
fp64 = f64
float64 = f64
bfloat16 = bf16
@staticmethod
def _is_np_obj(t: Any) -> bool:
if isinstance(t, np.dtype):
return True
elif isinstance(t, type):
if issubclass(t, np.generic):
return True
return False
@classmethod
def _from(
cls,
t: Union[torch.dtype, trt.DataType, np.dtype, dtype, type],
use_default: bool = False,
) -> dtype:
"""Create a Torch-TensorRT dtype from another library's dtype system.
Takes a dtype enum from one of numpy, torch, and tensorrt and create a ``torch_tensorrt.dtype``.
If the source dtype system is not supported or the type is not supported in Torch-TensorRT,
then an exception will be raised. As such it is not recommended to use this method directly.
Alternatively use ``torch_tensorrt.dtype.try_from()``
Arguments:
t (Union(torch.dtype, tensorrt.DataType, numpy.dtype, dtype)): Data type enum from another library
use_default (bool): In some cases a catch all type (such as ``torch_tensorrt.dtype.f32``) is sufficient, so instead of throwing an exception, return default value.
Returns:
dtype: Equivalent ``torch_tensorrt.dtype`` to ``t``
Raises:
TypeError: Unsupported data type or unknown source
Examples:
.. code:: py
# Succeeds
float_dtype = torch_tensorrt.dtype._from(torch.float) # Returns torch_tensorrt.dtype.f32
# Throws exception
float_dtype = torch_tensorrt.dtype._from(torch.complex128)
"""
# TODO: Ideally implemented with match statement but need to wait for Py39 EoL
if isinstance(t, torch.dtype):
if t == torch.uint8:
return dtype.u8
elif t == torch.int8:
return dtype.i8
elif t == torch.long:
return dtype.i64
elif t == torch.int32:
return dtype.i32
elif t == torch.float8_e4m3fn:
return dtype.f8
elif t == torch.half:
return dtype.f16
elif t == torch.float:
return dtype.f32
elif t == torch.float64:
return dtype.f64
elif t == torch.bool:
return dtype.b
elif t == torch.bfloat16:
return dtype.bf16
elif use_default:
logging.warning(
f"Given dtype that does not have direct mapping to Torch-TensorRT supported types ({t}), defaulting to torch_tensorrt.dtype.float"
)
return dtype.float
else:
raise TypeError(
f"Provided an unsupported data type as a data type for translation (support: bool, int, long, half, float, bfloat16), got: {t}"
)
elif isinstance(t, trt.DataType):
if t == trt.DataType.UINT8:
return dtype.u8
elif t == trt.DataType.INT8:
return dtype.i8
elif t == trt.DataType.FP8:
return dtype.f8
elif t == trt.DataType.INT32:
return dtype.i32
elif t == trt.DataType.INT64:
return dtype.i64
elif t == trt.DataType.HALF:
return dtype.f16
elif t == trt.DataType.FLOAT:
return dtype.f32
elif t == trt.DataType.BOOL:
return dtype.b
elif t == trt.DataType.BF16:
return dtype.bf16
else:
raise TypeError(
f"Provided an unsupported data type as a data type for translation (support: bool, int, half, float, bfloat16), got: {t}"
)
elif dtype._is_np_obj(t):
if t == np.uint8:
return dtype.u8
elif t == np.int8:
return dtype.i8
elif t == np.int32:
return dtype.i32
elif t == np.int64:
return dtype.i64
elif t == np.float16:
return dtype.f16
elif t == np.float32:
return dtype.f32
elif t == np.float64:
return dtype.f64
elif t == np.bool_:
return dtype.b
# TODO: Consider using ml_dtypes when issues like this are resolved:
# https://github.com/pytorch/pytorch/issues/109873
# elif t == ml_dtypes.bfloat16:
# return dtype.bf16
elif use_default:
logging.warning(
f"Given dtype that does not have direct mapping to Torch-TensorRT supported types ({t}), defaulting to torch_tensorrt.dtype.float"
)
return dtype.float
else:
raise TypeError(
"Provided an unsupported data type as an input data type (support: bool, int, long, half, float, bfloat16), got: "
+ str(t)
)
elif isinstance(t, dtype):
return t
elif ENABLED_FEATURES.torchscript_frontend:
from torch_tensorrt import _C
if isinstance(t, _C.dtype):
if t == _C.dtype.long:
return dtype.i64
elif t == _C.dtype.int32:
return dtype.i32
elif t == _C.dtype.int8:
return dtype.i8
elif t == _C.dtype.half:
return dtype.f16
elif t == _C.dtype.float:
return dtype.f32
elif t == _C.dtype.double:
return dtype.f64
elif t == _C.dtype.bool:
return dtype.b
elif t == _C.dtype.unknown:
return dtype.unknown
else:
raise TypeError(
f"Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: {t}"
)
# else: # commented out for mypy
raise TypeError(
f"Provided unsupported source type for dtype conversion (got: {t})"
)
[docs] @classmethod
def try_from(
cls,
t: Union[torch.dtype, trt.DataType, np.dtype, dtype],
use_default: bool = False,
) -> Optional[dtype]:
"""Create a Torch-TensorRT dtype from another library's dtype system.
Takes a dtype enum from one of numpy, torch, and tensorrt and create a ``torch_tensorrt.dtype``.
If the source dtype system is not supported or the type is not supported in Torch-TensorRT,
then returns ``None``.
Arguments:
t (Union(torch.dtype, tensorrt.DataType, numpy.dtype, dtype)): Data type enum from another library
use_default (bool): In some cases a catch all type (such as ``torch_tensorrt.dtype.f32``) is sufficient, so instead of throwing an exception, return default value.
Returns:
Optional(dtype): Equivalent ``torch_tensorrt.dtype`` to ``t`` or ``None``
Examples:
.. code:: py
# Succeeds
float_dtype = torch_tensorrt.dtype.try_from(torch.float) # Returns torch_tensorrt.dtype.f32
# Unsupported type
float_dtype = torch_tensorrt.dtype.try_from(torch.complex128) # Returns None
"""
try:
casted_format = dtype._from(t, use_default=use_default)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(
f"Conversion from {t} to torch_tensorrt.dtype failed", exc_info=True
)
return None
[docs] def to(
self,
t: Union[Type[torch.dtype], Type[trt.DataType], Type[np.dtype], Type[dtype]],
use_default: bool = False,
) -> Union[torch.dtype, trt.DataType, np.dtype, dtype]:
"""Convert dtype into the equivalent type in [torch, numpy, tensorrt]
Converts ``self`` into one of numpy, torch, and tensorrt equivalent dtypes.
If ``self`` is not supported in the target library, then an exception will be raised.
As such it is not recommended to use this method directly.
Alternatively use ``torch_tensorrt.dtype.try_to()``
Arguments:
t (Union(Type(torch.dtype), Type(tensorrt.DataType), Type(numpy.dtype), Type(dtype))): Data type enum from another library to convert to
use_default (bool): In some cases a catch all type (such as ``torch.float``) is sufficient, so instead of throwing an exception, return default value.
Returns:
Union(torch.dtype, tensorrt.DataType, numpy.dtype, dtype): dtype equivalent ``torch_tensorrt.dtype`` from library enum ``t``
Raises:
TypeError: Unsupported data type or unknown target
Examples:
.. code:: py
# Succeeds
float_dtype = torch_tensorrt.dtype.f32.to(torch.dtype) # Returns torch.float
# Failure
float_dtype = torch_tensorrt.dtype.bf16.to(numpy.dtype) # Throws exception
"""
# TODO: Ideally implemented with match statement but need to wait for Py39 EoL
if t == torch.dtype:
if self == dtype.u8:
return torch.uint8
elif self == dtype.i8:
return torch.int8
elif self == dtype.i32:
return torch.int
elif self == dtype.i64:
return torch.long
elif self == dtype.f8:
return torch.float8_e4m3fn
elif self == dtype.f16:
return torch.half
elif self == dtype.f32:
return torch.float
elif self == dtype.f64:
return torch.double
elif self == dtype.b:
return torch.bool
elif self == dtype.bf16:
return torch.bfloat16
elif use_default:
logging.warning(
f"Given dtype that does not have direct mapping to torch ({self}), defaulting to torch.float"
)
return torch.float
else:
raise TypeError(f"Unsupported torch dtype (had: {self})")
elif t == trt.DataType:
if self == dtype.u8:
return trt.DataType.UINT8
if self == dtype.i8:
return trt.DataType.INT8
elif self == dtype.i32:
return trt.DataType.INT32
elif self == dtype.f8:
return trt.DataType.FP8
elif self == dtype.i64:
return trt.DataType.INT64
elif self == dtype.f16:
return trt.DataType.HALF
elif self == dtype.f32:
return trt.DataType.FLOAT
elif self == dtype.b:
return trt.DataType.BOOL
elif self == dtype.bf16:
return trt.DataType.BF16
elif use_default:
return trt.DataType.FLOAT
else:
raise TypeError("Unsupported tensorrt dtype")
elif t == np.dtype:
if self == dtype.u8:
return np.uint8
elif self == dtype.i8:
return np.int8
elif self == dtype.i32:
return np.int32
elif self == dtype.i64:
return np.int64
elif self == dtype.f16:
return np.float16
elif self == dtype.f32:
return np.float32
elif self == dtype.f64:
return np.float64
elif self == dtype.b:
return np.bool_
# TODO: Consider using ml_dtypes when issues like this are resolved:
# https://github.com/pytorch/pytorch/issues/109873
# elif self == dtype.bf16:
# return ml_dtypes.bfloat16
elif use_default:
return np.float32
else:
raise TypeError("Unsupported numpy dtype")
elif t == dtype:
return self
elif ENABLED_FEATURES.torchscript_frontend:
from torch_tensorrt import _C
if t == _C.dtype:
if self == dtype.i64:
return _C.dtype.long
elif self == dtype.i8:
return _C.dtype.int8
elif self == dtype.i32:
return _C.dtype.int32
elif self == dtype.f16:
return _C.dtype.half
elif self == dtype.f32:
return _C.dtype.float
elif self == dtype.f64:
return _C.dtype.double
elif self == dtype.b:
return _C.dtype.bool
elif self == dtype.unknown:
return _C.dtype.unknown
else:
raise TypeError(
f"Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: {self}"
)
# else: # commented out for mypy
raise TypeError(
f"Provided unsupported destination type for dtype conversion {t}"
)
[docs] def try_to(
self,
t: Union[Type[torch.dtype], Type[trt.DataType], Type[np.dtype], Type[dtype]],
use_default: bool,
) -> Optional[Union[torch.dtype, trt.DataType, np.dtype, dtype]]:
"""Convert dtype into the equivalent type in [torch, numpy, tensorrt]
Converts ``self`` into one of numpy, torch, and tensorrt equivalent dtypes.
If ``self`` is not supported in the target library, then returns ``None``.
Arguments:
t (Union(Type(torch.dtype), Type(tensorrt.DataType), Type(numpy.dtype), Type(dtype))): Data type enum from another library to convert to
use_default (bool): In some cases a catch all type (such as ``torch.float``) is sufficient, so instead of throwing an exception, return default value.
Returns:
Optional(Union(torch.dtype, tensorrt.DataType, numpy.dtype, dtype)): dtype equivalent ``torch_tensorrt.dtype`` from library enum ``t``
Examples:
.. code:: py
# Succeeds
float_dtype = torch_tensorrt.dtype.f32.to(torch.dtype) # Returns torch.float
# Failure
float_dtype = torch_tensorrt.dtype.bf16.to(numpy.dtype) # Returns None
"""
try:
casted_format = self.to(t, use_default)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(
f"torch_tensorrt.dtype conversion to target type {t} failed",
exc_info=True,
)
return None
def __eq__(self, other: Union[torch.dtype, trt.DataType, np.dtype, dtype]) -> bool:
other_ = dtype._from(other)
return bool(self.value == other_.value)
def __hash__(self) -> int:
return hash(self.value)
# Putting aliases here that mess with mypy
bool = b
int = i32
[docs]class memory_format(Enum):
""""""
# TensorRT supported memory layouts
linear = auto()
"""Row major linear format.
For a tensor with dimensions {N, C, H, W}, the W axis always has unit stride, and the stride of every other axis is at least the product of the next dimension times the next stride. the strides are the same as for a C array with dimensions [N][C][H][W].
Equivient to ``memory_format.contiguous``
:meta hide-value:
"""
chw2 = auto()
"""Two wide channel vectorized row major format.
This format is bound to FP16 in TensorRT. It is only available for dimensions >= 3.
For a tensor with dimensions {N, C, H, W}, the memory layout is equivalent to a C array with dimensions [N][(C+1)/2][H][W][2], with the tensor coordinates (n, c, h, w) mapping to array subscript [n][c/2][h][w][c%2].
:meta hide-value:
"""
hwc8 = auto()
"""Eight channel format where C is padded to a multiple of 8.
This format is bound to FP16. It is only available for dimensions >= 3.
For a tensor with dimensions {N, C, H, W}, the memory layout is equivalent to the array with dimensions [N][H][W][(C+7)/8*8], with the tensor coordinates (n, c, h, w) mapping to array subscript [n][h][w][c].
:meta hide-value:
"""
chw4 = auto()
"""Four wide channel vectorized row major format. This format is bound to INT8. It is only available for dimensions >= 3.
For a tensor with dimensions {N, C, H, W}, the memory layout is equivalent to a C array with dimensions [N][(C+3)/4][H][W][4], with the tensor coordinates (n, c, h, w) mapping to array subscript [n][c/4][h][w][c%4].
:meta hide-value:
"""
chw16 = auto()
"""Sixteen wide channel vectorized row major format.
This format is bound to FP16. It is only available for dimensions >= 3.
For a tensor with dimensions {N, C, H, W}, the memory layout is equivalent to a C array with dimensions [N][(C+15)/16][H][W][16], with the tensor coordinates (n, c, h, w) mapping to array subscript [n][c/16][h][w][c%16].
:meta hide-value:
"""
chw32 = auto()
"""Thirty-two wide channel vectorized row major format.
This format is only available for dimensions >= 3.
For a tensor with dimensions {N, C, H, W}, the memory layout is equivalent to a C array with dimensions [N][(C+31)/32][H][W][32], with the tensor coordinates (n, c, h, w) mapping to array subscript [n][c/32][h][w][c%32].
:meta hide-value:
"""
dhwc8 = auto()
"""Eight channel format where C is padded to a multiple of 8.
This format is bound to FP16, and it is only available for dimensions >= 4.
For a tensor with dimensions {N, C, D, H, W}, the memory layout is equivalent to an array with dimensions [N][D][H][W][(C+7)/8*8], with the tensor coordinates (n, c, d, h, w) mapping to array subscript [n][d][h][w][c].
:meta hide-value:
"""
cdhw32 = auto()
"""Thirty-two wide channel vectorized row major format with 3 spatial dimensions.
This format is bound to FP16 and INT8. It is only available for dimensions >= 4.
For a tensor with dimensions {N, C, D, H, W}, the memory layout is equivalent to a C array with dimensions [N][(C+31)/32][D][H][W][32], with the tensor coordinates (n, d, c, h, w) mapping to array subscript [n][c/32][d][h][w][c%32].
:meta hide-value:
"""
hwc = auto()
"""Non-vectorized channel-last format. This format is bound to FP32 and is only available for dimensions >= 3.
Equivient to ``memory_format.channels_last``
:meta hide-value:
"""
dla_linear = auto()
""" DLA planar format. Row major format. The stride for stepping along the H axis is rounded up to 64 bytes.
This format is bound to FP16/Int8 and is only available for dimensions >= 3.
For a tensor with dimensions {N, C, H, W}, the memory layout is equivalent to a C array with dimensions [N][C][H][roundUp(W, 64/elementSize)] where elementSize is 2 for FP16 and 1 for Int8, with the tensor coordinates (n, c, h, w) mapping to array subscript [n][c][h][w].
:meta hide-value:
"""
dla_hwc4 = auto()
"""DLA image format. channel-last format. C can only be 1, 3, 4. If C == 3 it will be rounded to 4. The stride for stepping along the H axis is rounded up to 32 bytes.
This format is bound to FP16/Int8 and is only available for dimensions >= 3.
For a tensor with dimensions {N, C, H, W}, with C’ is 1, 4, 4 when C is 1, 3, 4 respectively, the memory layout is equivalent to a C array with dimensions [N][H][roundUp(W, 32/C’/elementSize)][C’] where elementSize is 2 for FP16 and 1 for Int8, C’ is the rounded C. The tensor coordinates (n, c, h, w) maps to array subscript [n][h][w][c].
:meta hide-value:
"""
hwc16 = auto()
"""Sixteen channel format where C is padded to a multiple of 16. This format is bound to FP16. It is only available for dimensions >= 3.
For a tensor with dimensions {N, C, H, W}, the memory layout is equivalent to the array with dimensions [N][H][W][(C+15)/16*16], with the tensor coordinates (n, c, h, w) mapping to array subscript [n][h][w][c].
:meta hide-value:
"""
dhwc = auto()
"""Non-vectorized channel-last format. This format is bound to FP32. It is only available for dimensions >= 4.
Equivient to ``memory_format.channels_last_3d``
:meta hide-value:
"""
# PyTorch aliases for TRT layouts
contiguous = linear
channels_last = hwc
channels_last_3d = dhwc
@classmethod
def _from(
cls, f: Union[torch.memory_format, trt.TensorFormat, memory_format]
) -> memory_format:
"""Create a Torch-TensorRT memory format enum from another library memory format enum.
Takes a memory format enum from one of torch, and tensorrt and create a ``torch_tensorrt.memory_format``.
If the source is not supported or the memory format is not supported in Torch-TensorRT,
then an exception will be raised. As such it is not recommended to use this method directly.
Alternatively use ``torch_tensorrt.memory_format.try_from()``
Arguments:
f (Union(torch.memory_format, tensorrt.TensorFormat, memory_format)): Memory format enum from another library
Returns:
memory_format: Equivalent ``torch_tensorrt.memory_format`` to ``f``
Raises:
TypeError: Unsupported memory format or unknown source
Examples:
.. code:: py
torchtrt_linear = torch_tensorrt.memory_format._from(torch.contiguous)
"""
# TODO: Ideally implemented with match statement but need to wait for Py39 EoL
if isinstance(f, torch.memory_format):
if f == torch.contiguous_format:
return memory_format.contiguous
elif f == torch.channels_last:
return memory_format.channels_last
elif f == torch.channels_last_3d:
return memory_format.channels_last_3d
else:
raise TypeError(
f"Provided an unsupported memory format for tensor, got: {dtype}"
)
elif isinstance(f, trt.DataType):
if f == trt.TensorFormat.LINEAR:
return memory_format.linear
elif f == trt.TensorFormat.CHW2:
return memory_format.chw2
elif f == trt.TensorFormat.HWC8:
return memory_format.hwc8
elif f == trt.TensorFormat.CHW4:
return memory_format.chw4
elif f == trt.TensorFormat.CHW16:
return memory_format.chw16
elif f == trt.TensorFormat.CHW32:
return memory_format.chw32
elif f == trt.TensorFormat.DHWC8:
return memory_format.dhwc8
elif f == trt.TensorFormat.CDHW32:
return memory_format.cdhw32
elif f == trt.TensorFormat.HWC:
return memory_format.hwc
elif f == trt.TensorFormat.DLA_LINEAR:
return memory_format.dla_linear
elif f == trt.TensorFormat.DLA_HWC4:
return memory_format.dla_hwc4
elif f == trt.TensorFormat.HWC16:
return memory_format.hwc16
elif f == trt.TensorFormat.DHWC:
return memory_format.dhwc
else:
raise TypeError(
f"Provided an unsupported tensor format for tensor, got: {dtype}"
)
elif isinstance(f, memory_format):
return f
elif ENABLED_FEATURES.torchscript_frontend:
from torch_tensorrt import _C
if isinstance(f, _C.TensorFormat):
if f == _C.TensorFormat.contiguous:
return memory_format.contiguous
elif f == _C.TensorFormat.channels_last:
return memory_format.channels_last
else:
raise ValueError(
"Provided an unsupported tensor format (support: NCHW/contiguous_format, NHWC/channel_last)"
)
# else: # commented out for mypy
raise TypeError("Provided unsupported source type for memory_format conversion")
[docs] @classmethod
def try_from(
cls, f: Union[torch.memory_format, trt.TensorFormat, memory_format]
) -> Optional[memory_format]:
"""Create a Torch-TensorRT memory format enum from another library memory format enum.
Takes a memory format enum from one of torch, and tensorrt and create a ``torch_tensorrt.memory_format``.
If the source is not supported or the memory format is not supported in Torch-TensorRT,
then ``None`` will be returned.
Arguments:
f (Union(torch.memory_format, tensorrt.TensorFormat, memory_format)): Memory format enum from another library
Returns:
Optional(memory_format): Equivalent ``torch_tensorrt.memory_format`` to ``f``
Examples:
.. code:: py
torchtrt_linear = torch_tensorrt.memory_format.try_from(torch.contiguous)
"""
try:
casted_format = memory_format._from(f)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(
f"Conversion from {f} to torch_tensorrt.memory_format failed",
exc_info=True,
)
return None
[docs] def to(
self,
t: Union[
Type[torch.memory_format], Type[trt.TensorFormat], Type[memory_format]
],
) -> Union[torch.memory_format, trt.TensorFormat, memory_format]:
"""Convert ``memory_format`` into the equivalent type in torch or tensorrt
Converts ``self`` into one of torch or tensorrt equivalent memory format.
If ``self`` is not supported in the target library, then an exception will be raised.
As such it is not recommended to use this method directly.
Alternatively use ``torch_tensorrt.memory_format.try_to()``
Arguments:
t (Union(Type(torch.memory_format), Type(tensorrt.TensorFormat), Type(memory_format))): Memory format type enum from another library to convert to
Returns:
Union(torch.memory_format, tensorrt.TensorFormat, memory_format): Memory format equivalent ``torch_tensorrt.memory_format`` in enum ``t``
Raises:
TypeError: Unknown target type or unsupported memory format
Examples:
.. code:: py
# Succeeds
tf = torch_tensorrt.memory_format.linear.to(torch.dtype) # Returns torch.contiguous
"""
if t == torch.memory_format:
if self == memory_format.contiguous:
return torch.contiguous_format
elif self == memory_format.channels_last:
return torch.channels_last
elif self == memory_format.channels_last_3d:
return torch.channels_last_3d
else:
raise TypeError("Unsupported torch dtype")
elif t == trt.TensorFormat:
if self == memory_format.linear:
return trt.TensorFormat.LINEAR
elif self == memory_format.chw2:
return trt.TensorFormat.CHW2
elif self == memory_format.hwc8:
return trt.TensorFormat.HWC8
elif self == memory_format.chw4:
return trt.TensorFormat.CHW4
elif self == memory_format.chw16:
return trt.TensorFormat.CHW16
elif self == memory_format.chw32:
return trt.TensorFormat.CHW32
elif self == memory_format.dhwc8:
return trt.TensorFormat.DHWC8
elif self == memory_format.cdhw32:
return trt.TensorFormat.CDHW32
elif self == memory_format.hwc:
return trt.TensorFormat.HWC
elif self == memory_format.dla_linear:
return trt.TensorFormat.DLA_LINEAR
elif self == memory_format.dla_hwc4:
return trt.TensorFormat.DLA_HWC4
elif self == memory_format.hwc16:
return trt.TensorFormat.HWC16
elif self == memory_format.dhwc:
return trt.TensorFormat.DHWC
else:
raise TypeError("Unsupported tensorrt memory format")
elif t == memory_format:
return self
elif ENABLED_FEATURES.torchscript_frontend:
from torch_tensorrt import _C
if t == _C.TensorFormat:
if self == memory_format.contiguous:
return _C.TensorFormat.contiguous
elif self == memory_format.channels_last:
return _C.TensorFormat.channels_last
else:
raise ValueError(
"Provided an unsupported tensor format (support: NCHW/contiguous_format, NHWC/channel_last)"
)
# else: # commented out for mypy
raise TypeError(
"Provided unsupported destination type for memory format conversion"
)
[docs] def try_to(
self,
t: Union[
Type[torch.memory_format], Type[trt.TensorFormat], Type[memory_format]
],
) -> Optional[Union[torch.memory_format, trt.TensorFormat, memory_format]]:
"""Convert ``memory_format`` into the equivalent type in torch or tensorrt
Converts ``self`` into one of torch or tensorrt equivalent memory format.
If ``self`` is not supported in the target library, then ``None`` will be returned
Arguments:
t (Union(Type(torch.memory_format), Type(tensorrt.TensorFormat), Type(memory_format))): Memory format type enum from another library to convert to
Returns:
Optional(Union(torch.memory_format, tensorrt.TensorFormat, memory_format)): Memory format equivalent ``torch_tensorrt.memory_format`` in enum ``t``
Examples:
.. code:: py
# Succeeds
tf = torch_tensorrt.memory_format.linear.to(torch.dtype) # Returns torch.contiguous
"""
try:
casted_format = self.to(t)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(
f"torch_tensorrt.memory_format conversion to target type {t} failed",
exc_info=True,
)
return None
def __eq__(
self, other: Union[torch.memory_format, trt.TensorFormat, memory_format]
) -> bool:
other_ = memory_format._from(other)
return self.value == other_.value
def __hash__(self) -> int:
return hash(self.value)
[docs]class DeviceType(Enum):
"""Type of device TensorRT will target"""
UNKNOWN = auto()
"""
Sentinel value
:meta hide-value:
"""
GPU = auto()
"""
Target is a GPU
:meta hide-value:
"""
DLA = auto()
"""
Target is a DLA core
:meta hide-value:
"""
@classmethod
def _from(cls, d: Union[trt.DeviceType, DeviceType]) -> DeviceType:
"""Create a Torch-TensorRT device type enum from a TensorRT device type enum.
Takes a device type enum from tensorrt and create a ``torch_tensorrt.DeviceType``.
If the source is not supported or the device type is not supported in Torch-TensorRT,
then an exception will be raised. As such it is not recommended to use this method directly.
Alternatively use ``torch_tensorrt.DeviceType.try_from()``
Arguments:
d (Union(tensorrt.DeviceType, DeviceType)): Device type enum from another library
Returns:
DeviceType: Equivalent ``torch_tensorrt.DeviceType`` to ``d``
Raises:
TypeError: Unknown source type or unsupported device type
Examples:
.. code:: py
torchtrt_dla = torch_tensorrt.DeviceType._from(tensorrt.DeviceType.DLA)
"""
if isinstance(d, trt.DeviceType):
if d == trt.DeviceType.GPU:
return DeviceType.GPU
elif d == trt.DeviceType.DLA:
return DeviceType.DLA
else:
raise ValueError(
"Provided an unsupported device type (support: GPU/DLA)"
)
elif isinstance(d, DeviceType):
return d
elif ENABLED_FEATURES.torchscript_frontend:
from torch_tensorrt import _C
if isinstance(d, _C.DeviceType):
if d == _C.DeviceType.GPU:
return DeviceType.GPU
elif d == _C.DeviceType.DLA:
return DeviceType.DLA
else:
raise ValueError(
"Provided an unsupported device type (support: GPU/DLA)"
)
# else: # commented out for mypy
raise TypeError("Provided unsupported source type for DeviceType conversion")
[docs] @classmethod
def try_from(cls, d: Union[trt.DeviceType, DeviceType]) -> Optional[DeviceType]:
"""Create a Torch-TensorRT device type enum from a TensorRT device type enum.
Takes a device type enum from tensorrt and create a ``torch_tensorrt.DeviceType``.
If the source is not supported or the device type is not supported in Torch-TensorRT,
then an exception will be raised. As such it is not recommended to use this method directly.
Alternatively use ``torch_tensorrt.DeviceType.try_from()``
Arguments:
d (Union(tensorrt.DeviceType, DeviceType)): Device type enum from another library
Returns:
DeviceType: Equivalent ``torch_tensorrt.DeviceType`` to ``d``
Examples:
.. code:: py
torchtrt_dla = torch_tensorrt.DeviceType._from(tensorrt.DeviceType.DLA)
"""
try:
casted_format = DeviceType._from(d)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(
f"Conversion from {d} to torch_tensorrt.DeviceType failed",
exc_info=True,
)
return None
[docs] def to(
self,
t: Union[Type[trt.DeviceType], Type[DeviceType]],
use_default: bool = False,
) -> Union[trt.DeviceType, DeviceType]:
"""Convert ``DeviceType`` into the equivalent type in tensorrt
Converts ``self`` into one of torch or tensorrt equivalent device type.
If ``self`` is not supported in the target library, then an exception will be raised.
As such it is not recommended to use this method directly.
Alternatively use ``torch_tensorrt.DeviceType.try_to()``
Arguments:
t (Union(Type(tensorrt.DeviceType), Type(DeviceType))): Device type enum from another library to convert to
Returns:
Union(tensorrt.DeviceType, DeviceType): Device type equivalent ``torch_tensorrt.DeviceType`` in enum ``t``
Raises:
TypeError: Unknown target type or unsupported device type
Examples:
.. code:: py
# Succeeds
trt_dla = torch_tensorrt.DeviceType.DLA.to(tensorrt.DeviceType) # Returns tensorrt.DeviceType.DLA
"""
if t == trt.DeviceType:
if self == DeviceType.GPU:
return trt.DeviceType.GPU
elif self == DeviceType.DLA:
return trt.DeviceType.DLA
elif use_default:
return trt.DeviceType.GPU
else:
raise ValueError(
"Provided an unsupported device type (support: GPU/DLA)"
)
elif t == DeviceType:
return self
elif ENABLED_FEATURES.torchscript_frontend:
from torch_tensorrt import _C
if t == _C.DeviceType:
if self == DeviceType.GPU:
return _C.DeviceType.GPU
elif self == DeviceType.DLA:
return _C.DeviceType.DLA
else:
raise ValueError(
"Provided an unsupported device type (support: GPU/DLA)"
)
# else: # commented out for mypy
raise TypeError(
"Provided unsupported destination type for device type conversion"
)
[docs] def try_to(
self,
t: Union[Type[trt.DeviceType], Type[DeviceType]],
use_default: bool = False,
) -> Optional[Union[trt.DeviceType, DeviceType]]:
"""Convert ``DeviceType`` into the equivalent type in tensorrt
Converts ``self`` into one of torch or tensorrt equivalent memory format.
If ``self`` is not supported in the target library, then ``None`` will be returned.
Arguments:
t (Union(Type(tensorrt.DeviceType), Type(DeviceType))): Device type enum from another library to convert to
Returns:
Optional(Union(tensorrt.DeviceType, DeviceType)): Device type equivalent ``torch_tensorrt.DeviceType`` in enum ``t``
Examples:
.. code:: py
# Succeeds
trt_dla = torch_tensorrt.DeviceType.DLA.to(tensorrt.DeviceType) # Returns tensorrt.DeviceType.DLA
"""
try:
casted_format = self.to(t, use_default=use_default)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(
f"torch_tensorrt.DeviceType conversion to target type {t} failed",
exc_info=True,
)
return None
def __eq__(self, other: Union[trt.DeviceType, DeviceType]) -> bool:
other_ = DeviceType._from(other)
return bool(self.value == other_.value)
def __hash__(self) -> int:
return hash(self.value)
[docs]class EngineCapability(Enum):
"""
EngineCapability determines the restrictions of a network during build time and what runtime it targets.
"""
STANDARD = auto()
"""
EngineCapability.STANDARD does not provide any restrictions on functionality and the resulting serialized engine can be executed with TensorRT’s standard runtime APIs.
:meta hide-value:
"""
SAFETY = auto()
"""
EngineCapability.SAFETY provides a restricted subset of network operations that are safety certified and the resulting serialized engine can be executed with TensorRT’s safe runtime APIs in the tensorrt.safe namespace.
:meta hide-value:
"""
DLA_STANDALONE = auto()
"""
``EngineCapability.DLA_STANDALONE`` provides a restricted subset of network operations that are DLA compatible and the resulting serialized engine can be executed using standalone DLA runtime APIs.
:meta hide-value:
"""
@classmethod
def _from(
cls, c: Union[trt.EngineCapability, EngineCapability]
) -> EngineCapability:
"""Create a Torch-TensorRT Engine capability enum from a TensorRT Engine capability enum.
Takes a device type enum from tensorrt and create a ``torch_tensorrt.EngineCapability``.
If the source is not supported or the engine capability is not supported in Torch-TensorRT,
then an exception will be raised. As such it is not recommended to use this method directly.
Alternatively use ``torch_tensorrt.EngineCapability.try_from()``
Arguments:
c (Union(tensorrt.EngineCapability, EngineCapability)): Engine capability enum from another library
Returns:
EngineCapability: Equivalent ``torch_tensorrt.EngineCapability`` to ``c``
Raises:
TypeError: Unknown source type or unsupported engine capability
Examples:
.. code:: py
torchtrt_ec = torch_tensorrt.EngineCapability._from(tensorrt.EngineCapability.SAFETY)
"""
if isinstance(c, trt.EngineCapability):
if c == trt.EngineCapability.STANDARD:
return EngineCapability.STANDARD
elif c == trt.EngineCapability.SAFETY:
return EngineCapability.SAFETY
elif c == trt.EngineCapability.DLA_STANDALONE:
return EngineCapability.DLA_STANDALONE
else:
raise ValueError("Provided an unsupported engine capability")
elif isinstance(c, EngineCapability):
return c
elif ENABLED_FEATURES.torchscript_frontend:
from torch_tensorrt import _C
if isinstance(c, _C.EngineCapability):
if c == _C.EngineCapability.STANDARD:
return EngineCapability.STANDARD
elif c == _C.EngineCapability.SAFETY:
return EngineCapability.SAFETY
elif c == _C.EngineCapability.DLA_STANDALONE:
return EngineCapability.DLA_STANDALONE
else:
raise ValueError("Provided an unsupported engine capability")
# else: # commented out for mypy
raise TypeError(
"Provided unsupported source type for EngineCapability conversion"
)
[docs] @classmethod
def try_from(
c: Union[trt.EngineCapability, EngineCapability]
) -> Optional[EngineCapability]:
"""Create a Torch-TensorRT engine capability enum from a TensorRT engine capability enum.
Takes a device type enum from tensorrt and create a ``torch_tensorrt.EngineCapability``.
If the source is not supported or the engine capability level is not supported in Torch-TensorRT,
then an exception will be raised. As such it is not recommended to use this method directly.
Alternatively use ``torch_tensorrt.EngineCapability.try_from()``
Arguments:
c (Union(tensorrt.EngineCapability, EngineCapability)): Engine capability enum from another library
Returns:
EngineCapability: Equivalent ``torch_tensorrt.EngineCapability`` to ``c``
Examples:
.. code:: py
torchtrt_safety_ec = torch_tensorrt.EngineCapability._from(tensorrt.EngineCapability.SAEFTY)
"""
try:
casted_format = EngineCapability._from(c)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(
f"Conversion from {c} to torch_tensorrt.EngineCapablity failed",
exc_info=True,
)
return None
[docs] def to(
self, t: Union[Type[trt.EngineCapability], Type[EngineCapability]]
) -> Union[trt.EngineCapability, EngineCapability]:
"""Convert ``EngineCapability`` into the equivalent type in tensorrt
Converts ``self`` into one of torch or tensorrt equivalent engine capability.
If ``self`` is not supported in the target library, then an exception will be raised.
As such it is not recommended to use this method directly.
Alternatively use ``torch_tensorrt.EngineCapability.try_to()``
Arguments:
t (Union(Type(tensorrt.EngineCapability), Type(EngineCapability))): Engine capability enum from another library to convert to
Returns:
Union(tensorrt.EngineCapability, EngineCapability): Engine capability equivalent ``torch_tensorrt.EngineCapability`` in enum ``t``
Raises:
TypeError: Unknown target type or unsupported engine capability
Examples:
.. code:: py
# Succeeds
torchtrt_dla_ec = torch_tensorrt.EngineCapability.DLA_STANDALONE.to(tensorrt.EngineCapability) # Returns tensorrt.EngineCapability.DLA
"""
if t == trt.EngineCapability:
if self == EngineCapability.STANDARD:
return trt.EngineCapability.STANDARD
elif self == EngineCapability.SAFETY:
return trt.EngineCapability.SAFETY
elif self == EngineCapability.DLA_STANDALONE:
return trt.EngineCapability.DLA_STANDALONE
else:
raise ValueError("Provided an unsupported engine capability")
elif t == EngineCapability:
return self
elif ENABLED_FEATURES.torchscript_frontend:
from torch_tensorrt import _C
if t == _C.EngineCapability:
if self == EngineCapability.STANDARD:
return _C.EngineCapability.STANDARD
elif self == EngineCapability.SAFETY:
return _C.EngineCapability.SAFETY
elif self == EngineCapability.DLA_STANDALONE:
return _C.EngineCapability.DLA_STANDALONE
else:
raise ValueError("Provided an unsupported engine capability")
# else: # commented out for mypy
raise TypeError(
"Provided unsupported destination type for engine capability type conversion"
)
[docs] def try_to(
self, t: Union[Type[trt.EngineCapability], Type[EngineCapability]]
) -> Optional[Union[trt.EngineCapability, EngineCapability]]:
"""Convert ``EngineCapability`` into the equivalent type in tensorrt
Converts ``self`` into one of torch or tensorrt equivalent engine capability.
If ``self`` is not supported in the target library, then ``None`` will be returned.
Arguments:
t (Union(Type(tensorrt.EngineCapability), Type(EngineCapability))): Engine capability enum from another library to convert to
Returns:
Optional(Union(tensorrt.EngineCapability, EngineCapability)): Engine capability equivalent ``torch_tensorrt.EngineCapability`` in enum ``t``
Examples:
.. code:: py
# Succeeds
trt_dla_ec = torch_tensorrt.EngineCapability.DLA.to(tensorrt.EngineCapability) # Returns tensorrt.EngineCapability.DLA_STANDALONE
"""
try:
casted_format = self.to(t)
return casted_format
except (ValueError, TypeError) as e:
logging.debug(
f"torch_tensorrt.EngineCapablity conversion to target type {t} failed",
exc_info=True,
)
return None
def __eq__(self, other: Union[trt.EngineCapability, EngineCapability]) -> bool:
other_ = EngineCapability._from(other)
return bool(self.value == other_.value)
def __hash__(self) -> int:
return hash(self.value)
class Platform(Enum):
"""
Specifies a target OS and CPU architecture that a Torch-TensorRT program targets
"""
LINUX_X86_64 = auto()
"""
OS: Linux, CPU Arch: x86_64
:meta hide-value:
"""
LINUX_AARCH64 = auto()
"""
OS: Linux, CPU Arch: aarch64
:meta hide-value:
"""
WIN_X86_64 = auto()
"""
OS: Windows, CPU Arch: x86_64
:meta hide-value:
"""
UNKNOWN = auto()
@classmethod
def current_platform(cls) -> Platform:
"""
Returns an enum for the current platform Torch-TensorRT is running on
Returns:
Platform: Current platform
"""
import platform
if platform.system().lower().startswith("linux"):
# linux
if platform.machine().lower().startswith("aarch64"):
return Platform.LINUX_AARCH64
elif platform.machine().lower().startswith("x86_64"):
return Platform.LINUX_X86_64
elif platform.system().lower().startswith("windows"):
# Windows...
if platform.machine().lower().startswith("amd64"):
return Platform.WIN_X86_64
return Platform.UNKNOWN
def __str__(self) -> str:
return str(self.name)
@needs_torch_tensorrt_runtime # type: ignore
def _to_serialized_rt_platform(self) -> str:
val: str = torch.ops.tensorrt._platform_unknown()
if self == Platform.LINUX_X86_64:
val = torch.ops.tensorrt._platform_linux_x86_64()
elif self == Platform.LINUX_AARCH64:
val = torch.ops.tensorrt._platform_linux_aarch64()
elif self == Platform.WIN_X86_64:
val = torch.ops.tensorrt._platform_win_x86_64()
return val