Using Torch-TensorRT in Python

Torch-TensorRT Python API accepts a `torch.nn.Module as an input. Under the hood, it uses torch.jit.script to convert the input module into a TorchScript module. To compile your input `torch.nn.Module with Torch-TensorRT, all you need to do is provide the module and inputs to Torch-TensorRT and you will be returned an optimized TorchScript module to run or add into another PyTorch module. Inputs is a list of torch_tensorrt.Input classes which define input’s shape, datatype and memory format. You can also specify settings such as operating precision for the engine or target device. After compilation you can save the module just like any other module to load in a deployment application. In order to load a TensorRT/TorchScript module, make sure you first import torch_tensorrt .

import torch_tensorrt

...

model = MyModel().eval() # torch module needs to be in eval (not training) mode

inputs = [torch_tensorrt.Input(
            min_shape=[1, 1, 16, 16],
            opt_shape=[1, 1, 32, 32],
            max_shape=[1, 1, 64, 64],
            dtype=torch.half,
        )]
enabled_precisions = {torch.float, torch.half} # Run with fp16

trt_ts_module = torch_tensorrt.compile(model, inputs=inputs, enabled_precisions=enabled_precisions)

input_data = input_data.to('cuda').half()
result = trt_ts_module(input_data)
torch.jit.save(trt_ts_module, "trt_ts_module.ts")
# Deployment application
import torch
import torch_tensorrt

trt_ts_module = torch.jit.load("trt_ts_module.ts")
input_data = input_data.to('cuda').half()
result = trt_ts_module(input_data)

Torch-TensorRT python API also provides torch_tensorrt.ts.compile which accepts a TorchScript module as input. The torchscript module can be obtained via scripting or tracing (refer to creating_torchscript_module_in_python ). torch_tensorrt.ts.compile accepts a Torchscript module and a list of torch_tensorrt.Input classes.