Shortcuts

torch_tensorrt

Functions

torch_tensorrt.set_device(gpu_id: int) None[source]
torch_tensorrt.compile(module: Any, ir: str = 'default', inputs: Optional[Sequence[Input | torch.Tensor | InputTensorSpec]] = None, enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, **kwargs: Any) Union[Module, ScriptModule, GraphModule, Callable[[...], Any]][source]

Compile a PyTorch module for NVIDIA GPUs using TensorRT

Takes a existing PyTorch module and a set of settings to configure the compiler and using the path specified in ir lower and compile the module to TensorRT returning a PyTorch Module back

Converts specifically the forward method of a Module

Parameters

module (Union(torch.nn.Module,torch.jit.ScriptModule) – Source module

Keyword Arguments
  • inputs (List[Union(Input, torch.Tensor)]) –

    Required List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type.

    input=[
        torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
        torch_tensorrt.Input(
            min_shape=(1, 224, 224, 3),
            opt_shape=(1, 512, 512, 3),
            max_shape=(1, 1024, 1024, 3),
            dtype=torch.int32
            format=torch.channel_last
        ), # Dynamic input shape for input #2
        torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
    ]
    

  • enabled_precision (Set(Union(torch.dpython:type, torch_tensorrt.dpython:type))) – The set of datatypes that TensorRT can use when selecting kernels

  • ir (str) – The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)

  • **kwargs – Additional settings for the specific requested strategy (See submodules for more info)

Returns

Compiled Module, when run it will execute via TensorRT

Return type

torch.nn.Module

torch_tensorrt.convert_method_to_trt_engine(module: Any, method_name: str = 'forward', inputs: Optional[Sequence[Input | torch.Tensor]] = None, ir: str = 'default', enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, **kwargs: Any) bytes[source]

Convert a TorchScript module method to a serialized TensorRT engine

Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings

Parameters

module (Union(torch.nn.Module,torch.jit.ScriptModule) – Source module

Keyword Arguments
  • inputs (List[Union(Input, torch.Tensor)]) –

    Required List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type.

    input=[
        torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
        torch_tensorrt.Input(
            min_shape=(1, 224, 224, 3),
            opt_shape=(1, 512, 512, 3),
            max_shape=(1, 1024, 1024, 3),
            dtype=torch.int32
            format=torch.channel_last
        ), # Dynamic input shape for input #2
        torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
    ]
    

  • enabled_precision (Set(Union(torch.dpython:type, torch_tensorrt.dpython:type))) – The set of datatypes that TensorRT can use when selecting kernels

  • ir (str) – The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)

  • **kwargs – Additional settings for the specific requested strategy (See submodules for more info)

Returns

Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs

Return type

bytes

torch_tensorrt.get_build_info() str[source]

Returns a string containing the build information of torch_tensorrt distribution

Returns

String containing the build information for torch_tensorrt distribution

Return type

str

torch_tensorrt.dump_build_info() None[source]

Prints build information about the torch_tensorrt distribution to stdout

Classes

class torch_tensorrt.Input(*args: Any, **kwargs: Any)[source]

Defines an input to a module in terms of expected shape, data type and tensor format.

Variables
  • shape_mode (torch_tensorrt.Input._ShapeMode) – Is input statically or dynamically shaped

  • shape (Tuple or Dict) –

    Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{

    ”min_shape”: Tuple, “opt_shape”: Tuple, “max_shape”: Tuple

    }``

  • dtype (torch_tensorrt.dpython:type) – The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)

  • format (TensorFormat) – The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)

__init__(*args: Any, **kwargs: Any) None[source]

__init__ Method for torch_tensorrt.Input

Input accepts one of a few construction patterns

Parameters

shape (Tuple or List, optional) – Static shape of input tensor

Keyword Arguments
  • shape (Tuple or List, optional) – Static shape of input tensor

  • min_shape (Tuple or List, optional) – Min size of input tensor’s shape range Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input’s shape_mode to DYNAMIC

  • opt_shape (Tuple or List, optional) – Opt size of input tensor’s shape range Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input’s shape_mode to DYNAMIC

  • max_shape (Tuple or List, optional) – Max size of input tensor’s shape range Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input’s shape_mode to DYNAMIC

  • dtype (torch.dpython:type or torch_tensorrt.dpython:type) – Expected data type for input tensor (default: torch_tensorrt.dtype.float32)

  • format (torch.memory_format or TensorFormat) – The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)

  • tensor_domain (Tuple(python:float, python:float), optional) – The domain of allowed values for the tensor, as interval notation: [tensor_domain[0], tensor_domain[1]). Note: Entering “None” (or not specifying) will set the bound to [0, 2)

  • torch_tensor (torch.Tensor) – Holds a corresponding torch tensor with this Input.

  • name (str, optional) – Name of this input in the input nn.Module’s forward function. Used to specify dynamic shapes for the corresponding input in dynamo tracer.

Examples

  • Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last)

  • Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW)

  • Input(min_shape=(1,3,32,32), opt_shape=[2,3,32,32], max_shape=(3,3,32,32)) #Implicitly dtype=torch_tensorrt.dtype.float32, format=torch_tensorrt.TensorFormat.NCHW

dtype: _enums.dtype = <dtype.unknown: 7>

torch_tensorrt.dtype.float32)

Type

The expected data type of the input tensor (default

example_tensor(optimization_profile_field: Optional[str] = None) Tensor[source]

Get an example tensor of the shape specified by the Input object

Parameters

optimization_profile_field (Optional(str)) – Name of the field to use for shape in the case the Input is dynamically shaped

Returns

A PyTorch Tensor

format: _enums.TensorFormat = <TensorFormat.contiguous: 0>

torch_tensorrt.TensorFormat.NCHW)

Type

The expected format of the input tensor (default

classmethod from_tensor(t: Tensor, disable_memory_format_check: bool = False) Input[source]

Produce a Input which contains the information of the given PyTorch tensor.

Parameters
  • tensor (torch.Tensor) – A PyTorch tensor.

  • disable_memory_format_check (bool) – Whether to validate the memory formats of input tensors

Returns

A Input object.

classmethod from_tensors(ts: Sequence[Tensor], disable_memory_format_check: bool = False) List[Input][source]

Produce a list of Inputs which contain the information of all the given PyTorch tensors.

Parameters
  • tensors (Iterable[torch.Tensor]) – A list of PyTorch tensors.

  • disable_memory_format_check (bool) – Whether to validate the memory formats of input tensors

Returns

A list of Inputs.

shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = None

Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form { "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }

shape_mode: Optional[_ShapeMode] = None

Is input statically or dynamically shaped

class torch_tensorrt.Device(*args: Any, **kwargs: Any)[source]

Defines a device that can be used to specify target devices for engines

Variables
  • device_type (DeviceType) – Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.

  • gpu_id (python:int) – Device ID for target GPU

  • dla_core (python:int) – Core ID for target DLA core

  • allow_gpu_fallback (bool) – Whether falling back to GPU if DLA cannot support an op should be allowed

__init__(*args: Any, **kwargs: Any)[source]

__init__ Method for torch_tensorrt.Device

Device accepts one of a few construction patterns

Parameters

spec (str) – String with device spec e.g. “dla:0” for dla, core_id 0

Keyword Arguments
  • gpu_id (python:int) – ID of target GPU (will get overrided if dla_core is specified to the GPU managing DLA). If specified, no positional arguments should be provided

  • dla_core (python:int) – ID of target DLA core. If specified, no positional arguments should be provided.

  • allow_gpu_fallback (bool) – Allow TensorRT to schedule operations on GPU if they are not supported on DLA (ignored if device type is not DLA)

Examples

  • Device(“gpu:1”)

  • Device(“cuda:1”)

  • Device(“dla:0”, allow_gpu_fallback=True)

  • Device(gpu_id=0, dla_core=0, allow_gpu_fallback=True)

  • Device(dla_core=0, allow_gpu_fallback=True)

  • Device(gpu_id=1)

allow_gpu_fallback: bool = False

Whether falling back to GPU if DLA cannot support an op should be allowed

device_type: Optional[DeviceType] = None

Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.

dla_core: int = -1

Core ID for target DLA core

gpu_id: int = -1

Device ID for target GPU

Enums

class torch_tensorrt.dtype

Enum to specifiy operating precision for engine execution

Members:

float : 32 bit floating point number

float32 : 32 bit floating point number

half : 16 bit floating point number

float16 : 16 bit floating point number

int8 : 8 bit integer number

int32 : 32 bit integer number

long : 64 bit integer number

int64 : 64 bit integer number

double : 64 bit floating point number

float64 : 64 bit floating point number

bool : Boolean value

unknown : Unknown data type

class torch_tensorrt.DeviceType

Device types that TensorRT can execute on

Members:

GPU : GPU device

DLA : DLA core

class torch_tensorrt.EngineCapability

Enum to specify engine capability settings (selections of kernels to meet safety requirements)

Members:

safe_gpu : Use safety GPU kernels only

safe_dla : Use safety DLA kernels only

default : Use default behavior

class torch_tensorrt.TensorFormat

Enum to specifiy the memory layout of tensors

Members:

contiguous : Contiguous memory layout (NCHW / Linear)

channels_last : Channels last memory layout (NHWC)

Submodules

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources