.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "beginner/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_beginner_hybrid_frontend_learning_hybrid_frontend_through_example_tutorial.py: Learning Hybrid Frontend Syntax Through Example =============================================== **Author:** `Nathan Inkawhich `_ This document is meant to highlight the syntax of the Hybrid Frontend through a non-code intensive example. The Hybrid Frontend is one of the new shiny features of Pytorch 1.0 and provides an avenue for developers to transition their models from **eager** to **graph** mode. PyTorch users are very familiar with eager mode as it provides the ease-of-use and flexibility that we all enjoy as researchers. Caffe2 users are more aquainted with graph mode which has the benefits of speed, optimization opportunities, and functionality in C++ runtime environments. The hybrid frontend bridges the gap between the the two modes by allowing researchers to develop and refine their models in eager mode (i.e. PyTorch), then gradually transition the proven model to graph mode for production, when speed and resouce consumption become critical. Hybrid Frontend Information --------------------------- The process for transitioning a model to graph mode is as follows. First, the developer constructs, trains, and tests the model in eager mode. Then they incrementally **trace** and **script** each function/module of the model with the Just-In-Time (JIT) compiler, at each step verifying that the output is correct. Finally, when each of the components of the top-level model have been traced and scripted, the model itself is traced. At which point the model has been transitioned to graph mode, and has a complete python-free representation. With this representation, the model runtime can take advantage of high-performance Caffe2 operators and graph based optimizations. Before we continue, it is important to understand the idea of tracing and scripting, and why they are separate. The goal of **trace** and **script** is the same, and that is to create a graph representation of the operations taking place in a given function. The discrepency comes from the flexibility of eager mode that allows for **data-dependent control flows** within the model architecture. When a function does NOT have a data-dependent control flow, it may be *traced* with ``torch.jit.trace``. However, when the function *has* a data-dependent control flow it must be *scripted* with ``torch.jit.script``. We will leave the details of the interworkings of the hybrid frontend for another document, but the code example below will show the syntax of how to trace and script different pure python functions and torch Modules. Hopefully, you will find that using the hybrid frontend is non-intrusive as it mostly involves adding decorators to the existing function and class definitions. Motivating Example ------------------ In this example we will implement a strange math function that may be logically broken up into four parts that do, and do not contain data-dependent control flows. The purpose here is to show a non-code intensive example where the use of the JIT is highlighted. This example is a stand-in representation of a useful model, whose implementation has been divided into various pure python functions and modules. The function we seek to implement, :math:`Y(x)`, is defined for :math:`x \epsilon \mathbb{N}` as .. math:: z(x) = \Biggl \lfloor \frac{\sqrt{\prod_{i=1}^{|2 x|}i}}{5} \Biggr \rfloor .. math:: Y(x) = \begin{cases} \frac{z(x)}{2} & \text{if } z(x)\%2 == 0, \\ z(x) & \text{otherwise} \end{cases} .. math:: \begin{array}{| r | r |} \hline x &1 &2 &3 &4 &5 &6 &7 \\ \hline Y(x) &0 &0 &-5 &20 &190 &-4377 &-59051 \\ \hline \end{array} As mentioned, the computation is split into four parts. Part one is the simple tensor calculation of :math:`|2x|`, which can be traced. Part two is the iterative product calculation that represents a data dependent control flow to be scripted (the number of loop iteration depends on the input at runtime). Part three is a trace-able :math:`\lfloor \sqrt{a/5} \rfloor` calculation. Finally, part 4 handles the output cases depending on the value of :math:`z(x)` and must be scripted due to the data dependency. Now, let's see how this looks in code. Part 1 - Tracing a pure python function ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We can implement part one as a pure python function as below. Notice, to trace this function we call ``torch.jit.trace`` and pass in the function to be traced. Since the trace requires a dummy input of the expected runtime type and shape, we also include the ``torch.rand`` to generate a single valued torch tensor. .. GENERATED FROM PYTHON SOURCE LINES 101-111 .. code-block:: default import torch def fn(x): return torch.abs(2*x) # This is how you define a traced function # Pass in both the function to be traced and an example input to ``torch.jit.trace`` traced_fn = torch.jit.trace(fn, torch.rand(())) .. GENERATED FROM PYTHON SOURCE LINES 112-121 Part 2 - Scripting a pure python function ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We can also implement part 2 as a pure python function where we iteratively compute the product. Since the number of iterations depends on the value of the input, we have a data dependent control flow, so the function must be scripted. We can script python functions simply with the ``@torch.jit.script`` decorator. .. GENERATED FROM PYTHON SOURCE LINES 121-132 .. code-block:: default # This is how you define a script function # Apply this decorator directly to the function @torch.jit.script def script_fn(x): z = torch.ones([1], dtype=torch.int64) for i in range(int(x)): z = z * (i + 1) return z .. GENERATED FROM PYTHON SOURCE LINES 133-141 Part 3 - Tracing a nn.Module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Next, we will implement part 3 of the computation within the forward function of a ``torch.nn.Module``. This module may be traced, but rather than adding a decorator here, we will handle the tracing where the Module is constructed. Thus, the class definition is not changed at all. .. GENERATED FROM PYTHON SOURCE LINES 141-149 .. code-block:: default # This is a normal module that can be traced. class TracedModule(torch.nn.Module): def forward(self, x): x = x.type(torch.float32) return torch.floor(torch.sqrt(x) / 5.) .. GENERATED FROM PYTHON SOURCE LINES 150-158 Part 4 - Scripting a nn.Module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ In the final part of the computation we have a ``torch.nn.Module`` that must be scripted. To accomodate this, we inherit from ``torch.jit.ScriptModule`` and add the ``@torch.jit.script_method`` decorator to the forward function. .. GENERATED FROM PYTHON SOURCE LINES 158-171 .. code-block:: default # This is how you define a scripted module. # The module should inherit from ScriptModule and the forward should have the # script_method decorator applied to it. class ScriptModule(torch.jit.ScriptModule): @torch.jit.script_method def forward(self, x): r = -x if int(torch.fmod(x, 2.0)) == 0.0: r = x / 2.0 return r .. GENERATED FROM PYTHON SOURCE LINES 172-195 Top-Level Module ~~~~~~~~~~~~~~~~ Now we will put together the pieces of the computation via a top level module called ``Net``. In the constructor, we will instantiate the ``TracedModule`` and ``ScriptModule`` as attributes. This must be done because we ultimately want to trace/script the top level module, and having the traced/scripted modules as attributes allows the Net to inherit the required submodules' parameters. Notice, this is where we actually trace the ``TracedModule`` by calling ``torch.jit.trace()`` and providing the necessary dummy input. Also notice that the ``ScriptModule`` is constructed as normal because we handled the scripting in the class definition. Here we can also print the graphs created for each individual part of the computation. The printed graphs allows us to see how the JIT ultimately interpreted the functions as graph computations. Finally, we define the ``forward`` function for the Net module where we run the input data ``x`` through the four parts of the computation. There is no strange syntax here and we call the traced and scripted modules and functions as expected. .. GENERATED FROM PYTHON SOURCE LINES 195-228 .. code-block:: default # This is a demonstration net that calls all of the different types of # methods and functions class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() # Modules must be attributes on the Module because if you want to trace # or script this Module, we must be able to inherit the submodules' # params. self.traced_module = torch.jit.trace(TracedModule(), torch.rand(())) self.script_module = ScriptModule() print('traced_fn graph', traced_fn.graph) print('script_fn graph', script_fn.graph) print('TracedModule graph', self.traced_module.__getattr__('forward').graph) print('ScriptModule graph', self.script_module.__getattr__('forward').graph) def forward(self, x): # Call a traced function x = traced_fn(x) # Call a script function x = script_fn(x) # Call a traced submodule x = self.traced_module(x) # Call a scripted submodule x = self.script_module(x) return x .. GENERATED FROM PYTHON SOURCE LINES 229-237 Running the Model ~~~~~~~~~~~~~~~~~ All that's left to do is construct the Net and compute the output through the forward function. Here, we use :math:`x=5` as the test input value and expect :math:`Y(x)=190.` Also, check out the graphs that were printed during the construction of the Net. .. GENERATED FROM PYTHON SOURCE LINES 237-243 .. code-block:: default # Instantiate this net and run it n = Net() print(n(torch.tensor([5]))) # 190. .. rst-class:: sphx-glr-script-out .. code-block:: none traced_fn graph graph(%x : Float(requires_grad=0, device=cpu)): %1 : Long(requires_grad=0, device=cpu) = prim::Constant[value={2}]() # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:105:0 %2 : Float(requires_grad=0, device=cpu) = aten::mul(%x, %1) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:105:0 %3 : Float(requires_grad=0, device=cpu) = aten::abs(%2) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:105:0 return (%3) script_fn graph graph(%x.1 : Tensor): %13 : bool = prim::Constant[value=1]() # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:127:4 %4 : NoneType = prim::Constant() %3 : int = prim::Constant[value=4]() # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:126:30 %1 : int = prim::Constant[value=1]() # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:126:20 %2 : int[] = prim::ListConstruct(%1) %z.1 : Tensor = aten::ones(%2, %3, %4, %4, %4) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:126:8 %10 : int = aten::Int(%x.1) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:127:19 %z : Tensor = prim::Loop(%10, %13, %z.1) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:127:4 block0(%i.1 : int, %z.11 : Tensor): %17 : int = aten::add(%i.1, %1) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:128:17 %z.5 : Tensor = aten::mul(%z.11, %17) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:128:12 -> (%13, %z.5) return (%z) TracedModule graph graph(%self : __torch__.TracedModule, %x.1 : Float(requires_grad=0, device=cpu)): %4 : int = prim::Constant[value=6]() # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:145:0 %5 : bool = prim::Constant[value=0]() # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:145:0 %6 : bool = prim::Constant[value=0]() # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:145:0 %7 : NoneType = prim::Constant() %x : Float(requires_grad=0, device=cpu) = aten::to(%x.1, %4, %5, %6, %7) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:145:0 %9 : Float(requires_grad=0, device=cpu) = aten::sqrt(%x) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:146:0 %10 : Double(requires_grad=0, device=cpu) = prim::Constant[value={5}]() # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:146:0 %11 : Float(requires_grad=0, device=cpu) = aten::div(%9, %10) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:146:0 %12 : Float(requires_grad=0, device=cpu) = aten::floor(%11) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:146:0 return (%12) ScriptModule graph graph(%self : __torch__.ScriptModule, %x.1 : Tensor): %5 : float = prim::Constant[value=2.]() # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:166:29 %9 : float = prim::Constant[value=0.]() # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:166:38 %r.1 : Tensor = aten::neg(%x.1) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:165:12 %6 : Tensor = aten::fmod(%x.1, %5) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:166:15 %8 : int = aten::Int(%6) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:166:11 %10 : bool = aten::eq(%8, %9) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:166:11 %r : Tensor = prim::If(%10) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:166:8 block0(): %r.3 : Tensor = aten::div(%x.1, %5) # /var/lib/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:167:16 -> (%r.3) block1(): -> (%r.1) return (%r) tensor([190.]) .. GENERATED FROM PYTHON SOURCE LINES 244-254 Tracing the Top-Level Model ~~~~~~~~~~~~~~~~~~~~~~~~~~~ The last part of the example is to trace the top-level module, ``Net``. As mentioned previously, since the traced/scripted modules are attributes of Net, we are able to trace ``Net`` as it inherits the parameters of the traced/scripted submodules. Note, the syntax for tracing Net is identical to the syntax for tracing ``TracedModule``. Also, check out the graph that is created. .. GENERATED FROM PYTHON SOURCE LINES 254-260 .. code-block:: default n_traced = torch.jit.trace(n, torch.tensor([5])) print(n_traced(torch.tensor([5]))) print('n_traced graph', n_traced.graph) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([190.]) n_traced graph graph(%self : __torch__.Net, %x.1 : Long(1, strides=[1], requires_grad=0, device=cpu)): %script_module : __torch__.ScriptModule = prim::GetAttr[name="script_module"](%self) %traced_module : __torch__.TracedModule = prim::GetAttr[name="traced_module"](%self) %10 : Function = prim::Constant[name="fn"]() %x : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::CallFunction(%10, %x.1) %12 : Function = prim::Constant[name="script_fn"]() %13 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::CallFunction(%12, %x) %14 : Float(1, strides=[1], requires_grad=0, device=cpu) = prim::CallMethod[name="forward"](%traced_module, %13) %15 : Float(1, strides=[1], requires_grad=0, device=cpu) = prim::CallMethod[name="forward"](%script_module, %14) return (%15) .. GENERATED FROM PYTHON SOURCE LINES 261-271 Hopefully, this document can serve as an introduction to the hybrid frontend as well as a syntax reference guide for more experienced users. Also, there are a few things to keep in mind when using the hybrid frontend. There is a constraint that traced/scripted methods must be written in a restricted subset of python, as features like generators, defs, and Python data structures are not supported. As a workaround, the scripting model *is* designed to work with both traced and non-traced code which means you can call non-traced code from traced functions. However, such a model may not be exported to ONNX. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.111 seconds) .. _sphx_glr_download_beginner_hybrid_frontend_learning_hybrid_frontend_through_example_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: learning_hybrid_frontend_through_example_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: learning_hybrid_frontend_through_example_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_