• Docs >
  • Using Torch-TensorRT in Python
Shortcuts

Using Torch-TensorRT in Python

The Torch-TensorRT Python API supports a number of unique usecases compared to the CLI and C++ APIs which solely support TorchScript compilation.

Torch-TensorRT Python API can accept a torch.nn.Module, torch.jit.ScriptModule, or torch.fx.GraphModule as an input. Depending on what is provided one of the two frontends (TorchScript or FX) will be selected to compile the module. Provided the module type is supported, users may explicitly set which frontend they would like to use using the ir flag for compile. If given a torch.nn.Module and the ir flag is set to either default or torchscript the module will be run through torch.jit.script to convert the input module into a TorchScript module.

To compile your input torch.nn.Module with Torch-TensorRT, all you need to do is provide the module and inputs to Torch-TensorRT and you will be returned an optimized TorchScript module to run or add into another PyTorch module. Inputs is a list of torch_tensorrt.Input classes which define input Tensors’ shape, datatype and memory format. Alternatively, if your input is a more complex data type, such as a tuple or list of Tensors, you can use the input_signature argument to specify a collection-based input, such as (List[Tensor], Tuple[Tensor, Tensor]). See the second sample below for an example. You can also specify settings such as operating precision for the engine or target device. After compilation you can save the module just like any other module to load in a deployment application. In order to load a TensorRT/TorchScript module, make sure you first import torch_tensorrt.

import torch_tensorrt

...

model = MyModel().eval()  # torch module needs to be in eval (not training) mode

inputs = [
    torch_tensorrt.Input(
        min_shape=[1, 1, 16, 16],
        opt_shape=[1, 1, 32, 32],
        max_shape=[1, 1, 64, 64],
        dtype=torch.half,
    )
]
enabled_precisions = {torch.float, torch.half}  # Run with fp16

trt_ts_module = torch_tensorrt.compile(
    model, inputs=inputs, enabled_precisions=enabled_precisions
)

input_data = input_data.to("cuda").half()
result = trt_ts_module(input_data)
torch.jit.save(trt_ts_module, "trt_ts_module.ts")
# Sample using collection-based inputs via the input_signature argument
import torch_tensorrt

...

model = MyModel().eval()

# input_signature expects a tuple of individual input arguments to the module
# The module below, for example, would have a docstring of the form:
# def forward(self, input0: List[torch.Tensor], input1: Tuple[torch.Tensor, torch.Tensor])
input_signature = (
    [torch_tensorrt.Input(shape=[64, 64], dtype=torch.half), torch_tensorrt.Input(shape=[64, 64], dtype=torch.half)],
    (torch_tensorrt.Input(shape=[64, 64], dtype=torch.half), torch_tensorrt.Input(shape=[64, 64], dtype=torch.half)),
)
enabled_precisions = {torch.float, torch.half}

trt_ts_module = torch_tensorrt.compile(
    model, input_signature=input_signature, enabled_precisions=enabled_precisions
)

input_data = input_data.to("cuda").half()
result = trt_ts_module(input_data)
torch.jit.save(trt_ts_module, "trt_ts_module.ts")
# Deployment application
import torch
import torch_tensorrt

trt_ts_module = torch.jit.load("trt_ts_module.ts")
input_data = input_data.to("cuda").half()
result = trt_ts_module(input_data)

Torch-TensorRT Python API also provides torch_tensorrt.ts.compile which accepts a TorchScript module as input and torch_tensorrt.fx.compile which accepts a FX GraphModule as input.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources