torch_tensorrt¶
Functions¶
- torch_tensorrt.compile(module: typing.Any, ir='default', inputs=[], enabled_precisions={<dtype.float: 1>}, **kwargs)[source]¶
Compile a PyTorch module for NVIDIA GPUs using TensorRT
Takes a existing PyTorch module and a set of settings to configure the compiler and using the path specified in
ir
lower and compile the module to TensorRT returning a PyTorch Module backConverts specifically the forward method of a Module
- Parameters
module (Union(torch.nn.Module,torch.jit.ScriptModule) – Source module
- Keyword Arguments
inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]) –
Required List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type.
input=[ torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1 torch_tensorrt.Input( min_shape=(1, 224, 224, 3), opt_shape=(1, 512, 512, 3), max_shape=(1, 1024, 1024, 3), dtype=torch.int32 format=torch.channel_last ), # Dynamic input shape for input #2 torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings ]
enabled_precision (Set(Union(torch.dpython:type, torch_tensorrt.dpython:type))) – The set of datatypes that TensorRT can use when selecting kernels
ir (str) – The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
**kwargs – Additional settings for the specific requested strategy (See submodules for more info)
- Returns
Compiled Module, when run it will execute via TensorRT
- Return type
torch.nn.Module
- torch_tensorrt.convert_method_to_trt_engine(module: typing.Any, method_name: str, ir='default', inputs=[], enabled_precisions={<dtype.float: 1>}, **kwargs)[source]¶
Convert a TorchScript module method to a serialized TensorRT engine
Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings
- Parameters
module (Union(torch.nn.Module,torch.jit.ScriptModule) – Source module
- Keyword Arguments
inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]) –
Required List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type.
input=[ torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1 torch_tensorrt.Input( min_shape=(1, 224, 224, 3), opt_shape=(1, 512, 512, 3), max_shape=(1, 1024, 1024, 3), dtype=torch.int32 format=torch.channel_last ), # Dynamic input shape for input #2 torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings ]
enabled_precision (Set(Union(torch.dpython:type, torch_tensorrt.dpython:type))) – The set of datatypes that TensorRT can use when selecting kernels
ir (str) – The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
**kwargs – Additional settings for the specific requested strategy (See submodules for more info)
- Returns
Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
- Return type
bytes
Classes¶
- class torch_tensorrt.Input(*args, **kwargs)[source]¶
Defines an input to a module in terms of expected shape, data type and tensor format.
- Variables
shape_mode (torch_tensorrt.Input._ShapeMode) – Is input statically or dynamically shaped
shape (Tuple or Dict) –
Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{
”min_shape”: Tuple, “opt_shape”: Tuple, “max_shape”: Tuple
}``
dtype (torch_tensorrt.dpython:type) – The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
format (torch_tensorrt.TensorFormat) – The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
- __init__(*args, **kwargs)[source]¶
__init__ Method for torch_tensorrt.Input
Input accepts one of a few construction patterns
- Parameters
shape (Tuple or List, optional) – Static shape of input tensor
- Keyword Arguments
shape (Tuple or List, optional) – Static shape of input tensor
min_shape (Tuple or List, optional) – Min size of input tensor’s shape range Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input’s shape_mode to DYNAMIC
opt_shape (Tuple or List, optional) – Opt size of input tensor’s shape range Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input’s shape_mode to DYNAMIC
max_shape (Tuple or List, optional) – Max size of input tensor’s shape range Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input’s shape_mode to DYNAMIC
dtype (torch.dpython:type or torch_tensorrt.dpython:type) – Expected data type for input tensor (default: torch_tensorrt.dtype.float32)
format (torch.memory_format or torch_tensorrt.TensorFormat) – The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
tensor_domain (Tuple(python:float, python:float), optional) – The domain of allowed values for the tensor, as interval notation: [tensor_domain[0], tensor_domain[1]). Note: Entering “None” (or not specifying) will set the bound to [0, 2)
Examples
Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last)
Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW)
Input(min_shape=(1,3,32,32), opt_shape=[2,3,32,32], max_shape=(3,3,32,32)) #Implicitly dtype=torch_tensorrt.dtype.float32, format=torch_tensorrt.TensorFormat.NCHW
- dtype = <dtype.unknown: 6>¶
torch_tensorrt.dtype.float32)
- Type
The expected data type of the input tensor (default
- example_tensor(optimization_profile_field: Optional[str] = None) torch.Tensor [source]¶
Get an example tensor of the shape specified by the Input object
- Parameters
optimization_profile_field (Optional(str)) – Name of the field to use for shape in the case the Input is dynamically shaped
- Returns
A PyTorch Tensor
- format = <TensorFormat.contiguous: 0>¶
torch_tensorrt.TensorFormat.NCHW)
- Type
The expected format of the input tensor (default
- classmethod from_tensor(t: torch.Tensor) torch_tensorrt._Input.Input [source]¶
Produce a Input which contains the information of the given PyTorch tensor.
- Parameters
tensor (torch.Tensor) – A PyTorch tensor.
- Returns
A Input object.
- classmethod from_tensors(ts: torch.Tensor) List[torch_tensorrt._Input.Input] [source]¶
Produce a list of Inputs which contain the information of all the given PyTorch tensors.
- Parameters
tensors (Iterable[torch.Tensor]) – A list of PyTorch tensors.
- Returns
A list of Inputs.
- shape = None¶
Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form
{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }
- Type
(Tuple or Dict)
- shape_mode = None¶
Is input statically or dynamically shaped
- Type
(torch_tensorrt.Input._ShapeMode)
- class torch_tensorrt.Device(*args, **kwargs)[source]¶
Defines a device that can be used to specify target devices for engines
- Variables
device_type (torch_tensorrt.DeviceType) – Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
gpu_id (python:int) – Device ID for target GPU
dla_core (python:int) – Core ID for target DLA core
allow_gpu_fallback (bool) – Whether falling back to GPU if DLA cannot support an op should be allowed
- __init__(*args, **kwargs)[source]¶
__init__ Method for torch_tensorrt.Device
Device accepts one of a few construction patterns
- Parameters
spec (str) – String with device spec e.g. “dla:0” for dla, core_id 0
- Keyword Arguments
gpu_id (python:int) – ID of target GPU (will get overrided if dla_core is specified to the GPU managing DLA). If specified, no positional arguments should be provided
dla_core (python:int) – ID of target DLA core. If specified, no positional arguments should be provided.
allow_gpu_fallback (bool) – Allow TensorRT to schedule operations on GPU if they are not supported on DLA (ignored if device type is not DLA)
Examples
Device(“gpu:1”)
Device(“cuda:1”)
Device(“dla:0”, allow_gpu_fallback=True)
Device(gpu_id=0, dla_core=0, allow_gpu_fallback=True)
Device(dla_core=0, allow_gpu_fallback=True)
Device(gpu_id=1)
- allow_gpu_fallback = False¶
(bool) Whether falling back to GPU if DLA cannot support an op should be allowed
- device_type = None¶
Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
- Type
- dla_core = -1¶
(int) Core ID for target DLA core
- gpu_id = -1¶
(int) Device ID for target GPU
- class torch_tensorrt.TRTModuleNext(serialized_engine: bytearray = bytearray(b''), name: str = '', input_binding_names: List[str] = [], output_binding_names: List[str] = [], target_device: Optional[torch_tensorrt._Device.Device] = None)[source]¶
TRTModuleNext is a PyTorch module which encompasses an arbitrary TensorRT Engine.
This module is backed by the Torch-TensorRT runtime and is fully compatibile with both FX / Python deployments (just
import torch_tensorrt
as part of the application) as well as TorchScript / C++ deployments since TRTModule can be passed totorch.jit.trace
and then saved.The forward function is simpily forward(*args: torch.Tensor) -> Tuple[torch.Tensor] where the internal implementation is
return Tuple(torch.ops.tensorrt.execute_engine(list(inputs), self.engine))
> Note: TRTModuleNext only supports engines built with explict batch
- Variables
name (str) – Name of module (for easier debugging)
engine (torch.classess.tensorrt.Engine) – Torch-TensorRT TensorRT Engine instance, manages [de]serialization, device configuration, profiling
input_binding_names (List[str]) – List of input TensorRT engine binding names in the order they would be passed to the TRT modules
output_binding_names (List[str]) – List of output TensorRT engine binding names in the order they should be returned
- __init__(serialized_engine: bytearray = bytearray(b''), name: str = '', input_binding_names: List[str] = [], output_binding_names: List[str] = [], target_device: Optional[torch_tensorrt._Device.Device] = None)[source]¶
__init__ method for torch_tensorrt.TRTModuleNext
Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch
torch.nn.Module
around it.If binding names are not provided, it is assumed that the engine binding names follow the following convention:
- [symbol].[index in input / output array]
ex. [x.0, x.1, x.2] -> [y.0]
- Parameters
name (str) – Name for module
serialized_engine (bytearray) – Serialized TensorRT engine in the form of a bytearray
input_binding_names (List[str]) – List of input TensorRT engine binding names in the order they would be passed to the TRT modules
output_binding_names (List[str]) – List of output TensorRT engine binding names in the order they should be returned
target_device – (torch_tensorrt.Device): Device to instantiate TensorRT engine on. Must be a compatible device i.e. same GPU model / compute capability as was used to build the engine
Example
..code-block:: py
- with io.BytesIO() as engine_bytes:
engine_bytes.write(trt_engine.serialize()) engine_str = engine_bytes.getvalue()
- trt_module = TRTModule(
engine_str, engine_name=”my_module”, input_names=[“x”], output_names=[“output”],
)
- dump_layer_info()[source]¶
Dump layer information encoded by the TensorRT engine in this module to STDOUT
- enable_profiling(profiling_results_dir: Optional[str] = None)[source]¶
Enable the profiler to collect latency information about the execution of the engine
Traces can be visualized using https://ui.perfetto.dev/ or compatible alternatives
- Keyword Arguments
profiling_results_dir (str) – Absolute path to the directory to sort results of profiling.
- forward(*inputs)[source]¶
Implementation of the forward pass for a TensorRT engine
- Parameters
*inputs (torch.Tensor) – Inputs to the forward function, must all be
torch.Tensor
- Returns
Result of the engine computation
- Return type
torch.Tensor or Tuple(torch.Tensor)
- get_extra_state()[source]¶
Returns any extra state to include in the module’s state_dict. Implement this and a corresponding
set_extra_state()
for your module if you need to store extra state. This function is called when building the module’s state_dict().Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.
- Returns
Any extra state to store in the module’s state_dict
- Return type
object
- get_layer_info() str [source]¶
Get a JSON string containing the layer information encoded by the TensorRT engine in this module
- Returns
A JSON string which contains the layer information of the engine incapsulated in this module
- Return type
str
- set_extra_state(state)[source]¶
This function is called from
load_state_dict()
to handle any extra state found within the state_dict. Implement this function and a correspondingget_extra_state()
for your module if you need to store extra state within its state_dict.- Parameters
state (dict) – Extra state from the state_dict
Enums¶
- class torch_tensorrt.dtype¶
Enum to specifiy operating precision for engine execution
Members:
float : 32 bit floating point number
float32 : 32 bit floating point number
half : 16 bit floating point number
float16 : 16 bit floating point number
int8 : 8 bit integer number
int32 : 32 bit integer number
long : 64 bit integer number
int64 : 64 bit integer number
bool : Boolean value
unknown : Unknown data type
- class torch_tensorrt.DeviceType¶
Enum to specify device kinds to build TensorRT engines for
Members:
GPU : Specify using GPU to execute TensorRT Engine
DLA : Specify using DLA to execute TensorRT Engine (Jetson Only)
- class torch_tensorrt.EngineCapability¶
Enum to specify engine capability settings (selections of kernels to meet safety requirements)
Members:
safe_gpu : Use safety GPU kernels only
safe_dla : Use safety DLA kernels only
default : Use default behavior
- class torch_tensorrt.TensorFormat¶
Enum to specifiy the memory layout of tensors
Members:
contiguous : Contiguous memory layout (NCHW / Linear)
channels_last : Channels last memory layout (NHWC)