• Docs >
  • Creating a TorchScript Module

Creating a TorchScript Module

TorchScript is a way to create serializable and optimizable models from PyTorch code. PyTorch has detailed documentation on how to do this https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html but briefly here is the here is key background information and the process:

PyTorch programs are based around Module s which can be used to compose higher level modules. Modules contain a constructor to set up the modules, parameters and sub-modules and a forward function which describes how to use the parameters and submodules when the module is invoked.

For example, we can define a LeNet module like this:

 1import torch.nn as nn
 2import torch.nn.functional as F
 5class LeNetFeatExtractor(nn.Module):
 6    def __init__(self):
 7        super(LeNetFeatExtractor, self).__init__()
 8        self.conv1 = nn.Conv2d(1, 6, 3)
 9        self.conv2 = nn.Conv2d(6, 16, 3)
11    def forward(self, x):
12        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
13        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
14        return x
17class LeNetClassifier(nn.Module):
18    def __init__(self):
19        super(LeNetClassifier, self).__init__()
20        self.fc1 = nn.Linear(16 * 6 * 6, 120)
21        self.fc2 = nn.Linear(120, 84)
22        self.fc3 = nn.Linear(84, 10)
24    def forward(self, x):
25        x = torch.flatten(x, 1)
26        x = F.relu(self.fc1(x))
27        x = F.relu(self.fc2(x))
28        x = self.fc3(x)
29        return x
32class LeNet(nn.Module):
33    def __init__(self):
34        super(LeNet, self).__init__()
35        self.feat = LeNetFeatExtractor()
36        self.classifer = LeNetClassifier()
38    def forward(self, x):
39        x = self.feat(x)
40        x = self.classifer(x)
41        return x


Obviously you may want to consolidate such a simple model into a single module but we can see the composability of PyTorch here

From here are two pathways for going from PyTorch Python code to TorchScript code: Tracing and Scripting.

Tracing follows the path of execution when the module is called and records what happens. To trace an instance of our LeNet module, we can call torch.jit.trace with an example input.

import torch

model = LeNet()
input_data = torch.empty([1, 1, 32, 32])
traced_model = torch.jit.trace(model, input_data)

Scripting actually inspects your code with a compiler and generates an equivalent TorchScript program. The difference is that since tracing is following the execution of your module, it cannot pick up control flow for instance. By working from the Python code, the compiler can include these components. We can run the script compiler on our LeNet module by calling torch.jit.script

import torch

model = LeNet()
script_model = torch.jit.script(model)

There are reasons to use one path or another, the PyTorch documentation has information on how to choose. From a Torch-TensorRT prespective, there is better support (i.e your module is more likely to compile) for traced modules because it doesn’t include all the complexities of a complete programming language, though both paths supported.

After scripting or tracing your module, you are given back a TorchScript Module. This contains the code and parameters used to run the module stored in a intermediate representation that Torch-TensorRT can consume.

Here is what the LeNet traced module IR looks like:

graph(%self.1 : __torch__.___torch_mangle_10.LeNet,
    %input.1 : Float(1, 1, 32, 32)):
    %129 : __torch__.___torch_mangle_9.LeNetClassifier = prim::GetAttr[name="classifer"](%self.1)
    %119 : __torch__.___torch_mangle_5.LeNetFeatExtractor = prim::GetAttr[name="feat"](%self.1)
    %137 : Tensor = prim::CallMethod[name="forward"](%119, %input.1)
    %138 : Tensor = prim::CallMethod[name="forward"](%129, %137)
    return (%138)

and the LeNet scripted module IR:

graph(%self : __torch__.LeNet,
    %x.1 : Tensor):
    %2 : __torch__.LeNetFeatExtractor = prim::GetAttr[name="feat"](%self)
    %x.3 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) # x.py:38:12
    %5 : __torch__.LeNetClassifier = prim::GetAttr[name="classifer"](%self)
    %x.5 : Tensor = prim::CallMethod[name="forward"](%5, %x.3) # x.py:39:12
    return (%x.5)

You can see that the IR preserves the module structure we have in our python code.

Working with TorchScript in Python

TorchScript Modules are run the same way you run normal PyTorch modules. You can run the forward pass using the forward method or just calling the module torch_scirpt_module(in_tensor) The JIT compiler will compile and optimize the module on the fly and then returns the results.

Saving TorchScript Module to Disk

For either traced or scripted modules, you can save the module to disk with the following command

import torch

model = LeNet()
script_model = torch.jit.script(model)


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources