• Tutorials >
  • Learning Hybrid Frontend Syntax Through Example
Shortcuts

Learning Hybrid Frontend Syntax Through Example

Author: Nathan Inkawhich

This document is meant to highlight the syntax of the Hybrid Frontend through a non-code intensive example. The Hybrid Frontend is one of the new shiny features of Pytorch 1.0 and provides an avenue for developers to transition their models from eager to graph mode. PyTorch users are very familiar with eager mode as it provides the ease-of-use and flexibility that we all enjoy as researchers. Caffe2 users are more aquainted with graph mode which has the benefits of speed, optimization opportunities, and functionality in C++ runtime environments. The hybrid frontend bridges the gap between the the two modes by allowing researchers to develop and refine their models in eager mode (i.e. PyTorch), then gradually transition the proven model to graph mode for production, when speed and resouce consumption become critical.

Hybrid Frontend Information

The process for transitioning a model to graph mode is as follows. First, the developer constructs, trains, and tests the model in eager mode. Then they incrementally trace and script each function/module of the model with the Just-In-Time (JIT) compiler, at each step verifying that the output is correct. Finally, when each of the components of the top-level model have been traced and scripted, the model itself is traced. At which point the model has been transitioned to graph mode, and has a complete python-free representation. With this representation, the model runtime can take advantage of high-performance Caffe2 operators and graph based optimizations.

Before we continue, it is important to understand the idea of tracing and scripting, and why they are separate. The goal of trace and script is the same, and that is to create a graph representation of the operations taking place in a given function. The discrepency comes from the flexibility of eager mode that allows for data-dependent control flows within the model architecture. When a function does NOT have a data-dependent control flow, it may be traced with torch.jit.trace. However, when the function has a data-dependent control flow it must be scripted with torch.jit.script. We will leave the details of the interworkings of the hybrid frontend for another document, but the code example below will show the syntax of how to trace and script different pure python functions and torch Modules. Hopefully, you will find that using the hybrid frontend is non-intrusive as it mostly involves adding decorators to the existing function and class definitions.

Motivating Example

In this example we will implement a strange math function that may be logically broken up into four parts that do, and do not contain data-dependent control flows. The purpose here is to show a non-code intensive example where the use of the JIT is highlighted. This example is a stand-in representation of a useful model, whose implementation has been divided into various pure python functions and modules.

The function we seek to implement, \(Y(x)\), is defined for \(x \epsilon \mathbb{N}\) as

\[z(x) = \Biggl \lfloor \frac{\sqrt{\prod_{i=1}^{|2 x|}i}}{5} \Biggr \rfloor\]
\[Y(x) = \begin{cases} \frac{z(x)}{2} & \text{if } z(x)\%2 == 0, \\ z(x) & \text{otherwise} \end{cases}\]
\[\begin{array}{| r | r |} \hline x &1 &2 &3 &4 &5 &6 &7 \\ \hline Y(x) &0 &0 &-5 &20 &190 &-4377 &-59051 \\ \hline \end{array}\]

As mentioned, the computation is split into four parts. Part one is the simple tensor calculation of \(|2x|\), which can be traced. Part two is the iterative product calculation that represents a data dependent control flow to be scripted (the number of loop iteration depends on the input at runtime). Part three is a trace-able \(\lfloor \sqrt{a/5} \rfloor\) calculation. Finally, part 4 handles the output cases depending on the value of \(z(x)\) and must be scripted due to the data dependency. Now, let’s see how this looks in code.

Part 1 - Tracing a pure python function

We can implement part one as a pure python function as below. Notice, to trace this function we call torch.jit.trace and pass in the function to be traced. Since the trace requires a dummy input of the expected runtime type and shape, we also include the torch.rand to generate a single valued torch tensor.

import torch

def fn(x):
    return torch.abs(2*x)

# This is how you define a traced function
# Pass in both the function to be traced and an example input to ``torch.jit.trace``
traced_fn = torch.jit.trace(fn, torch.rand(()))

Part 2 - Scripting a pure python function

