• Docs >
  • Saving models compiled with Torch-TensorRT
Shortcuts

Saving models compiled with Torch-TensorRT

Saving models compiled with Torch-TensorRT can be done using torch_tensorrt.save API.

Dynamo IR

The output type of ir=dynamo compilation of Torch-TensorRT is torch.fx.GraphModule object by default. We can save this object in either TorchScript (torch.jit.ScriptModule) or ExportedProgram (torch.export.ExportedProgram) formats by specifying the output_format flag. Here are the options output_format will accept

  • exported_program : This is the default. We perform transformations on the graphmodule first and use torch.export.save to save the module.

  • torchscript : We trace the graphmodule via torch.jit.trace and save it via torch.jit.save.

a) ExportedProgram

Here’s an example usage

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_ep is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
torch_tensorrt.save(trt_gm, "trt.ep", inputs=inputs)

# Later, you can load it and run inference
model = torch.export.load("trt.ep").module()
model(*inputs)

b) Torchscript

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_gm is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs)

# Later, you can load it and run inference
model = torch.jit.load("trt.ts").cuda()
model(*inputs)

Torchscript IR

In Torch-TensorRT 1.X versions, the primary way to compile and run inference with Torch-TensorRT is using Torchscript IR. For ir=ts, this behavior stays the same in 2.X versions as well.

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs) # Output is a ScriptModule object
torch.jit.save(trt_ts, "trt_model.ts")

# Later, you can load it and run inference
model = torch.jit.load("trt_model.ts").cuda()
model(*inputs)

Loading the models

We can load torchscript or exported_program models using torch.jit.load and torch.export.load APIs from PyTorch directly. Alternatively, we provide a light wrapper torch_tensorrt.load(file_path) which can load either of the above model types.

Here’s an example usage

import torch
import torch_tensorrt

# file_path can be trt.ep or trt.ts file obtained via saving the model (refer to the above section)
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
model = torch_tensorrt.load(<file_path>).module()
model(*inputs)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources