Introduction to TorchScript¶
James Reed (email@example.com), Michael Suo (firstname.lastname@example.org), rev2
This tutorial is an introduction to TorchScript, an intermediate
representation of a PyTorch model (subclass of
can then be run in a high-performance environment such as C++.
In this tutorial we will cover:
- The basics of model authoring in PyTorch, including:
- Composing modules into a hierarchy of modules
- Specific methods for converting PyTorch modules to TorchScript, our high-performance deployment runtime
- Tracing an existing module
- Using scripting to directly compile a module
- How to compose both approaches
- Saving and loading TorchScript modules
We hope that after you complete this tutorial, you will proceed to go through the follow-on tutorial which will walk you through an example of actually calling a TorchScript model from C++.
import torch # This is all you need to use both PyTorch and TorchScript! print(torch.__version__)
Basics of TorchScript¶
Now let’s take our running example and see how we can apply TorchScript.
In short, TorchScript provides tools to capture the definition of your model, even in light of the flexible and dynamic nature of PyTorch. Let’s begin by examining what we call tracing.
class MyCell(torch.nn.Module): def __init__(self): super(MyCell, self).__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, x, h): new_h = torch.tanh(self.linear(x) + h) return new_h, new_h my_cell = MyCell() x, h = torch.rand(3, 4), torch.rand(3, 4) traced_cell = torch.jit.trace(my_cell, (x, h)) print(traced_cell) traced_cell(x, h)
We’ve rewinded a bit and taken the second version of our
class. As before, we’ve instantiated it, but this time, we’ve called
torch.jit.trace, passed in the
Module, and passed in example
inputs the network might see.
What exactly has this done? It has invoked the
Module, recorded the
operations that occured when the
Module was run, and created an
torch.jit.ScriptModule (of which
TracedModule is an
TorchScript records its definitions in an Intermediate Representation
(or IR), commonly referred to in Deep learning as a graph. We can
examine the graph with the
However, this is a very low-level representation and most of the
information contained in the graph is not useful for end users. Instead,
we can use the
.code property to give a Python-syntax interpretation
of the code:
So why did we do all this? There are several reasons:
- TorchScript code can be invoked in its own interpreter, which is basically a restricted Python interpreter. This interpreter does not acquire the Global Interpreter Lock, and so many requests can be processed on the same instance simultaneously.
- This format allows us to save the whole model to disk and load it into another environment, such as in a server written in a language other than Python
- TorchScript gives us a representation in which we can do compiler optimizations on the code to provide more efficient execution
- TorchScript allows us to interface with many backend/device runtimes that require a broader view of the program than individual operators.
We can see that invoking
traced_cell produces the same results as
the Python module:
print(my_cell(x, h)) print(traced_cell(x, h))
Using Scripting to Convert Modules¶
There’s a reason we used version two of our module, and not the one with the control-flow-laden submodule. Let’s examine that now:
class MyDecisionGate(torch.nn.Module): def forward(self, x): if x.sum() > 0: return x else: return -x class MyCell(torch.nn.Module): def __init__(self, dg): super(MyCell, self).__init__() self.dg = dg self.linear = torch.nn.Linear(4, 4) def forward(self, x, h): new_h = torch.tanh(self.dg(self.linear(x)) + h) return new_h, new_h my_cell = MyCell(MyDecisionGate()) traced_cell = torch.jit.trace(my_cell, (x, h)) print(traced_cell.code)
Looking at the
.code output, we can see that the
is nowhere to be found! Why? Tracing does exactly what we said it would:
run the code, record the operations that happen and construct a
ScriptModule that does exactly that. Unfortunately, things like control
flow are erased.
How can we faithfully represent this module in TorchScript? We provide a
script compiler, which does direct analysis of your Python source
code to transform it into TorchScript. Let’s convert
using the script compiler:
scripted_gate = torch.jit.script(MyDecisionGate()) my_cell = MyCell(scripted_gate) traced_cell = torch.jit.script(my_cell) print(traced_cell.code)
Hooray! We’ve now faithfully captured the behavior of our program in TorchScript. Let’s now try running the program:
# New inputs x, h = torch.rand(3, 4), torch.rand(3, 4) traced_cell(x, h)
Mixing Scripting and Tracing¶
Some situations call for using tracing rather than scripting (e.g. a
module has many architectural decisions that are made based on constant
Python values that we would like to not appear in TorchScript). In this
case, scripting can be composed with tracing:
inline the code for a traced module, and tracing will inline the code
for a scripted module.
An example of the first case:
class MyRNNLoop(torch.nn.Module): def __init__(self): super(MyRNNLoop, self).__init__() self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h)) def forward(self, xs): h, y = torch.zeros(3, 4), torch.zeros(3, 4) for i in range(xs.size(0)): y, h = self.cell(xs[i], h) return y, h rnn_loop = torch.jit.script(MyRNNLoop()) print(rnn_loop.code)
And an example of the second case:
class WrapRNN(torch.nn.Module): def __init__(self): super(WrapRNN, self).__init__() self.loop = torch.jit.script(MyRNNLoop()) def forward(self, xs): y, h = self.loop(xs) return torch.relu(y) traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4))) print(traced.code)
This way, scripting and tracing can be used when the situation calls for each of them and used together.
Saving and Loading models¶
We provide APIs to save and load TorchScript modules to/from disk in an archive format. This format includes code, parameters, attributes, and debug information, meaning that the archive is a freestanding representation of the model that can be loaded in an entirely separate process. Let’s save and load our wrapped RNN module:
traced.save('wrapped_rnn.zip') loaded = torch.jit.load('wrapped_rnn.zip') print(loaded) print(loaded.code)
As you can see, serialization preserves the module hierarchy and the code we’ve been examining throughout. The model can also be loaded, for example, into C++ for python-free execution.
We’ve completed our tutorial! For a more involved demonstration, check out the NeurIPS demo for converting machine translation models using TorchScript: https://colab.research.google.com/drive/1HiICg6jRkBnr5hvK2-VnMi88Vi9pUzEJ
Total running time of the script: ( 0 minutes 0.000 seconds)