We can also implement part 2 as a pure python function where we iteratively compute the product. Since the number of iterations depends on the value of the input, we have a data dependent control flow, so the function must be scripted. We can script python functions simply with the @torch.jit.script decorator.

# This is how you define a script function
# Apply this decorator directly to the function
@torch.jit.script
def script_fn(x):
    z = torch.ones([1], dtype=torch.int64)
    for i in range(int(x)):
        z = z * (i + 1)
    return z

Part 3 - Tracing a nn.Module

Next, we will implement part 3 of the computation within the forward function of a torch.nn.Module. This module may be traced, but rather than adding a decorator here, we will handle the tracing where the Module is constructed. Thus, the class definition is not changed at all.

# This is a normal module that can be traced.
class TracedModule(torch.nn.Module):
    def forward(self, x):
        x = x.type(torch.float32)
        return torch.floor(torch.sqrt(x) / 5.)

Part 4 - Scripting a nn.Module

In the final part of the computation we have a torch.nn.Module that must be scripted. To accomodate this, we inherit from torch.jit.ScriptModule and add the @torch.jit.script_method decorator to the forward function.

# This is how you define a scripted module.
# The module should inherit from ScriptModule and the forward should have the
#   script_method decorator applied to it.
class ScriptModule(torch.jit.ScriptModule):
    @torch.jit.script_method
    def forward(self, x):
        r = -x
        if int(torch.fmod(x, 2.0)) == 0.0:
            r = x / 2.0
        return r

Top-Level Module

Now we will put together the pieces of the computation via a top level module called Net. In the constructor, we will instantiate the TracedModule and ScriptModule as attributes. This must be done because we ultimately want to trace/script the top level module, and having the traced/scripted modules as attributes allows the Net to inherit the required submodules’ parameters. Notice, this is where we actually trace the TracedModule by calling torch.jit.trace() and providing the necessary dummy input. Also notice that the ScriptModule is constructed as normal because we handled the scripting in the class definition.

Here we can also print the graphs created for each individual part of the computation. The printed graphs allows us to see how the JIT ultimately interpreted the functions as graph computations.

Finally, we define the forward function for the Net module where we run the input data x through the four parts of the computation. There is no strange syntax here and we call the traced and scripted modules and functions as expected.

# This is a demonstration net that calls all of the different types of
# methods and functions
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # Modules must be attributes on the Module because if you want to trace
        # or script this Module, we must be able to inherit the submodules'
        # params.
        self.traced_module = torch.jit.trace(TracedModule(), torch.rand(()))
        self.script_module = ScriptModule()

        print('traced_fn graph', traced_fn.graph)
        print('script_fn graph', script_fn.graph)
        print('TracedModule graph', self.traced_module.__getattr__('forward').graph)
        print('ScriptModule graph', self.script_module.__getattr__('forward').graph)

    def forward(self, x):
        # Call a traced function
        x = traced_fn(x)

        # Call a script function
        x = script_fn(x)

        # Call a traced submodule
        x = self.traced_module(x)

        # Call a scripted submodule
        x = self.script_module(x)

        return x

Running the Model

All that’s left to do is construct the Net and compute the output through the forward function. Here, we use \(x=5\) as the test input value and expect \(Y(x)=190.\) Also, check out the graphs that were printed during the construction of the Net.

# Instantiate this net and run it
n = Net()
print(n(torch.tensor([5]))) # 190.

Out:

traced_fn graph graph(%x : Float(requires_grad=0, device=cpu)):
  %1 : Long(requires_grad=0, device=cpu) = prim::Constant[value={2}]() # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:105:0
  %2 : Float(requires_grad=0, device=cpu) = aten::mul(%x, %1) # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:105:0
  %3 : Float(requires_grad=0, device=cpu) = aten::abs(%2) # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:105:0
  return (%3)

