Creating a TorchScript Module¶
TorchScript is a way to create serializable and optimizable models from PyTorch code. PyTorch has detailed documentation on how to do this https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html but briefly here is the here is key background information and the process:
PyTorch programs are based around Module
s which can be used to compose higher level modules. Modules
contain a constructor to set up the modules, parameters and sub-modules
and a forward function which describes how to use the parameters and submodules when the module is invoked.
For example, we can define a LeNet module like this:
1import torch.nn as nn
2import torch.nn.functional as F
3
4
5class LeNetFeatExtractor(nn.Module):
6 def __init__(self):
7 super(LeNetFeatExtractor, self).__init__()
8 self.conv1 = nn.Conv2d(1, 6, 3)
9 self.conv2 = nn.Conv2d(6, 16, 3)
10
11 def forward(self, x):
12 x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
13 x = F.max_pool2d(F.relu(self.conv2(x)), 2)
14 return x
15
16
17class LeNetClassifier(nn.Module):
18 def __init__(self):
19 super(LeNetClassifier, self).__init__()
20 self.fc1 = nn.Linear(16 * 6 * 6, 120)
21 self.fc2 = nn.Linear(120, 84)
22 self.fc3 = nn.Linear(84, 10)
23
24 def forward(self, x):
25 x = torch.flatten(x, 1)
26 x = F.relu(self.fc1(x))
27 x = F.relu(self.fc2(x))
28 x = self.fc3(x)
29 return x
30
31
32class LeNet(nn.Module):
33 def __init__(self):
34 super(LeNet, self).__init__()
35 self.feat = LeNetFeatExtractor()
36 self.classifier = LeNetClassifier()
37
38 def forward(self, x):
39 x = self.feat(x)
40 x = self.classifier(x)
41 return x
.
Obviously you may want to consolidate such a simple model into a single module but we can see the composability of PyTorch here
From here are two pathways for going from PyTorch Python code to TorchScript code: Tracing and Scripting.
Tracing follows the path of execution when the module is called and records what happens.
To trace an instance of our LeNet module, we can call torch.jit.trace
with an example input.
import torch
model = LeNet()
input_data = torch.empty([1, 1, 32, 32])
traced_model = torch.jit.trace(model, input_data)
Scripting actually inspects your code with a compiler and generates an equivalent TorchScript program. The difference is that since tracing
is following the execution of your module, it cannot pick up control flow for instance. By working from the Python code, the compiler can
include these components. We can run the script compiler on our LeNet module by calling torch.jit.script
import torch
model = LeNet()
script_model = torch.jit.script(model)
There are reasons to use one path or another, the PyTorch documentation has information on how to choose. From a Torch-TensorRT perspective, there is better support (i.e your module is more likely to compile) for traced modules because it doesn’t include all the complexities of a complete programming language, though both paths supported.
After scripting or tracing your module, you are given back a TorchScript Module. This contains the code and parameters used to run the module stored in a intermediate representation that Torch-TensorRT can consume.
Here is what the LeNet traced module IR looks like:
graph(%self.1 : __torch__.___torch_mangle_10.LeNet,
%input.1 : Float(1, 1, 32, 32)):
%129 : __torch__.___torch_mangle_9.LeNetClassifier = prim::GetAttr[name="classifier"](%self.1)
%119 : __torch__.___torch_mangle_5.LeNetFeatExtractor = prim::GetAttr[name="feat"](%self.1)
%137 : Tensor = prim::CallMethod[name="forward"](%119, %input.1)
%138 : Tensor = prim::CallMethod[name="forward"](%129, %137)
return (%138)
and the LeNet scripted module IR:
graph(%self : __torch__.LeNet,
%x.1 : Tensor):
%2 : __torch__.LeNetFeatExtractor = prim::GetAttr[name="feat"](%self)
%x.3 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) # x.py:38:12
%5 : __torch__.LeNetClassifier = prim::GetAttr[name="classifier"](%self)
%x.5 : Tensor = prim::CallMethod[name="forward"](%5, %x.3) # x.py:39:12
return (%x.5)
You can see that the IR preserves the module structure we have in our python code.
Working with TorchScript in Python¶
TorchScript Modules are run the same way you run normal PyTorch modules. You can run the forward pass using the
forward
method or just calling the module torch_script_module(in_tensor)
The JIT compiler will compile
and optimize the module on the fly and then returns the results.
Saving TorchScript Module to Disk¶
For either traced or scripted modules, you can save the module to disk with the following command
import torch
model = LeNet()
script_model = torch.jit.script(model)
script_model.save("lenet_scripted.ts")