.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "beginner/Intro_to_TorchScript_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_beginner_Intro_to_TorchScript_tutorial.py: Introduction to TorchScript =========================== **Authors:** James Reed (jamesreed@fb.com), Michael Suo (suo@fb.com), rev2 This tutorial is an introduction to TorchScript, an intermediate representation of a PyTorch model (subclass of ``nn.Module``) that can then be run in a high-performance environment such as C++. In this tutorial we will cover: 1. The basics of model authoring in PyTorch, including: - Modules - Defining ``forward`` functions - Composing modules into a hierarchy of modules 2. Specific methods for converting PyTorch modules to TorchScript, our high-performance deployment runtime - Tracing an existing module - Using scripting to directly compile a module - How to compose both approaches - Saving and loading TorchScript modules We hope that after you complete this tutorial, you will proceed to go through `the follow-on tutorial `_ which will walk you through an example of actually calling a TorchScript model from C++. .. GENERATED FROM PYTHON SOURCE LINES 33-39 .. code-block:: default import torch # This is all you need to use both PyTorch and TorchScript! print(torch.__version__) torch.manual_seed(191009) # set the seed for reproducibility .. rst-class:: sphx-glr-script-out .. code-block:: none 2.3.0+cu121 .. GENERATED FROM PYTHON SOURCE LINES 40-54 Basics of PyTorch Model Authoring --------------------------------- Let’s start out by defining a simple ``Module``. A ``Module`` is the basic unit of composition in PyTorch. It contains: 1. A constructor, which prepares the module for invocation 2. A set of ``Parameters`` and sub-\ ``Modules``. These are initialized by the constructor and can be used by the module during invocation. 3. A ``forward`` function. This is the code that is run when the module is invoked. Let’s examine a small example: .. GENERATED FROM PYTHON SOURCE LINES 54-69 .. code-block:: default class MyCell(torch.nn.Module): def __init__(self): super(MyCell, self).__init__() def forward(self, x, h): new_h = torch.tanh(x + h) return new_h, new_h my_cell = MyCell() x = torch.rand(3, 4) h = torch.rand(3, 4) print(my_cell(x, h)) .. rst-class:: sphx-glr-script-out .. code-block:: none (tensor([[0.8219, 0.8990, 0.6670, 0.8277], [0.5176, 0.4017, 0.8545, 0.7336], [0.6013, 0.6992, 0.2618, 0.6668]]), tensor([[0.8219, 0.8990, 0.6670, 0.8277], [0.5176, 0.4017, 0.8545, 0.7336], [0.6013, 0.6992, 0.2618, 0.6668]])) .. GENERATED FROM PYTHON SOURCE LINES 70-87 So we’ve: 1. Created a class that subclasses ``torch.nn.Module``. 2. Defined a constructor. The constructor doesn’t do much, just calls the constructor for ``super``. 3. Defined a ``forward`` function, which takes two inputs and returns two outputs. The actual contents of the ``forward`` function are not really important, but it’s sort of a fake `RNN cell `__–that is–it’s a function that is applied on a loop. We instantiated the module, and made ``x`` and ``h``, which are just 3x4 matrices of random values. Then we invoked the cell with ``my_cell(x, h)``. This in turn calls our ``forward`` function. Let’s do something a little more interesting: .. GENERATED FROM PYTHON SOURCE LINES 87-102 .. code-block:: default class MyCell(torch.nn.Module): def __init__(self): super(MyCell, self).__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, x, h): new_h = torch.tanh(self.linear(x) + h) return new_h, new_h my_cell = MyCell() print(my_cell) print(my_cell(x, h)) .. rst-class:: sphx-glr-script-out .. code-block:: none MyCell( (linear): Linear(in_features=4, out_features=4, bias=True) ) (tensor([[ 0.8573, 0.6190, 0.5774, 0.7869], [ 0.3326, 0.0530, 0.0702, 0.8114], [ 0.7818, -0.0506, 0.4039, 0.7967]], grad_fn=), tensor([[ 0.8573, 0.6190, 0.5774, 0.7869], [ 0.3326, 0.0530, 0.0702, 0.8114], [ 0.7818, -0.0506, 0.4039, 0.7967]], grad_fn=)) .. GENERATED FROM PYTHON SOURCE LINES 103-127 We’ve redefined our module ``MyCell``, but this time we’ve added a ``self.linear`` attribute, and we invoke ``self.linear`` in the forward function. What exactly is happening here? ``torch.nn.Linear`` is a ``Module`` from the PyTorch standard library. Just like ``MyCell``, it can be invoked using the call syntax. We are building a hierarchy of ``Module``\ s. ``print`` on a ``Module`` will give a visual representation of the ``Module``\ ’s subclass hierarchy. In our example, we can see our ``Linear`` subclass and its parameters. By composing ``Module``\ s in this way, we can succinctly and readably author models with reusable components. You may have noticed ``grad_fn`` on the outputs. This is a detail of PyTorch’s method of automatic differentiation, called `autograd `__. In short, this system allows us to compute derivatives through potentially complex programs. The design allows for a massive amount of flexibility in model authoring. Now let’s examine said flexibility: .. GENERATED FROM PYTHON SOURCE LINES 127-150 .. code-block:: default class MyDecisionGate(torch.nn.Module): def forward(self, x): if x.sum() > 0: return x else: return -x class MyCell(torch.nn.Module): def __init__(self): super(MyCell, self).__init__() self.dg = MyDecisionGate() self.linear = torch.nn.Linear(4, 4) def forward(self, x, h): new_h = torch.tanh(self.dg(self.linear(x)) + h) return new_h, new_h my_cell = MyCell() print(my_cell) print(my_cell(x, h)) .. rst-class:: sphx-glr-script-out .. code-block:: none MyCell( (dg): MyDecisionGate() (linear): Linear(in_features=4, out_features=4, bias=True) ) (tensor([[ 0.8346, 0.5931, 0.2097, 0.8232], [ 0.2340, -0.1254, 0.2679, 0.8064], [ 0.6231, 0.1494, -0.3110, 0.7865]], grad_fn=), tensor([[ 0.8346, 0.5931, 0.2097, 0.8232], [ 0.2340, -0.1254, 0.2679, 0.8064], [ 0.6231, 0.1494, -0.3110, 0.7865]], grad_fn=)) .. GENERATED FROM PYTHON SOURCE LINES 151-167 We’ve once again redefined our ``MyCell`` class, but here we’ve defined ``MyDecisionGate``. This module utilizes **control flow**. Control flow consists of things like loops and ``if``-statements. Many frameworks take the approach of computing symbolic derivatives given a full program representation. However, in PyTorch, we use a gradient tape. We record operations as they occur, and replay them backwards in computing derivatives. In this way, the framework does not have to explicitly define derivatives for all constructs in the language. .. figure:: https://github.com/pytorch/pytorch/raw/main/docs/source/_static/img/dynamic_graph.gif :alt: How autograd works How autograd works .. GENERATED FROM PYTHON SOURCE LINES 170-182 Basics of TorchScript --------------------- Now let’s take our running example and see how we can apply TorchScript. In short, TorchScript provides tools to capture the definition of your model, even in light of the flexible and dynamic nature of PyTorch. Let’s begin by examining what we call **tracing**. Tracing ``Modules`` ~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 182-199 .. code-block:: default class MyCell(torch.nn.Module): def __init__(self): super(MyCell, self).__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, x, h): new_h = torch.tanh(self.linear(x) + h) return new_h, new_h my_cell = MyCell() x, h = torch.rand(3, 4), torch.rand(3, 4) traced_cell = torch.jit.trace(my_cell, (x, h)) print(traced_cell) traced_cell(x, h) .. rst-class:: sphx-glr-script-out .. code-block:: none MyCell( original_name=MyCell (linear): Linear(original_name=Linear) ) (tensor([[-0.2541, 0.2460, 0.2297, 0.1014], [-0.2329, -0.2911, 0.5641, 0.5015], [ 0.1688, 0.2252, 0.7251, 0.2530]], grad_fn=), tensor([[-0.2541, 0.2460, 0.2297, 0.1014], [-0.2329, -0.2911, 0.5641, 0.5015], [ 0.1688, 0.2252, 0.7251, 0.2530]], grad_fn=)) .. GENERATED FROM PYTHON SOURCE LINES 200-214 We’ve rewinded a bit and taken the second version of our ``MyCell`` class. As before, we’ve instantiated it, but this time, we’ve called ``torch.jit.trace``, passed in the ``Module``, and passed in *example inputs* the network might see. What exactly has this done? It has invoked the ``Module``, recorded the operations that occurred when the ``Module`` was run, and created an instance of ``torch.jit.ScriptModule`` (of which ``TracedModule`` is an instance) TorchScript records its definitions in an Intermediate Representation (or IR), commonly referred to in Deep learning as a *graph*. We can examine the graph with the ``.graph`` property: .. GENERATED FROM PYTHON SOURCE LINES 214-218 .. code-block:: default print(traced_cell.graph) .. rst-class:: sphx-glr-script-out .. code-block:: none graph(%self.1 : __torch__.MyCell, %x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu), %h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)): %linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1) %20 : Tensor = prim::CallMethod[name="forward"](%linear, %x) %11 : int = prim::Constant[value=1]() # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0 %12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0 %13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0 %14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13) return (%14) .. GENERATED FROM PYTHON SOURCE LINES 219-224 However, this is a very low-level representation and most of the information contained in the graph is not useful for end users. Instead, we can use the ``.code`` property to give a Python-syntax interpretation of the code: .. GENERATED FROM PYTHON SOURCE LINES 224-228 .. code-block:: default print(traced_cell.code) .. rst-class:: sphx-glr-script-out .. code-block:: none def forward(self, x: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]: linear = self.linear _0 = torch.tanh(torch.add((linear).forward(x, ), h)) return (_0, _0) .. GENERATED FROM PYTHON SOURCE LINES 229-246 So **why** did we do all this? There are several reasons: 1. TorchScript code can be invoked in its own interpreter, which is basically a restricted Python interpreter. This interpreter does not acquire the Global Interpreter Lock, and so many requests can be processed on the same instance simultaneously. 2. This format allows us to save the whole model to disk and load it into another environment, such as in a server written in a language other than Python 3. TorchScript gives us a representation in which we can do compiler optimizations on the code to provide more efficient execution 4. TorchScript allows us to interface with many backend/device runtimes that require a broader view of the program than individual operators. We can see that invoking ``traced_cell`` produces the same results as the Python module: .. GENERATED FROM PYTHON SOURCE LINES 246-251 .. code-block:: default print(my_cell(x, h)) print(traced_cell(x, h)) .. rst-class:: sphx-glr-script-out .. code-block:: none (tensor([[-0.2541, 0.2460, 0.2297, 0.1014], [-0.2329, -0.2911, 0.5641, 0.5015], [ 0.1688, 0.2252, 0.7251, 0.2530]], grad_fn=), tensor([[-0.2541, 0.2460, 0.2297, 0.1014], [-0.2329, -0.2911, 0.5641, 0.5015], [ 0.1688, 0.2252, 0.7251, 0.2530]], grad_fn=)) (tensor([[-0.2541, 0.2460, 0.2297, 0.1014], [-0.2329, -0.2911, 0.5641, 0.5015], [ 0.1688, 0.2252, 0.7251, 0.2530]], grad_fn=), tensor([[-0.2541, 0.2460, 0.2297, 0.1014], [-0.2329, -0.2911, 0.5641, 0.5015], [ 0.1688, 0.2252, 0.7251, 0.2530]], grad_fn=)) .. GENERATED FROM PYTHON SOURCE LINES 252-258 Using Scripting to Convert Modules ---------------------------------- There’s a reason we used version two of our module, and not the one with the control-flow-laden submodule. Let’s examine that now: .. GENERATED FROM PYTHON SOURCE LINES 258-283 .. code-block:: default class MyDecisionGate(torch.nn.Module): def forward(self, x): if x.sum() > 0: return x else: return -x class MyCell(torch.nn.Module): def __init__(self, dg): super(MyCell, self).__init__() self.dg = dg self.linear = torch.nn.Linear(4, 4) def forward(self, x, h): new_h = torch.tanh(self.dg(self.linear(x)) + h) return new_h, new_h my_cell = MyCell(MyDecisionGate()) traced_cell = torch.jit.trace(my_cell, (x, h)) print(traced_cell.dg.code) print(traced_cell.code) .. rst-class:: sphx-glr-script-out .. code-block:: none /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:261: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! def forward(self, argument_1: Tensor) -> NoneType: return None def forward(self, x: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]: dg = self.dg linear = self.linear _0 = (linear).forward(x, ) _1 = (dg).forward(_0, ) _2 = torch.tanh(torch.add(_0, h)) return (_2, _2) .. GENERATED FROM PYTHON SOURCE LINES 284-295 Looking at the ``.code`` output, we can see that the ``if-else`` branch is nowhere to be found! Why? Tracing does exactly what we said it would: run the code, record the operations *that happen* and construct a ``ScriptModule`` that does exactly that. Unfortunately, things like control flow are erased. How can we faithfully represent this module in TorchScript? We provide a **script compiler**, which does direct analysis of your Python source code to transform it into TorchScript. Let’s convert ``MyDecisionGate`` using the script compiler: .. GENERATED FROM PYTHON SOURCE LINES 295-305 .. code-block:: default scripted_gate = torch.jit.script(MyDecisionGate()) my_cell = MyCell(scripted_gate) scripted_cell = torch.jit.script(my_cell) print(scripted_gate.code) print(scripted_cell.code) .. rst-class:: sphx-glr-script-out .. code-block:: none def forward(self, x: Tensor) -> Tensor: if bool(torch.gt(torch.sum(x), 0)): _0 = x else: _0 = torch.neg(x) return _0 def forward(self, x: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]: dg = self.dg linear = self.linear _0 = torch.add((dg).forward((linear).forward(x, ), ), h) new_h = torch.tanh(_0) return (new_h, new_h) .. GENERATED FROM PYTHON SOURCE LINES 306-309 Hooray! We’ve now faithfully captured the behavior of our program in TorchScript. Let’s now try running the program: .. GENERATED FROM PYTHON SOURCE LINES 309-315 .. code-block:: default # New inputs x, h = torch.rand(3, 4), torch.rand(3, 4) print(scripted_cell(x, h)) .. rst-class:: sphx-glr-script-out .. code-block:: none (tensor([[ 0.5679, 0.5762, 0.2506, -0.0734], [ 0.5228, 0.7122, 0.6985, -0.0656], [ 0.6187, 0.4487, 0.7456, -0.0238]], grad_fn=), tensor([[ 0.5679, 0.5762, 0.2506, -0.0734], [ 0.5228, 0.7122, 0.6985, -0.0656], [ 0.6187, 0.4487, 0.7456, -0.0238]], grad_fn=)) .. GENERATED FROM PYTHON SOURCE LINES 316-328 Mixing Scripting and Tracing ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Some situations call for using tracing rather than scripting (e.g. a module has many architectural decisions that are made based on constant Python values that we would like to not appear in TorchScript). In this case, scripting can be composed with tracing: ``torch.jit.script`` will inline the code for a traced module, and tracing will inline the code for a scripted module. An example of the first case: .. GENERATED FROM PYTHON SOURCE LINES 328-345 .. code-block:: default class MyRNNLoop(torch.nn.Module): def __init__(self): super(MyRNNLoop, self).__init__() self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h)) def forward(self, xs): h, y = torch.zeros(3, 4), torch.zeros(3, 4) for i in range(xs.size(0)): y, h = self.cell(xs[i], h) return y, h rnn_loop = torch.jit.script(MyRNNLoop()) print(rnn_loop.code) .. rst-class:: sphx-glr-script-out .. code-block:: none def forward(self, xs: Tensor) -> Tuple[Tensor, Tensor]: h = torch.zeros([3, 4]) y = torch.zeros([3, 4]) y0 = y h0 = h for i in range(torch.size(xs, 0)): cell = self.cell _0 = (cell).forward(torch.select(xs, 0, i), h0, ) y1, h1, = _0 y0, h0 = y1, h1 return (y0, h0) .. GENERATED FROM PYTHON SOURCE LINES 346-348 And an example of the second case: .. GENERATED FROM PYTHON SOURCE LINES 348-362 .. code-block:: default class WrapRNN(torch.nn.Module): def __init__(self): super(WrapRNN, self).__init__() self.loop = torch.jit.script(MyRNNLoop()) def forward(self, xs): y, h = self.loop(xs) return torch.relu(y) traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4))) print(traced.code) .. rst-class:: sphx-glr-script-out .. code-block:: none def forward(self, xs: Tensor) -> Tensor: loop = self.loop _0, y, = (loop).forward(xs, ) return torch.relu(y) .. GENERATED FROM PYTHON SOURCE LINES 363-375 This way, scripting and tracing can be used when the situation calls for each of them and used together. Saving and Loading models ------------------------- We provide APIs to save and load TorchScript modules to/from disk in an archive format. This format includes code, parameters, attributes, and debug information, meaning that the archive is a freestanding representation of the model that can be loaded in an entirely separate process. Let’s save and load our wrapped RNN module: .. GENERATED FROM PYTHON SOURCE LINES 375-384 .. code-block:: default traced.save('wrapped_rnn.pt') loaded = torch.jit.load('wrapped_rnn.pt') print(loaded) print(loaded.code) .. rst-class:: sphx-glr-script-out .. code-block:: none RecursiveScriptModule( original_name=WrapRNN (loop): RecursiveScriptModule( original_name=MyRNNLoop (cell): RecursiveScriptModule( original_name=MyCell (dg): RecursiveScriptModule(original_name=MyDecisionGate) (linear): RecursiveScriptModule(original_name=Linear) ) ) ) def forward(self, xs: Tensor) -> Tensor: loop = self.loop _0, y, = (loop).forward(xs, ) return torch.relu(y) .. GENERATED FROM PYTHON SOURCE LINES 385-399 As you can see, serialization preserves the module hierarchy and the code we’ve been examining throughout. The model can also be loaded, for example, `into C++ `__ for python-free execution. Further Reading ~~~~~~~~~~~~~~~ We’ve completed our tutorial! For a more involved demonstration, check out the NeurIPS demo for converting machine translation models using TorchScript: https://colab.research.google.com/drive/1HiICg6jRkBnr5hvK2-VnMi88Vi9pUzEJ .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.233 seconds) .. _sphx_glr_download_beginner_Intro_to_TorchScript_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: Intro_to_TorchScript_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: Intro_to_TorchScript_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_