script_fn graph graph(%x.1 : Tensor):
  %13 : bool = prim::Constant[value=1]() # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:127:4
  %4 : NoneType = prim::Constant()
  %3 : int = prim::Constant[value=4]() # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:126:30
  %1 : int = prim::Constant[value=1]() # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:126:20
  %2 : int[] = prim::ListConstruct(%1)
  %z.1 : Tensor = aten::ones(%2, %3, %4, %4, %4) # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:126:8
  %10 : int = aten::Int(%x.1) # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:127:19
  %z : Tensor = prim::Loop(%10, %13, %z.1) # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:127:4
    block0(%i.1 : int, %z.11 : Tensor):
      %17 : int = aten::add(%i.1, %1) # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:128:17
      %z.5 : Tensor = aten::mul(%z.11, %17) # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:128:12
      -> (%13, %z.5)
  return (%z)

TracedModule graph graph(%self : __torch__.TracedModule,
      %x.1 : Float(requires_grad=0, device=cpu)):
  %4 : Device = prim::Constant[value="cpu"]() # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:145:0
  %5 : int = prim::Constant[value=6]() # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:145:0
  %6 : bool = prim::Constant[value=0]() # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:145:0
  %7 : bool = prim::Constant[value=0]() # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:145:0
  %8 : NoneType = prim::Constant()
  %x : Float(requires_grad=0, device=cpu) = aten::to(%x.1, %4, %5, %6, %7, %8) # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:145:0
  %10 : Float(requires_grad=0, device=cpu) = aten::sqrt(%x) # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:146:0
  %11 : Double(requires_grad=0, device=cpu) = prim::Constant[value={5}]() # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:146:0
  %12 : Float(requires_grad=0, device=cpu) = aten::div(%10, %11) # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:146:0
  %13 : Float(requires_grad=0, device=cpu) = aten::floor(%12) # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:146:0
  return (%13)

ScriptModule graph graph(%self : __torch__.ScriptModule,
      %x.1 : Tensor):
  %5 : float = prim::Constant[value=2.]() # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:166:29
  %9 : float = prim::Constant[value=0.]() # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:166:38
  %r.1 : Tensor = aten::neg(%x.1) # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:165:12
  %6 : Tensor = aten::fmod(%x.1, %5) # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:166:15
  %8 : int = aten::Int(%6) # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:166:11
  %10 : bool = aten::eq(%8, %9) # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:166:11
  %r : Tensor = prim::If(%10) # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:166:8
    block0():
      %r.3 : Tensor = aten::div(%x.1, %5) # /var/lib/jenkins/workspace/beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py:167:16
      -> (%r.3)
    block1():
      -> (%r.1)
  return (%r)

tensor([190.])

Tracing the Top-Level Model

The last part of the example is to trace the top-level module, Net. As mentioned previously, since the traced/scripted modules are attributes of Net, we are able to trace Net as it inherits the parameters of the traced/scripted submodules. Note, the syntax for tracing Net is identical to the syntax for tracing TracedModule. Also, check out the graph that is created.

n_traced = torch.jit.trace(n, torch.tensor([5]))
print(n_traced(torch.tensor([5])))
print('n_traced graph', n_traced.graph)

Out:

tensor([190.])
n_traced graph graph(%self : __torch__.Net,
      %x.1 : Long(1, strides=[1], requires_grad=0, device=cpu)):
  %script_module : __torch__.ScriptModule = prim::GetAttr[name="script_module"](%self)
  %traced_module : __torch__.TracedModule = prim::GetAttr[name="traced_module"](%self)
  %10 : Function = prim::Constant[name="fn"]()
  %x : Tensor = prim::CallFunction(%10, %x.1)
  %12 : Function = prim::Constant[name="script_fn"]()
  %13 : Tensor = prim::CallFunction(%12, %x)
  %14 : Tensor = prim::CallMethod[name="forward"](%traced_module, %13)
  %15 : Tensor = prim::CallMethod[name="forward"](%script_module, %14)
  return (%15)

Hopefully, this document can serve as an introduction to the hybrid frontend as well as a syntax reference guide for more experienced users. Also, there are a few things to keep in mind when using the hybrid frontend. There is a constraint that traced/scripted methods must be written in a restricted subset of python, as features like generators, defs, and Python data structures are not supported. As a workaround, the scripting model is designed to work with both traced and non-traced code which means you can call non-traced code from traced functions. However, such a model may not be exported to ONNX.

Total running time of the script: ( 0 minutes 0.505 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources