Saving models compiled with Torch-TensorRT¶
Saving models compiled with Torch-TensorRT can be done using torch_tensorrt.save API.
Dynamo IR¶
The output type of ir=dynamo compilation of Torch-TensorRT is torch.fx.GraphModule object by default. We can save this object in either TorchScript (torch.jit.ScriptModule) or ExportedProgram (torch.export.ExportedProgram) formats by specifying the output_format flag. Here are the options output_format will accept
exported_program : This is the default. We perform transformations on the graphmodule first and use torch.export.save to save the module.
torchscript : We trace the graphmodule via torch.jit.trace and save it via torch.jit.save.
a) ExportedProgram¶
Here’s an example usage
import torch
import torch_tensorrt
model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_ep is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
torch_tensorrt.save(trt_gm, "trt.ep", inputs=inputs)
# Later, you can load it and run inference
model = torch.export.load("trt.ep").module()
model(*inputs)
b) Torchscript¶
import torch
import torch_tensorrt
model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_gm is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs)
# Later, you can load it and run inference
model = torch.jit.load("trt.ts").cuda()
model(*inputs)
Torchscript IR¶
In Torch-TensorRT 1.X versions, the primary way to compile and run inference with Torch-TensorRT is using Torchscript IR. For ir=ts, this behavior stays the same in 2.X versions as well.
import torch
import torch_tensorrt
model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs) # Output is a ScriptModule object
torch.jit.save(trt_ts, "trt_model.ts")
# Later, you can load it and run inference
model = torch.jit.load("trt_model.ts").cuda()
model(*inputs)
Loading the models¶
We can load torchscript or exported_program models using torch.jit.load and torch.export.load APIs from PyTorch directly. Alternatively, we provide a light wrapper torch_tensorrt.load(file_path) which can load either of the above model types.
Here’s an example usage
import torch
import torch_tensorrt
# file_path can be trt.ep or trt.ts file obtained via saving the model (refer to the above section)
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
model = torch_tensorrt.load(<file_path>).module()
model(*inputs)