torch_tensorrt¶
Functions¶
- torch_tensorrt.compile(module: Any, ir: str = 'default', inputs: Optional[Sequence[torch_tensorrt._Input.Input | torch.Tensor | torch_tensorrt.fx.input_tensor_spec.InputTensorSpec]] = None, enabled_precisions: Optional[Set[torch.dtype | torch_tensorrt._C.dtype]] = None, **kwargs: Any) Union[torch.nn.modules.module.Module, torch.jit._script.ScriptModule, torch.fx.graph_module.GraphModule, Callable[[...], Any]] [source]¶
Compile a PyTorch module for NVIDIA GPUs using TensorRT
Takes a existing PyTorch module and a set of settings to configure the compiler and using the path specified in
ir
lower and compile the module to TensorRT returning a PyTorch Module backConverts specifically the forward method of a Module
- Parameters
module (Union(torch.nn.Module,torch.jit.ScriptModule) – Source module
- Keyword Arguments
inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]) –
Required List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type.
input=[ torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1 torch_tensorrt.Input( min_shape=(1, 224, 224, 3), opt_shape=(1, 512, 512, 3), max_shape=(1, 1024, 1024, 3), dtype=torch.int32 format=torch.channel_last ), # Dynamic input shape for input #2 torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings ]
enabled_precision (Set(Union(torch.dpython:type, torch_tensorrt.dpython:type))) – The set of datatypes that TensorRT can use when selecting kernels
ir (str) – The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
**kwargs – Additional settings for the specific requested strategy (See submodules for more info)
- Returns
Compiled Module, when run it will execute via TensorRT
- Return type
torch.nn.Module
- torch_tensorrt.convert_method_to_trt_engine(module: Any, method_name: str = 'forward', inputs: Optional[Sequence[torch_tensorrt._Input.Input | torch.Tensor]] = None, ir: str = 'default', enabled_precisions: Optional[Set[torch.dtype | torch_tensorrt._C.dtype]] = None, **kwargs: Any) bytes [source]¶
Convert a TorchScript module method to a serialized TensorRT engine
Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings
- Parameters
module (Union(torch.nn.Module,torch.jit.ScriptModule) – Source module
- Keyword Arguments
inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]) –
Required List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type.
input=[ torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1 torch_tensorrt.Input( min_shape=(1, 224, 224, 3), opt_shape=(1, 512, 512, 3), max_shape=(1, 1024, 1024, 3), dtype=torch.int32 format=torch.channel_last ), # Dynamic input shape for input #2 torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings ]
enabled_precision (Set(Union(torch.dpython:type, torch_tensorrt.dpython:type))) – The set of datatypes that TensorRT can use when selecting kernels
ir (str) – The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
**kwargs – Additional settings for the specific requested strategy (See submodules for more info)
- Returns
Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
- Return type
bytes
Classes¶
- class torch_tensorrt.Input(*args: Any, **kwargs: Any)[source]¶
Defines an input to a module in terms of expected shape, data type and tensor format.
- Variables
shape_mode (torch_tensorrt.Input._ShapeMode) – Is input statically or dynamically shaped
shape (Tuple or Dict) –
Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{
”min_shape”: Tuple, “opt_shape”: Tuple, “max_shape”: Tuple
}``
dtype (torch_tensorrt.dpython:type) – The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
format (torch_tensorrt.TensorFormat) – The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
- __init__(*args: Any, **kwargs: Any) None [source]¶
__init__ Method for torch_tensorrt.Input
Input accepts one of a few construction patterns
- Parameters
shape (Tuple or List, optional) – Static shape of input tensor
- Keyword Arguments
shape (Tuple or List, optional) – Static shape of input tensor
min_shape (Tuple or List, optional) – Min size of input tensor’s shape range Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input’s shape_mode to DYNAMIC
opt_shape (Tuple or List, optional) – Opt size of input tensor’s shape range Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input’s shape_mode to DYNAMIC
max_shape (Tuple or List, optional) – Max size of input tensor’s shape range Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input’s shape_mode to DYNAMIC
dtype (torch.dpython:type or torch_tensorrt.dpython:type) – Expected data type for input tensor (default: torch_tensorrt.dtype.float32)
format (torch.memory_format or torch_tensorrt.TensorFormat) – The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
tensor_domain (Tuple(python:float, python:float), optional) – The domain of allowed values for the tensor, as interval notation: [tensor_domain[0], tensor_domain[1]). Note: Entering “None” (or not specifying) will set the bound to [0, 2)
Examples
Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last)
Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW)
Input(min_shape=(1,3,32,32), opt_shape=[2,3,32,32], max_shape=(3,3,32,32)) #Implicitly dtype=torch_tensorrt.dtype.float32, format=torch_tensorrt.TensorFormat.NCHW
- dtype: _enums.dtype = <dtype.unknown: 6>¶
torch_tensorrt.dtype.float32)
- Type
The expected data type of the input tensor (default
- example_tensor(optimization_profile_field: Optional[str] = None) torch.Tensor [source]¶
Get an example tensor of the shape specified by the Input object
- Parameters
optimization_profile_field (Optional(str)) – Name of the field to use for shape in the case the Input is dynamically shaped
- Returns
A PyTorch Tensor
- format: _enums.TensorFormat = <TensorFormat.contiguous: 0>¶
torch_tensorrt.TensorFormat.NCHW)
- Type
The expected format of the input tensor (default
- classmethod from_tensor(t: torch.Tensor, disable_memory_format_check: bool = False) torch_tensorrt._Input.Input [source]¶
Produce a Input which contains the information of the given PyTorch tensor.
- Parameters
tensor (torch.Tensor) – A PyTorch tensor.
disable_memory_format_check (bool) – Whether to validate the memory formats of input tensors
- Returns
A Input object.
- classmethod from_tensors(ts: Sequence[torch.Tensor], disable_memory_format_check: bool = False) List[torch_tensorrt._Input.Input] [source]¶
Produce a list of Inputs which contain the information of all the given PyTorch tensors.
- Parameters
tensors (Iterable[torch.Tensor]) – A list of PyTorch tensors.
disable_memory_format_check (bool) – Whether to validate the memory formats of input tensors
- Returns
A list of Inputs.
- shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = None¶
Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form
{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }
- shape_mode: Optional[_ShapeMode] = None¶
Is input statically or dynamically shaped
- class torch_tensorrt.Device(*args: Any, **kwargs: Any)[source]¶
Defines a device that can be used to specify target devices for engines
- Variables
device_type (torch_tensorrt.DeviceType) – Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
gpu_id (python:int) – Device ID for target GPU
dla_core (python:int) – Core ID for target DLA core
allow_gpu_fallback (bool) – Whether falling back to GPU if DLA cannot support an op should be allowed
- __init__(*args: Any, **kwargs: Any)[source]¶
__init__ Method for torch_tensorrt.Device
Device accepts one of a few construction patterns
- Parameters
spec (str) – String with device spec e.g. “dla:0” for dla, core_id 0
- Keyword Arguments
gpu_id (python:int) – ID of target GPU (will get overrided if dla_core is specified to the GPU managing DLA). If specified, no positional arguments should be provided
dla_core (python:int) – ID of target DLA core. If specified, no positional arguments should be provided.
allow_gpu_fallback (bool) – Allow TensorRT to schedule operations on GPU if they are not supported on DLA (ignored if device type is not DLA)
Examples
Device(“gpu:1”)
Device(“cuda:1”)
Device(“dla:0”, allow_gpu_fallback=True)
Device(gpu_id=0, dla_core=0, allow_gpu_fallback=True)
Device(dla_core=0, allow_gpu_fallback=True)
Device(gpu_id=1)
- allow_gpu_fallback: bool = False¶
Whether falling back to GPU if DLA cannot support an op should be allowed
- device_type: Optional[tensorrt_bindings.tensorrt.DeviceType] = None¶
Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
- dla_core: int = -1¶
Core ID for target DLA core
- gpu_id: int = -1¶
Device ID for target GPU
Enums¶
- class torch_tensorrt.dtype¶
Enum to specifiy operating precision for engine execution
Members:
float : 32 bit floating point number
float32 : 32 bit floating point number
half : 16 bit floating point number
float16 : 16 bit floating point number
int8 : 8 bit integer number
int32 : 32 bit integer number
long : 64 bit integer number
int64 : 64 bit integer number
bool : Boolean value
unknown : Unknown data type
- class torch_tensorrt.DeviceType¶
Device types that TensorRT can execute on
Members:
GPU : GPU device
DLA : DLA core
- class torch_tensorrt.EngineCapability¶
Enum to specify engine capability settings (selections of kernels to meet safety requirements)
Members:
safe_gpu : Use safety GPU kernels only
safe_dla : Use safety DLA kernels only
default : Use default behavior
- class torch_tensorrt.TensorFormat¶
Enum to specifiy the memory layout of tensors
Members:
contiguous : Contiguous memory layout (NCHW / Linear)
channels_last : Channels last memory layout (NHWC)