torch_tensorrt¶
Functions¶
- torch_tensorrt.compile(module: Any, ir: str = 'default', inputs: Optional[Sequence[Input | torch.Tensor | InputTensorSpec]] = None, arg_inputs: Optional[Sequence[Sequence[Any]]] = None, kwarg_inputs: Optional[dict[Any, Any]] = None, enabled_precisions: Optional[Set[Union[dtype, dtype]]] = None, **kwargs: Any) Union[Module, ScriptModule, GraphModule, Callable[[...], Any]] [source]¶
Compile a PyTorch module for NVIDIA GPUs using TensorRT
Takes a existing PyTorch module and a set of settings to configure the compiler and using the path specified in
ir
lower and compile the module to TensorRT returning a PyTorch Module backConverts specifically the forward method of a Module
- Parameters
module (Union(torch.nn.Module,torch.jit.ScriptModule) – Source module
- Keyword Arguments
inputs (List[Union(Input, torch.Tensor)]) –
Required List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type.
inputs=[ torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1 torch_tensorrt.Input( min_shape=(1, 224, 224, 3), opt_shape=(1, 512, 512, 3), max_shape=(1, 1024, 1024, 3), dtype=torch.int32 format=torch.channel_last ), # Dynamic input shape for input #2 torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings ]
arg_inputs (Tuple[Any, ...]) – Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[Any, ...]) – Optional, kwarg inputs to the module forward function.
enabled_precision (Set(Union(torch.dpython:type, torch_tensorrt.dpython:type))) – The set of datatypes that TensorRT can use when selecting kernels
ir (str) – The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
**kwargs – Additional settings for the specific requested strategy (See submodules for more info)
- Returns
Compiled Module, when run it will execute via TensorRT
- Return type
torch.nn.Module
- torch_tensorrt.convert_method_to_trt_engine(module: Any, method_name: str = 'forward', inputs: Optional[Sequence[Input | torch.Tensor | InputTensorSpec]] = None, arg_inputs: Optional[Sequence[Sequence[Any]]] = None, kwarg_inputs: Optional[dict[Any, Any]] = None, ir: str = 'default', enabled_precisions: Optional[Set[Union[dtype, dtype]]] = None, **kwargs: Any) bytes [source]¶
Convert a TorchScript module method to a serialized TensorRT engine
Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings
- Parameters
module (Union(torch.nn.Module,torch.jit.ScriptModule) – Source module
- Keyword Arguments
inputs (List[Union(Input, torch.Tensor)]) –
Required List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type.
input=[ torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1 torch_tensorrt.Input( min_shape=(1, 224, 224, 3), opt_shape=(1, 512, 512, 3), max_shape=(1, 1024, 1024, 3), dtype=torch.int32 format=torch.channel_last ), # Dynamic input shape for input #2 torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings ]
arg_inputs (Tuple[Any, ...]) – Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[Any, ...]) – Optional, kwarg inputs to the module forward function.
enabled_precision (Set(Union(torch.dpython:type, torch_tensorrt.dpython:type))) – The set of datatypes that TensorRT can use when selecting kernels
ir (str) – The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
**kwargs – Additional settings for the specific requested strategy (See submodules for more info)
- Returns
Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
- Return type
bytes
- torch_tensorrt.save(module: Any, file_path: str = '', *, output_format: str = 'exported_program', inputs: Optional[Sequence[Tensor]] = None, arg_inputs: Optional[Sequence[Tensor]] = None, kwarg_inputs: Optional[dict[str, Any]] = None, retrace: bool = False) None [source]¶
Save the model to disk in the specified output format.
- Parameters
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)) – Compiled Torch-TensorRT module
inputs (torch.Tensor) – Torch input tensors
arg_inputs (Tuple[Any, ...]) – Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[Any, ...]) – Optional, kwarg inputs to the module forward function.
output_format (str) – Format to save the model. Options include exported_program | torchscript.
retrace (bool) – When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it. This flag is experimental for now.
- torch_tensorrt.load(file_path: str = '') Any [source]¶
Load either a Torchscript model or ExportedProgram.
Loads a TorchScript or ExportedProgram file from disk. File type will be detect the type using try, except.
- Parameters
file_path (str) – Path to file on the disk
- Raises
ValueError – If there is no file or the file is not either a TorchScript file or ExportedProgram file
Classes¶
- class torch_tensorrt.MutableTorchTensorRTModule(pytorch_model: Module, *, device: Optional[Union[Device, device, str]] = None, disable_tf32: bool = False, assume_dynamic_shape_support: bool = False, sparse_weights: bool = False, enabled_precisions: Set[Union[dtype, dtype]] = {dtype.f32}, engine_capability: EngineCapability = EngineCapability.STANDARD, make_refittable: bool = False, debug: bool = False, num_avg_timing_iters: int = 1, workspace_size: int = 0, dla_sram_size: int = 1048576, dla_local_dram_size: int = 1073741824, dla_global_dram_size: int = 536870912, truncate_double: bool = False, require_full_compilation: bool = False, min_block_size: int = 5, torch_executed_ops: Optional[Collection[Union[Callable[[...], Any], str]]] = None, torch_executed_modules: Optional[List[str]] = None, pass_through_build_failures: bool = False, max_aux_streams: Optional[int] = None, version_compatible: bool = False, optimization_level: Optional[int] = None, use_python_runtime: bool = False, use_fast_partitioner: bool = True, enable_experimental_decompositions: bool = False, dryrun: bool = False, hardware_compatible: bool = False, timing_cache_path: str = '/tmp/torch_tensorrt_engine_cache/timing_cache.bin', **kwargs: Any)[source]¶
Initialize a MutableTorchTensorRTModule to seamlessly manipulate it like a regular PyTorch module. All TensorRT compilation and refitting processes are handled automatically as you work with the module. Any changes to its attributes or loading a different state_dict will trigger refitting or recompilation, which will be managed during the next forward pass.
The MutableTorchTensorRTModule takes a PyTorch module and a set of configuration settings for the compiler. Once compilation is complete, the module maintains the connection between the TensorRT graph module and the original PyTorch module. Any modifications made to the MutableTorchTensorRTModule will be reflected in both the TensorRT graph module and the original PyTorch module.
- __init__(pytorch_model: Module, *, device: Optional[Union[Device, device, str]] = None, disable_tf32: bool = False, assume_dynamic_shape_support: bool = False, sparse_weights: bool = False, enabled_precisions: Set[Union[dtype, dtype]] = {dtype.f32}, engine_capability: EngineCapability = EngineCapability.STANDARD, make_refittable: bool = False, debug: bool = False, num_avg_timing_iters: int = 1, workspace_size: int = 0, dla_sram_size: int = 1048576, dla_local_dram_size: int = 1073741824, dla_global_dram_size: int = 536870912, truncate_double: bool = False, require_full_compilation: bool = False, min_block_size: int = 5, torch_executed_ops: Optional[Collection[Union[Callable[[...], Any], str]]] = None, torch_executed_modules: Optional[List[str]] = None, pass_through_build_failures: bool = False, max_aux_streams: Optional[int] = None, version_compatible: bool = False, optimization_level: Optional[int] = None, use_python_runtime: bool = False, use_fast_partitioner: bool = True, enable_experimental_decompositions: bool = False, dryrun: bool = False, hardware_compatible: bool = False, timing_cache_path: str = '/tmp/torch_tensorrt_engine_cache/timing_cache.bin', **kwargs: Any) None [source]¶
- Parameters
pytorch_model (torch.nn.module) – Source module that needs to be accelerated
- Keyword Arguments
device (Union(Device, torch.device, dict)) –
Target device for TensorRT engines to run on
device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)
disable_tf32 (bool) – Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
assume_dynamic_shape_support (bool) – Setting this to true enables the converters work for both dynamic and static shapes. Default: False
sparse_weights (bool) – Enable sparsity for convolution and fully connected layers.
enabled_precision (Set(Union(torch.dpython:type, torch_tensorrt.dpython:type))) – The set of datatypes that TensorRT can use when selecting kernels
refit (bool) – Enable refitting
debug (bool) – Enable debuggable engine
capability (EngineCapability) – Restrict kernel selection to safe gpu kernels or safe dla kernels
num_avg_timing_iters (python:int) – Number of averaging timing iterations used to select kernels
workspace_size (python:int) – Maximum size of workspace given to TensorRT
dla_sram_size (python:int) – Fast software managed RAM used by DLA to communicate within a layer.
dla_local_dram_size (python:int) – Host RAM used by DLA to share intermediate tensor data across operations
dla_global_dram_size (python:int) – Host RAM used by DLA to store weights and metadata for execution
truncate_double (bool) – Truncate weights provided in double (float64) to float32
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)) – Calibrator object which will provide data to the PTQ system for INT8 Calibration
require_full_compilation (bool) – Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch
min_block_size (python:int) – The minimum number of contiguous TensorRT convertible operations in order to run a set of operations in TensorRT
torch_executed_ops (Collection[Target]) – Set of aten operators that must be run in PyTorch. An error will be thrown if this set is not empty but
require_full_compilation
is Truetorch_executed_modules (List[str]) – List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but
require_full_compilation
is Truepass_through_build_failures (bool) – Error out if there are issues during compilation (only applicable to torch.compile workflows)
max_aux_stream (Optional[python:int]) – Maximum streams in the engine
version_compatible (bool) – Build the TensorRT engines compatible with future versions of TensorRT (Restrict to lean runtime operators to provide version forward compatibility for the engines)
optimization_level – (Optional[int]): Setting a higher optimization level allows TensorRT to spend longer engine building time searching for more optimization options. The resulting engine may have better performance compared to an engine built with a lower optimization level. The default optimization level is 3. Valid values include integers from 0 to the maximum optimization level, which is currently 5. Setting it to be greater than the maximum level results in identical behavior to the maximum level.
use_python_runtime – (bool): Return a graph using a pure Python runtime, reduces options for serialization
use_fast_partitioner – (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optimal. Use the global paritioner (
False
) if looking for best performanceenable_experimental_decompositions (bool) – Use the full set of operator decompositions. These decompositions may not be tested but serve to make the graph easier to convert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
dryrun (bool) – Toggle for “Dryrun” mode, running everything except conversion to TRT and logging outputs
hardware_compatible (bool) – Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str) – Path to the timing cache if it exists (or) where it will be saved after compilation
lazy_engine_init (bool) – Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
**kwargs – Any,
- Returns
MutableTorchTensorRTModule
- compile() None [source]¶
(Re)compile the TRT graph module using the PyTorch module. This function should be called whenever the weight structure get changed (shape, more layers…) MutableTorchTensorRTModule automatically catches weight value updates and call this function to recompile. If it fails to catch the changes, please call this function manually to recompile the TRT graph module.
- refit_gm() None [source]¶
Refit the TRT graph module with any updates. This function should be called whenever the weight values get changed but the weight structure remains the same. MutableTorchTensorRTModule automatically catches weight value updates and call this function to refit the module. If it fails to catch the changes, please call this function manually to update the TRT graph module.
- class torch_tensorrt.Input(*args: Any, **kwargs: Any)[source]¶
Defines an input to a module in terms of expected shape, data type and tensor format.
- Variables
shape_mode (torch_tensorrt.Input._ShapeMode) – Is input statically or dynamically shaped
shape (Tuple or Dict) –
Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form
{"min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple}
dtype (torch_tensorrt.dpython:type) – The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
format (torch_tensorrt.TensorFormat) – The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
- __init__(*args: Any, **kwargs: Any) None [source]¶
__init__ Method for torch_tensorrt.Input
Input accepts one of a few construction patterns
- Parameters
shape (Tuple or List, optional) – Static shape of input tensor
- Keyword Arguments
shape (Tuple or List, optional) – Static shape of input tensor
min_shape (Tuple or List, optional) – Min size of input tensor’s shape range Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implicitly this sets Input’s shape_mode to DYNAMIC
opt_shape (Tuple or List, optional) – Opt size of input tensor’s shape range Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implicitly this sets Input’s shape_mode to DYNAMIC
max_shape (Tuple or List, optional) – Max size of input tensor’s shape range Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implicitly this sets Input’s shape_mode to DYNAMIC
dtype (torch.dpython:type or torch_tensorrt.dpython:type) – Expected data type for input tensor (default: torch_tensorrt.dtype.float32)
format (torch.memory_format or torch_tensorrt.TensorFormat) – The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
tensor_domain (Tuple(python:float, python:float), optional) – The domain of allowed values for the tensor, as interval notation: [tensor_domain[0], tensor_domain[1]). Note: Entering “None” (or not specifying) will set the bound to [0, 2)
torch_tensor (torch.Tensor) – Holds a corresponding torch tensor with this Input.
name (str, optional) – Name of this input in the input nn.Module’s forward function. Used to specify dynamic shapes for the corresponding input in dynamo tracer.
Examples
Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last)
Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW)
Input(min_shape=(1,3,32,32), opt_shape=[2,3,32,32], max_shape=(3,3,32,32)) #Implicitly dtype=torch_tensorrt.dtype.float32, format=torch_tensorrt.TensorFormat.NCHW
- example_tensor(optimization_profile_field: Optional[str] = None) Tensor [source]¶
Get an example tensor of the shape specified by the Input object
- Parameters
optimization_profile_field (Optional(str)) – Name of the field to use for shape in the case the Input is dynamically shaped
- Returns
A PyTorch Tensor
- classmethod from_tensor(t: Tensor, disable_memory_format_check: bool = False) Input [source]¶
Produce a Input which contains the information of the given PyTorch tensor.
- Parameters
tensor (torch.Tensor) – A PyTorch tensor.
disable_memory_format_check (bool) – Whether to validate the memory formats of input tensors
- Returns
A Input object.
- classmethod from_tensors(ts: Sequence[Tensor], disable_memory_format_check: bool = False) List[Input] [source]¶
Produce a list of Inputs which contain the information of all the given PyTorch tensors.
- Parameters
tensors (Iterable[torch.Tensor]) – A list of PyTorch tensors.
disable_memory_format_check (bool) – Whether to validate the memory formats of input tensors
- Returns
A list of Inputs.
- dtype: dtype = 1¶
torch_tensorrt.dtype.float32)
- Type
The expected data type of the input tensor (default
- format: memory_format = 1¶
torch_tensorrt.memory_format.linear)
- Type
The expected format of the input tensor (default
- class torch_tensorrt.Device(*args: Any, **kwargs: Any)[source]¶
Defines a device that can be used to specify target devices for engines
- Variables
device_type (DeviceType) – Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
gpu_id (python:int) – Device ID for target GPU
dla_core (python:int) – Core ID for target DLA core
allow_gpu_fallback (bool) – Whether falling back to GPU if DLA cannot support an op should be allowed
- __init__(*args: Any, **kwargs: Any)[source]¶
__init__ Method for torch_tensorrt.Device
Device accepts one of a few construction patterns
- Parameters
spec (str) – String with device spec e.g. “dla:0” for dla, core_id 0
- Keyword Arguments
gpu_id (python:int) – ID of target GPU (will get overridden if dla_core is specified to the GPU managing DLA). If specified, no positional arguments should be provided
dla_core (python:int) – ID of target DLA core. If specified, no positional arguments should be provided.
allow_gpu_fallback (bool) – Allow TensorRT to schedule operations on GPU if they are not supported on DLA (ignored if device type is not DLA)
Examples
Device(“gpu:1”)
Device(“cuda:1”)
Device(“dla:0”, allow_gpu_fallback=True)
Device(gpu_id=0, dla_core=0, allow_gpu_fallback=True)
Device(dla_core=0, allow_gpu_fallback=True)
Device(gpu_id=1)
- device_type: DeviceType = 1¶
Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
- dla_core: int = -1¶
Core ID for target DLA core
- gpu_id: int = -1¶
Device ID for target GPU
Enums¶
- class torch_tensorrt.dtype(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]¶
Enum to describe data types to Torch-TensorRT, has compatibility with torch, tensorrt and numpy dtypes
- to(t: Union[Type[dtype], Type[DataType], Type[dtype], Type[dtype]], use_default: bool = False) Union[dtype, DataType, dtype, dtype] [source]¶
Convert dtype into the equivalent type in [torch, numpy, tensorrt]
Converts
self
into one of numpy, torch, and tensorrt equivalent dtypes. Ifself
is not supported in the target library, then an exception will be raised. As such it is not recommended to use this method directly.Alternatively use
torch_tensorrt.dtype.try_to()
- Parameters
t (Union(Type(torch.dpython:type), Type(tensorrt.DataType), Type(numpy.dpython:type), Type(dpython:type))) – Data type enum from another library to convert to
use_default (bool) – In some cases a catch all type (such as
torch.float
) is sufficient, so instead of throwing an exception, return default value.
- Returns
dtype equivalent
torch_tensorrt.dtype
from library enumt
- Return type
Union(torch.dtype, tensorrt.DataType, numpy.dtype, dtype)
- Raises
TypeError – Unsupported data type or unknown target
Examples
# Succeeds float_dtype = torch_tensorrt.dtype.f32.to(torch.dtype) # Returns torch.float # Failure float_dtype = torch_tensorrt.dtype.bf16.to(numpy.dtype) # Throws exception
- classmethod try_from(t: Union[dtype, DataType, dtype, dtype], use_default: bool = False) Optional[dtype] [source]¶
Create a Torch-TensorRT dtype from another library’s dtype system.
Takes a dtype enum from one of numpy, torch, and tensorrt and create a
torch_tensorrt.dtype
. If the source dtype system is not supported or the type is not supported in Torch-TensorRT, then returnsNone
.- Parameters
t (Union(torch.dpython:type, tensorrt.DataType, numpy.dpython:type, dpython:type)) – Data type enum from another library
use_default (bool) – In some cases a catch all type (such as
torch_tensorrt.dtype.f32
) is sufficient, so instead of throwing an exception, return default value.
- Returns
Equivalent
torch_tensorrt.dtype
tot
orNone
- Return type
Optional(dtype)
Examples
# Succeeds float_dtype = torch_tensorrt.dtype.try_from(torch.float) # Returns torch_tensorrt.dtype.f32 # Unsupported type float_dtype = torch_tensorrt.dtype.try_from(torch.complex128) # Returns None
- try_to(t: Union[Type[dtype], Type[DataType], Type[dtype], Type[dtype]], use_default: bool) Optional[Union[dtype, DataType, dtype, dtype]] [source]¶
Convert dtype into the equivalent type in [torch, numpy, tensorrt]
Converts
self
into one of numpy, torch, and tensorrt equivalent dtypes. Ifself
is not supported in the target library, then returnsNone
.- Parameters
t (Union(Type(torch.dpython:type), Type(tensorrt.DataType), Type(numpy.dpython:type), Type(dpython:type))) – Data type enum from another library to convert to
use_default (bool) – In some cases a catch all type (such as
torch.float
) is sufficient, so instead of throwing an exception, return default value.
- Returns
dtype equivalent
torch_tensorrt.dtype
from library enumt
- Return type
Optional(Union(torch.dtype, tensorrt.DataType, numpy.dtype, dtype))
Examples
# Succeeds float_dtype = torch_tensorrt.dtype.f32.to(torch.dtype) # Returns torch.float # Failure float_dtype = torch_tensorrt.dtype.bf16.to(numpy.dtype) # Returns None
- b¶
Boolean value, equivalent to
dtype.bool
- bf16¶
16 bit “Brain” floating-point number, equivalent to
dtype.bfloat16
- f16¶
16 bit floating-point number, equivalent to
dtype.half
,dtype.fp16
anddtype.float16
- f32¶
32 bit floating-point number, equivalent to
dtype.float
,dtype.fp32
anddtype.float32
- f64¶
64 bit floating-point number, equivalent to
dtype.double
,dtype.fp64
anddtype.float64
- f8¶
8 bit floating-point number, equivalent to
dtype.fp8
anddtype.float8
- i32¶
Signed 32 bit integer, equivalent to
dtype.int32
anddtype.int
- i64¶
Signed 64 bit integer, equivalent to
dtype.int64
anddtype.long
- i8¶
Signed 8 bit integer, equivalent to
dtype.int8
, when enabled as a kernel precision typically requires the model to support quantization
- u8¶
Unsigned 8 bit integer, equivalent to
dtype.uint8
- unknown¶
Sentinel value
- class torch_tensorrt.DeviceType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]¶
Type of device TensorRT will target
- to(t: Union[Type[DeviceType], Type[DeviceType]], use_default: bool = False) Union[DeviceType, DeviceType] [source]¶
Convert
DeviceType
into the equivalent type in tensorrtConverts
self
into one of torch or tensorrt equivalent device type. Ifself
is not supported in the target library, then an exception will be raised. As such it is not recommended to use this method directly.Alternatively use
torch_tensorrt.DeviceType.try_to()
- Parameters
t (Union(Type(tensorrt.DeviceType), Type(DeviceType))) – Device type enum from another library to convert to
- Returns
Device type equivalent
torch_tensorrt.DeviceType
in enumt
- Return type
Union(tensorrt.DeviceType, DeviceType)
- Raises
TypeError – Unknown target type or unsupported device type
Examples
# Succeeds trt_dla = torch_tensorrt.DeviceType.DLA.to(tensorrt.DeviceType) # Returns tensorrt.DeviceType.DLA
- classmethod try_from(d: Union[DeviceType, DeviceType]) Optional[DeviceType] [source]¶
Create a Torch-TensorRT device type enum from a TensorRT device type enum.
Takes a device type enum from tensorrt and create a
torch_tensorrt.DeviceType
. If the source is not supported or the device type is not supported in Torch-TensorRT, then an exception will be raised. As such it is not recommended to use this method directly.Alternatively use
torch_tensorrt.DeviceType.try_from()
- Parameters
d (Union(tensorrt.DeviceType, DeviceType)) – Device type enum from another library
- Returns
Equivalent
torch_tensorrt.DeviceType
tod
- Return type
Examples
torchtrt_dla = torch_tensorrt.DeviceType._from(tensorrt.DeviceType.DLA)
- try_to(t: Union[Type[DeviceType], Type[DeviceType]], use_default: bool = False) Optional[Union[DeviceType, DeviceType]] [source]¶
Convert
DeviceType
into the equivalent type in tensorrtConverts
self
into one of torch or tensorrt equivalent memory format. Ifself
is not supported in the target library, thenNone
will be returned.- Parameters
t (Union(Type(tensorrt.DeviceType), Type(DeviceType))) – Device type enum from another library to convert to
- Returns
Device type equivalent
torch_tensorrt.DeviceType
in enumt
- Return type
Optional(Union(tensorrt.DeviceType, DeviceType))
Examples
# Succeeds trt_dla = torch_tensorrt.DeviceType.DLA.to(tensorrt.DeviceType) # Returns tensorrt.DeviceType.DLA
- DLA¶
Target is a DLA core
- GPU¶
Target is a GPU
- UNKNOWN¶
Sentinel value
- class torch_tensorrt.EngineCapability(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]¶
EngineCapability determines the restrictions of a network during build time and what runtime it targets.
- to(t: Union[Type[EngineCapability], Type[EngineCapability]]) Union[EngineCapability, EngineCapability] [source]¶
Convert
EngineCapability
into the equivalent type in tensorrtConverts
self
into one of torch or tensorrt equivalent engine capability. Ifself
is not supported in the target library, then an exception will be raised. As such it is not recommended to use this method directly.Alternatively use
torch_tensorrt.EngineCapability.try_to()
- Parameters
t (Union(Type(tensorrt.EngineCapability), Type(EngineCapability))) – Engine capability enum from another library to convert to
- Returns
Engine capability equivalent
torch_tensorrt.EngineCapability
in enumt
- Return type
Union(tensorrt.EngineCapability, EngineCapability)
- Raises
TypeError – Unknown target type or unsupported engine capability
Examples
# Succeeds torchtrt_dla_ec = torch_tensorrt.EngineCapability.DLA_STANDALONE.to(tensorrt.EngineCapability) # Returns tensorrt.EngineCapability.DLA
- classmethod try_from() Optional[EngineCapability] [source]¶
Create a Torch-TensorRT engine capability enum from a TensorRT engine capability enum.
Takes a device type enum from tensorrt and create a
torch_tensorrt.EngineCapability
. If the source is not supported or the engine capability level is not supported in Torch-TensorRT, then an exception will be raised. As such it is not recommended to use this method directly.Alternatively use
torch_tensorrt.EngineCapability.try_from()
- Parameters
c (Union(tensorrt.EngineCapability, EngineCapability)) – Engine capability enum from another library
- Returns
Equivalent
torch_tensorrt.EngineCapability
toc
- Return type
Examples
torchtrt_safety_ec = torch_tensorrt.EngineCapability._from(tensorrt.EngineCapability.SAEFTY)
- try_to(t: Union[Type[EngineCapability], Type[EngineCapability]]) Optional[Union[EngineCapability, EngineCapability]] [source]¶
Convert
EngineCapability
into the equivalent type in tensorrtConverts
self
into one of torch or tensorrt equivalent engine capability. Ifself
is not supported in the target library, thenNone
will be returned.- Parameters
t (Union(Type(tensorrt.EngineCapability), Type(EngineCapability))) – Engine capability enum from another library to convert to
- Returns
Engine capability equivalent
torch_tensorrt.EngineCapability
in enumt
- Return type
Optional(Union(tensorrt.EngineCapability, EngineCapability))
Examples
# Succeeds trt_dla_ec = torch_tensorrt.EngineCapability.DLA.to(tensorrt.EngineCapability) # Returns tensorrt.EngineCapability.DLA_STANDALONE
- DLA_STANDALONE¶
EngineCapability.DLA_STANDALONE
provides a restricted subset of network operations that are DLA compatible and the resulting serialized engine can be executed using standalone DLA runtime APIs.
- SAFETY¶
EngineCapability.SAFETY provides a restricted subset of network operations that are safety certified and the resulting serialized engine can be executed with TensorRT’s safe runtime APIs in the tensorrt.safe namespace.
- STANDARD¶
EngineCapability.STANDARD does not provide any restrictions on functionality and the resulting serialized engine can be executed with TensorRT’s standard runtime APIs.
- class torch_tensorrt.memory_format(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]¶
- to(t: Union[Type[memory_format], Type[TensorFormat], Type[memory_format]]) Union[memory_format, TensorFormat, memory_format] [source]¶
Convert
memory_format
into the equivalent type in torch or tensorrtConverts
self
into one of torch or tensorrt equivalent memory format. Ifself
is not supported in the target library, then an exception will be raised. As such it is not recommended to use this method directly.Alternatively use
torch_tensorrt.memory_format.try_to()
- Parameters
t (Union(Type(torch.memory_format), Type(tensorrt.TensorFormat), Type(memory_format))) – Memory format type enum from another library to convert to
- Returns
Memory format equivalent
torch_tensorrt.memory_format
in enumt
- Return type
Union(torch.memory_format, tensorrt.TensorFormat, memory_format)
- Raises
TypeError – Unknown target type or unsupported memory format
Examples
# Succeeds tf = torch_tensorrt.memory_format.linear.to(torch.dtype) # Returns torch.contiguous
- classmethod try_from(f: Union[memory_format, TensorFormat, memory_format]) Optional[memory_format] [source]¶
Create a Torch-TensorRT memory format enum from another library memory format enum.
Takes a memory format enum from one of torch, and tensorrt and create a
torch_tensorrt.memory_format
. If the source is not supported or the memory format is not supported in Torch-TensorRT, thenNone
will be returned.- Parameters
f (Union(torch.memory_format, tensorrt.TensorFormat, memory_format)) – Memory format enum from another library
- Returns
Equivalent
torch_tensorrt.memory_format
tof
- Return type
Optional(memory_format)
Examples
torchtrt_linear = torch_tensorrt.memory_format.try_from(torch.contiguous)
- try_to(t: Union[Type[memory_format], Type[TensorFormat], Type[memory_format]]) Optional[Union[memory_format, TensorFormat, memory_format]] [source]¶
Convert
memory_format
into the equivalent type in torch or tensorrtConverts
self
into one of torch or tensorrt equivalent memory format. Ifself
is not supported in the target library, thenNone
will be returned- Parameters
t (Union(Type(torch.memory_format), Type(tensorrt.TensorFormat), Type(memory_format))) – Memory format type enum from another library to convert to
- Returns
Memory format equivalent
torch_tensorrt.memory_format
in enumt
- Return type
Optional(Union(torch.memory_format, tensorrt.TensorFormat, memory_format))
Examples
# Succeeds tf = torch_tensorrt.memory_format.linear.to(torch.dtype) # Returns torch.contiguous
- cdhw32¶
Thirty-two wide channel vectorized row major format with 3 spatial dimensions.
This format is bound to FP16 and INT8. It is only available for dimensions >= 4.
For a tensor with dimensions {N, C, D, H, W}, the memory layout is equivalent to a C array with dimensions [N][(C+31)/32][D][H][W][32], with the tensor coordinates (n, d, c, h, w) mapping to array subscript [n][c/32][d][h][w][c%32].
- chw16¶
Sixteen wide channel vectorized row major format.
This format is bound to FP16. It is only available for dimensions >= 3.
For a tensor with dimensions {N, C, H, W}, the memory layout is equivalent to a C array with dimensions [N][(C+15)/16][H][W][16], with the tensor coordinates (n, c, h, w) mapping to array subscript [n][c/16][h][w][c%16].
- chw2¶
Two wide channel vectorized row major format.
This format is bound to FP16 in TensorRT. It is only available for dimensions >= 3.
For a tensor with dimensions {N, C, H, W}, the memory layout is equivalent to a C array with dimensions [N][(C+1)/2][H][W][2], with the tensor coordinates (n, c, h, w) mapping to array subscript [n][c/2][h][w][c%2].
- chw32¶
Thirty-two wide channel vectorized row major format.
This format is only available for dimensions >= 3.
For a tensor with dimensions {N, C, H, W}, the memory layout is equivalent to a C array with dimensions [N][(C+31)/32][H][W][32], with the tensor coordinates (n, c, h, w) mapping to array subscript [n][c/32][h][w][c%32].
- chw4¶
Four wide channel vectorized row major format. This format is bound to INT8. It is only available for dimensions >= 3.
For a tensor with dimensions {N, C, H, W}, the memory layout is equivalent to a C array with dimensions [N][(C+3)/4][H][W][4], with the tensor coordinates (n, c, h, w) mapping to array subscript [n][c/4][h][w][c%4].
- dhwc¶
Non-vectorized channel-last format. This format is bound to FP32. It is only available for dimensions >= 4.
Equivient to
memory_format.channels_last_3d
- dhwc8¶
Eight channel format where C is padded to a multiple of 8.
This format is bound to FP16, and it is only available for dimensions >= 4.
For a tensor with dimensions {N, C, D, H, W}, the memory layout is equivalent to an array with dimensions [N][D][H][W][(C+7)/8*8], with the tensor coordinates (n, c, d, h, w) mapping to array subscript [n][d][h][w][c].
- dla_hwc4¶
DLA image format. channel-last format. C can only be 1, 3, 4. If C == 3 it will be rounded to 4. The stride for stepping along the H axis is rounded up to 32 bytes.
This format is bound to FP16/Int8 and is only available for dimensions >= 3.
For a tensor with dimensions {N, C, H, W}, with C’ is 1, 4, 4 when C is 1, 3, 4 respectively, the memory layout is equivalent to a C array with dimensions [N][H][roundUp(W, 32/C’/elementSize)][C’] where elementSize is 2 for FP16 and 1 for Int8, C’ is the rounded C. The tensor coordinates (n, c, h, w) maps to array subscript [n][h][w][c].
- dla_linear¶
DLA planar format. Row major format. The stride for stepping along the H axis is rounded up to 64 bytes.
This format is bound to FP16/Int8 and is only available for dimensions >= 3.
For a tensor with dimensions {N, C, H, W}, the memory layout is equivalent to a C array with dimensions [N][C][H][roundUp(W, 64/elementSize)] where elementSize is 2 for FP16 and 1 for Int8, with the tensor coordinates (n, c, h, w) mapping to array subscript [n][c][h][w].
- hwc¶
Non-vectorized channel-last format. This format is bound to FP32 and is only available for dimensions >= 3.
Equivient to
memory_format.channels_last
- hwc16¶
Sixteen channel format where C is padded to a multiple of 16. This format is bound to FP16. It is only available for dimensions >= 3.
For a tensor with dimensions {N, C, H, W}, the memory layout is equivalent to the array with dimensions [N][H][W][(C+15)/16*16], with the tensor coordinates (n, c, h, w) mapping to array subscript [n][h][w][c].
- hwc8¶
Eight channel format where C is padded to a multiple of 8.
This format is bound to FP16. It is only available for dimensions >= 3.
For a tensor with dimensions {N, C, H, W}, the memory layout is equivalent to the array with dimensions [N][H][W][(C+7)/8*8], with the tensor coordinates (n, c, h, w) mapping to array subscript [n][h][w][c].
- linear¶
Row major linear format.
For a tensor with dimensions {N, C, H, W}, the W axis always has unit stride, and the stride of every other axis is at least the product of the next dimension times the next stride. the strides are the same as for a C array with dimensions [N][C][H][W].
Equivient to
memory_format.contiguous