TorchScript
===========

.. toctree::
   :maxdepth: 1
   :caption: Builtin Functions
   :hidden:

   torch.jit.supported_ops <jit_builtin_functions>

.. contents:: :local:

.. automodule:: torch.jit
.. currentmodule:: torch.jit

TorchScript is a way to create serializable and optimizable models from PyTorch code.
Any TorchScript program can be saved from a Python
process and loaded in a process where there is no Python dependency.

We provide tools to incrementally transition a model from a pure Python program
to a TorchScript program that can be run independently from Python, such as in a standalone C++ program.
This makes it possible to train models in PyTorch using familiar tools in Python and then export
the model via TorchScript to a production environment where Python programs may be disadvantageous.
for performance and multi-threading reasons.

For a gentle introduction to TorchScript, see the `Introduction to TorchScript <https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html>`_ tutorial.

For an end-to-end example of converting a PyTorch model to TorchScript and running it in C++, see the
`Loading a PyTorch Model in C++ <https://pytorch.org/tutorials/advanced/cpp_export.html>`_ tutorial.

Creating TorchScript Code
--------------------------

.. autoclass:: ScriptModule()
    :members:

.. autofunction:: script(obj)

.. autofunction:: trace(func, example_inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5)

.. autofunction:: trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5)

.. autofunction:: save

.. autofunction:: load


Mixing Tracing and Scripting
----------------------------

In many cases either tracing or scripting is an easier approach for converting a model to TorchScript.
Tracing and scripting can be composed to suit the particular requirements
of a part of a model.

Scripted functions can call traced functions. This is particularly useful when you need
to use control-flow around a simple feed-forward model. For instance the beam search
of a sequence to sequence model will typically be written in script but can call an
encoder module generated using tracing.


.. testsetup::

    # These are hidden from the docs, but these are necessary for `doctest`
    # since the `inspect` module doesn't play nicely with the execution
    # environment for `doctest`
    import torch

    original_script = torch.jit.script
    def script_wrapper(obj, *args, **kwargs):
        obj.__module__ = 'FakeMod'
        return original_script(obj, *args, **kwargs)

    torch.jit.script = script_wrapper

    original_trace = torch.jit.trace
    def trace_wrapper(obj, *args, **kwargs):
        obj.__module__ = 'FakeMod'
        return original_trace(obj, *args, **kwargs)

    torch.jit.trace = trace_wrapper


Example (calling a traced function in script):

.. testcode::

    import torch

    def foo(x, y):
        return 2 * x + y

    traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

    @torch.jit.script
    def bar(x):
        return traced_foo(x, x)

Traced functions can call script functions. This is useful when a small part of
a model requires some control-flow even though most of the model is just a feed-forward
network. Control-flow inside of a script function called by a traced function is
preserved correctly.

Example (calling a script function in a traced function):

.. testcode::

    import torch

    @torch.jit.script
    def foo(x, y):
        if x.max() > y.max():
            r = x
        else:
            r = y
        return r


    def bar(x, y, z):
        return foo(x, y) + z

    traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))

This composition also works for ``nn.Module``\s as well, where it can be used to generate
a submodule using tracing that can be called from the methods of a script module.

Example (using a traced module):

.. testcode::
    :skipif: torchvision is None

    import torch
    import torchvision

    class MyScriptModule(torch.nn.Module):
        def __init__(self):
            super(MyScriptModule, self).__init__()
            self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
                                            .resize_(1, 3, 1, 1))
            self.resnet = torch.jit.trace(torchvision.models.resnet18(),
                                          torch.rand(1, 3, 224, 224))

        def forward(self, input):
            return self.resnet(input - self.means)

    my_script_module = torch.jit.script(MyScriptModule())

Migrating to PyTorch 1.2 Recursive Scripting API
------------------------------------------------
This section details the changes to TorchScript in PyTorch 1.2. If you are new to TorchScript you can
skip this section. There are two main changes to the TorchScript API with PyTorch 1.2.

1. :func:`torch.jit.script <torch.jit.script>` will now attempt to recursively compile functions,
methods, and classes that it encounters. Once you call ``torch.jit.script``,
compilation is "opt-out", rather than "opt-in".

2. ``torch.jit.script(nn_module_instance)`` is now the preferred way to create
``ScriptModule``\s, instead of inheriting from ``torch.jit.ScriptModule``.
These changes combine to provide a simpler, easier-to-use API for converting
your ``nn.Module``\s into ``ScriptModule``\s, ready to be optimized and executed in a
non-Python environment.

The new usage looks like this:

.. testcode::

    import torch
    import torch.nn as nn
    import torch.nn.functional as F

    class Model(nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            self.conv1 = nn.Conv2d(1, 20, 5)
            self.conv2 = nn.Conv2d(20, 20, 5)

        def forward(self, x):
            x = F.relu(self.conv1(x))
            return F.relu(self.conv2(x))

    my_model = Model()
    my_scripted_model = torch.jit.script(my_model)


* The module's ``forward`` is compiled by default. Methods called from ``forward`` are lazily compiled in the order they are used in ``forward``.
* To compile a method other than ``forward`` that is not called from ``forward``, add ``@torch.jit.export``.
* To stop the compiler from compiling a method and leave it as a call to Python, add ``@torch.jit.ignore``.
* Most attribute types can be inferred, so ``torch.jit.Attribute`` is not necessary. For empty container types, annotate their types using `PEP 526-style <https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations>`_ class annotations.
* Constants can be marked with a ``Final`` class annotation instead of adding the name of the member to ``__constants__``.
* Python 3 type hints can be used in place of ``torch.jit.annotate``

As a result of these changes, the following items are considered deprecated and should not appear in new code:
  * The ``@torch.jit.script_method`` decorator
  * Classes that inherit from ``torch.jit.ScriptModule``
  * The ``torch.jit.Attribute`` wrapper class
  * The ``__constants__`` array
  * The ``torch.jit.annotate`` function

Modules
~~~~~~~
.. warning::

    The :func:`@torch.jit.ignore <torch.jit.ignore>` annotation's behavior changes in
    PyTorch 1.2. Before PyTorch 1.2 the @ignore decorator was used to make a function
    or method callable from code that is exported. To get this functionality back,
    use ``@torch.jit.unused()``. ``@torch.jit.ignore`` is now equivalent
    to ``@torch.jit.ignore(drop=False)``. See :func:`@torch.jit.ignore <torch.jit.ignore>`
    and :func:`@torch.jit.unused<torch.jit.unused>` for details.

When passed to the :func:`torch.jit.script <torch.jit.script>` function, a ``torch.nn.Module``\'s data is
copied to a ``ScriptModule`` and the TorchScript compiler compiles the module.
The module's ``forward`` is compiled by default. Methods called from ``forward`` are
lazily compiled in the order they are used in ``forward``, as well as any
``@torch.jit.export`` methods.

.. autofunction:: export

Functions
~~~~~~~~~
Functions don't change much, they can be decorated with :func:`@torch.jit.ignore <torch.jit.ignore>` if needed.

.. testcode::

    # Same behavior as pre-PyTorch 1.2
    @torch.jit.script
    def some_fn():
        return 2

    # Marks a function as ignored, if nothing
    # ever calls it then this has no effect
    @torch.jit.ignore
    def some_fn2():
        return 2

    # Doesn't do anything, this function is already
    # the main entry point
    @torch.jit.export
    def some_fn3():
        return 2


TorchScript Classes
~~~~~~~~~~~~~~~~~~~
Everything in a user defined `TorchScript Class`_ is exported by default, functions
can be decorated with :func:`@torch.jit.ignore <torch.jit.ignore>` if needed.

Attributes
~~~~~~~~~~
The TorchScript compiler needs to know the types of `module attributes`_. Most types
can be inferred from the value of the member. Empty lists and dicts cannot have their
types inferred and must have their types annotated with `PEP 526-style <https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations>`_ class annotations.

Old API:

.. testcode::

    from typing import Dict
    import torch

    class MyModule(torch.jit.ScriptModule):
        def __init__(self):
            super(MyModule, self).__init__()
            self.my_dict = torch.jit.Attribute({}, Dict[str, int])
            self.my_int = torch.jit.Attribute(20, int)

    m = MyModule()

New API:

.. testcode::

    from typing import Dict

    class MyModule(torch.nn.Module):
        my_dict: Dict[str, int]

        def __init__(self):
            super(MyModule, self).__init__()
            # This type cannot be inferred and must be specified
            self.my_dict = {}

            # The attribute type here is inferred to be `int`
            self.my_int = 20

        def forward(self):
            pass

    m = torch.jit.script(MyModule())

Python 2
^^^^^^^^
If you are stuck on Python 2 and cannot use the class annotation syntax, you can use the ``__annotations__`` class member to directly apply type annotations.

.. testcode::

    from typing import Dict

    class MyModule(torch.jit.ScriptModule):
        __annotations__ = {'my_dict': Dict[str, int]}

        def __init__(self):
            super(MyModule, self).__init__()
            self.my_dict = {}
            self.my_int = 20

Constants
~~~~~~~~~
The ``Final`` type constructor can be used to mark members as `constant`_. If members are not marked constant, they will be copied to the resulting ``ScriptModule`` as an attribute. Using ``Final`` opens opportunities for optimization if the value is known to be fixed and gives additional type safety.

Old API:

.. testcode::

    class MyModule(torch.jit.ScriptModule):
        __constants__ = ['my_constant']

        def __init__(self):
            super(MyModule, self).__init__()
            self.my_constant = 2

        def forward(self):
            pass
    m = MyModule()

New API:

::

    try:
        from typing_extensions import Final
    except:
        # If you don't have `typing_extensions` installed, you can use a
        # polyfill from `torch.jit`.
        from torch.jit import Final

    class MyModule(torch.nn.Module):

        my_constant: Final[int]

        def __init__(self):
            super(MyModule, self).__init__()
            self.my_constant = 2

        def forward(self):
            pass

    m = torch.jit.script(MyModule())

.. _Python 3 type hints:

Variables
~~~~~~~~~
Containers are assumed to have type ``Tensor`` and be non-optional (see
`Default Types`_ for more information). Previously, ``torch.jit.annotate`` was used to
tell the TorchScript compiler what the type should be. Python 3 style type hints are
now supported.

.. testcode::

    import torch
    from typing import Dict, Optional

    @torch.jit.script
    def make_dict(flag: bool):
        x: Dict[str, int] = {}
        x['hi'] = 2
        b: Optional[int] = None
        if flag:
            b = 2
        return x, b



TorchScript Language Reference
-------------------------------

TorchScript is a statically typed subset of Python that can either be written directly (using
the :func:`@torch.jit.script <torch.jit.script>` decorator) or generated automatically from Python code via
tracing. When using tracing, code is automatically converted into this subset of
Python by recording only the actual operators on tensors and simply executing and
discarding the other surrounding Python code.

When writing TorchScript directly using ``@torch.jit.script`` decorator, the programmer must
only use the subset of Python supported in TorchScript. This section documents
what is supported in TorchScript as if it were a language reference for a stand
alone language. Any features of Python not mentioned in this reference are not
part of TorchScript. See `Builtin Functions`_ for a complete reference of available
Pytorch tensor methods, modules, and functions.

As a subset of Python, any valid TorchScript function is also a valid Python
function. This makes it possible to `disable TorchScript`_ and debug the
function using standard Python tools like ``pdb``. The reverse is not true: there
are many valid Python programs that are not valid TorchScript programs.
Instead, TorchScript focuses specifically on the features of Python that are
needed to represent neural network models in PyTorch.

.. _types:
.. _supported type:

Types
~~~~~

The largest difference between TorchScript and the full Python language is that
TorchScript only supports a small set of types that are needed to express neural
net models. In particular, TorchScript supports:

.. csv-table::
   :header: "Type", "Description"

   "``Tensor``", "A PyTorch tensor of any dtype, dimension, or backend"
   "``Tuple[T0, T1, ...]``", "A tuple containing subtypes ``T0``, ``T1``, etc. (e.g. ``Tuple[Tensor, Tensor]``)"
   "``bool``", "A boolean value"
   "``int``", "A scalar integer"
   "``float``", "A scalar floating point number"
   "``str``", "A string"
   "``List[T]``", "A list of which all members are type ``T``"
   "``Optional[T]``", "A value which is either None or type ``T``"
   "``Dict[K, V]``", "A dict with key type ``K`` and value type ``V``. Only ``str``, ``int``, and ``float`` are allowed as key types."
   "``T``", "A `TorchScript Class`_"
   "``NamedTuple[T0, T1, ...]``", "A :func:`collections.namedtuple <collections.namedtuple>` tuple type"


Unlike Python, each variable in TorchScript function must have a single static type.
This makes it easier to optimize TorchScript functions.

Example (a type mismatch)

.. testcode::

    import torch

    @torch.jit.script
    def an_error(x):
        if x:
            r = torch.rand(1)
        else:
            r = 4
        return r


.. testoutput::

     Traceback (most recent call last):
       ...
     RuntimeError: ...

     Type mismatch: r is set to type Tensor in the true branch and type int in the false branch:
     @torch.jit.script
     def an_error(x):
         if x:
         ~~~~~...  <--- HERE
             r = torch.rand(1)
         else:
             r = 4
         return r
     ...


Default Types
^^^^^^^^^^^^^

By default, all parameters to a TorchScript function are assumed to be Tensor.
To specify that an argument to a TorchScript function is another type, it is possible to use
MyPy-style type annotations using the types listed above.

.. testcode::

    import torch

    @torch.jit.script
    def foo(x, tup):
        # type: (int, Tuple[Tensor, Tensor]) -> Tensor
        t0, t1 = tup
        return t0 + t1 + x

    print(foo(3, (torch.rand(3), torch.rand(3))))

.. testoutput::
    :hide:

    ...

.. note::
  It is also possible to annotate types with Python 3 type hints from the
  ``typing`` module.

  .. testcode::

    import torch
    from typing import Tuple

    @torch.jit.script
    def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        t0, t1 = tup
        return t0 + t1 + x

    print(foo(3, (torch.rand(3), torch.rand(3))))

  .. testoutput::
    :hide:

    ...

  In our examples, we use comment-based type hints to ensure Python 2
  compatibility as well.


An empty list is assumed to be ``List[Tensor]`` and empty dicts
``Dict[str, Tensor]``. To instantiate an empty list or dict of other types,
use `Python 3 type hints`_. If you are on Python 2, you can use ``torch.jit.annotate``.

Example (type annotations for Python 3):

.. testcode::

    import torch
    import torch.nn as nn
    from typing import Dict, List, Tuple

    class EmptyDataStructures(torch.nn.Module):
        def __init__(self):
            super(EmptyDataStructures, self).__init__()

        def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]:
            # This annotates the list to be a `List[Tuple[int, float]]`
            my_list: List[Tuple[int, float]] = []
            for i in range(10):
                my_list.append((i, x.item()))

            my_dict: Dict[str, int] = {}
            return my_list, my_dict

    x = torch.jit.script(EmptyDataStructures())


Example (``torch.jit.annotate`` for Python 2):

.. testcode::

    import torch
    import torch.nn as nn
    from typing import Dict, List, Tuple

    class EmptyDataStructures(torch.nn.Module):
        def __init__(self):
            super(EmptyDataStructures, self).__init__()

        def forward(self, x):
            # type: (Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]

            # This annotates the list to be a `List[Tuple[int, float]]`
            my_list = torch.jit.annotate(List[Tuple[int, float]], [])
            for i in range(10):
                my_list.append((i, float(x.item())))

            my_dict = torch.jit.annotate(Dict[str, int], {})
            return my_list, my_dict

    x = torch.jit.script(EmptyDataStructures())



Optional Type Refinement
^^^^^^^^^^^^^^^^^^^^^^^^

TorchScript will refine the type of a variable of type ``Optional[T]`` when
a comparison to ``None`` is made inside the conditional of an if-statement or checked in an ``assert``.
The compiler can reason about multiple ``None`` checks that are combined with
``and``, ``or``, and ``not``. Refinement will also occur for else blocks of if-statements
that are not explicitly written.

The ``None`` check must be within the if-statement's condition; assigning
a ``None`` check to a variable and using it in the if-statement's condition will
not refine the types of variables in the check.
Only local variables will be refined, an attribute like ``self.x`` will not and must assigned to
a local variable to be refined.


Example (refining types on parameters and locals):

.. testcode::

    import torch
    import torch.nn as nn
    from typing import Optional

    class M(nn.Module):
        z: Optional[int]

        def __init__(self, z):
            super(M, self).__init__()
            # If `z` is None, its type cannot be inferred, so it must
            # be specified (above)
            self.z = z

        def forward(self, x, y, z):
            # type: (Optional[int], Optional[int], Optional[int]) -> int
            if x is None:
                x = 1
                x = x + 1

            # Refinement for an attribute by assigning it to a local
            z = self.z
            if y is not None and z is not None:
                x = y + z

            # Refinement via an `assert`
            assert z is not None
            x += z
            return x

    module = torch.jit.script(M(2))
    module = torch.jit.script(M(None))

.. _TorchScript Class:
.. _TorchScript Classes:

TorchScript Classes
^^^^^^^^^^^^^^^^^^^
Python classes can be used in TorchScript if they are annotated with :func:`@torch.jit.script <torch.jit.script>`,
similar to how you would declare a TorchScript function:

.. testcode::
    :skipif: True  # TODO: fix the source file resolving so this can be tested

    @torch.jit.script
    class Foo:
      def __init__(self, x, y):
        self.x = x

      def aug_add_x(self, inc):
        self.x += inc


This subset is restricted:

* All functions must be valid TorchScript functions (including ``__init__()``).
* Classes must be new-style classes, as we use ``__new__()`` to construct them with pybind11.
* TorchScript classes are statically typed. Members can only be declared by assigning to
  self in the ``__init__()`` method.

    For example, assigning to ``self`` outside of the ``__init__()`` method: ::

        @torch.jit.script
        class Foo:
          def assign_x(self):
            self.x = torch.rand(2, 3)

    Will result in: ::

        RuntimeError:
        Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?:
        def assign_x(self):
          self.x = torch.rand(2, 3)
          ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

* No expressions except method definitions are allowed in the body of the class.
* No support for inheritance or any other polymorphism strategy, except for inheriting
  from ``object`` to specify a new-style class.

After a class is defined, it can be used in both TorchScript and Python interchangeably
like any other TorchScript type:

::

    # Declare a TorchScript class
    @torch.jit.script
    class Pair:
      def __init__(self, first, second):
        self.first = first
        self.second = second

    @torch.jit.script
    def sum_pair(p):
      # type: (Pair) -> Tensor
      return p.first + p.second

    p = Pair(torch.rand(2, 3), torch.rand(2, 3))
    print(sum_pair(p))


Named Tuples
^^^^^^^^^^^^
Types produced by :func:`collections.namedtuple <collections.namedtuple>` can be used in TorchScript.

.. testcode::

    import torch
    import collections

    Point = collections.namedtuple('Point', ['x', 'y'])

    @torch.jit.script
    def total(point):
        # type: (Point) -> Tensor
        return point.x + point.y

    p = Point(x=torch.rand(3), y=torch.rand(3))
    print(total(p))

.. testoutput::
    :hide:

    ...


Expressions
~~~~~~~~~~~

The following Python Expressions are supported.

Literals
^^^^^^^^
::

    True
    False
    None
    'string literals'
    "string literals"
    3  # interpreted as int
    3.4  # interpreted as a float

List Construction
"""""""""""""""""
An empty list is assumed have type ``List[Tensor]``.
The types of other list literals are derived from the type of the members.
See `Default Types`_ for more details.

::

    [3, 4]
    []
    [torch.rand(3), torch.rand(4)]



Tuple Construction
""""""""""""""""""
::

    (3, 4)
    (3,)


Dict Construction
"""""""""""""""""
An empty dict is assumed have type ``Dict[str, Tensor]``.
The types of other dict literals are derived from the type of the members.
See `Default Types`_ for more details.

::

    {'hello': 3}
    {}
    {'a': torch.rand(3), 'b': torch.rand(4)}


Variables
^^^^^^^^^
See `Variable Resolution`_ for how variables are resolved.

::

    my_variable_name

Arithmetic Operators
^^^^^^^^^^^^^^^^^^^^
::

    a + b
    a - b
    a * b
    a / b
    a ^ b
    a @ b

Comparison Operators
^^^^^^^^^^^^^^^^^^^^
::

    a == b
    a != b
    a < b
    a > b
    a <= b
    a >= b

Logical Operators
^^^^^^^^^^^^^^^^^
::

    a and b
    a or b
    not b

Subscripts and Slicing
^^^^^^^^^^^^^^^^^^^^^^
::

    t[0]
    t[-1]
    t[0:2]
    t[1:]
    t[:1]
    t[:]
    t[0, 1]
    t[0, 1:2]
    t[0, :1]
    t[-1, 1:, 0]
    t[1:, -1, 0]
    t[i:j, i]

Function Calls
^^^^^^^^^^^^^^
Calls to `builtin functions`_

::

    torch.rand(3, dtype=torch.int)

Calls to other script functions:

.. testcode::

    import torch

    @torch.jit.script
    def foo(x):
        return x + 1

    @torch.jit.script
    def bar(x):
        return foo(x)

Method Calls
^^^^^^^^^^^^
Calls to methods of builtin types like tensor: ``x.mm(y)``

On modules, methods must be compiled before they can be called. The TorchScript
compiler recursively compiles methods it sees when compiling other methods. By default,
compilation starts on the ``forward`` method. Any methods called by ``forward`` will
be compiled, and any methods called by those methods, and so on. To start compilation at
a method other than ``forward``, use the :func:`@torch.jit.export <torch.jit.export>` decorator
(``forward`` implicitly is marked ``@torch.jit.export``).

Calling a submodule directly (e.g. ``self.resnet(input)``) is equivalent to
calling its ``forward`` method (e.g. ``self.resnet.forward(input)``).

.. testcode::
    :skipif: torchvision is None

    import torch
    import torch.nn as nn
    import torchvision

    class MyModule(nn.Module):
        def __init__(self):
            super(MyModule, self).__init__()
            means = torch.tensor([103.939, 116.779, 123.68])
            self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1))
            resnet = torchvision.models.resnet18()
            self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224))

        def helper(self, input):
            return self.resnet(input - self.means)

        def forward(self, input):
            return self.helper(input)

        # Since nothing in the model calls `top_level_method`, the compiler
        # must be explicitly told to compile this method
        @torch.jit.export
        def top_level_method(self, input):
            return self.other_helper(input)

        def other_helper(self, input):
            return input + 10

    # `my_script_module` will have the compiled methods `forward`, `helper`,
    # `top_level_method`, and `other_helper`
    my_script_module = torch.jit.script(MyModule())


Ternary Expressions
^^^^^^^^^^^^^^^^^^^
::

    x if x > y else y

Casts
^^^^^
::

    float(ten)
    int(3.5)
    bool(ten)
    str(2)``

Accessing Module Parameters
^^^^^^^^^^^^^^^^^^^^^^^^^^^
::

    self.my_parameter
    self.my_submodule.my_parameter


Statements
~~~~~~~~~~

TorchScript supports the following types of statements:

Simple Assignments
^^^^^^^^^^^^^^^^^^
::

    a = b
    a += b # short-hand for a = a + b, does not operate in-place on a
    a -= b

Pattern Matching Assignments
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
::

    a, b = tuple_or_list
    a, b, *c = a_tuple

Multiple Assignments
::

    a = b, c = tup

Print Statements
^^^^^^^^^^^^^^^^
::

    print("the result of an add:", a + b)

If Statements
^^^^^^^^^^^^^
::

    if a < 4:
        r = -a
    elif a < 3:
        r = a + a
    else:
        r = 3 * a

In addition to bools, floats, ints, and Tensors can be used in a conditional
and will be implicitly casted to a boolean.

While Loops
^^^^^^^^^^^
::

    a = 0
    while a < 4:
        print(a)
        a += 1


For loops with range
^^^^^^^^^^^^^^^^^^^^
::

    x = 0
    for i in range(10):
        x *= i

For loops over tuples
^^^^^^^^^^^^^^^^^^^^^
These unroll the loop, generating a body for
each member of the tuple. The body must type-check correctly for each member.

::

    tup = (3, torch.rand(4))
    for x in tup:
        print(x)


For loops over constant nn.ModuleList
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

To use a ``nn.ModuleList`` inside a compiled method, it must be marked
constant by adding the name of the attribute to the ``__constants__``
list for the type. For loops over a ``nn.ModuleList`` will unroll the body of the
loop at compile time, with each member of the constant module list.

.. testcode::

    class SubModule(torch.nn.Module):
        def __init__(self):
            super(SubModule, self).__init__()
            self.weight = nn.Parameter(torch.randn(2))

        def forward(self, input):
            return self.weight + input

    class MyModule(torch.nn.Module):
        __constants__ = ['mods']

        def __init__(self):
            super(MyModule, self).__init__()
            self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])

        def forward(self, v):
            for module in self.mods:
                v = module(v)
            return v


    m = torch.jit.script(MyModule())



Break and Continue
^^^^^^^^^^^^^^^^^^
::

    for i in range(5):
        if i == 1:
        continue
        if i == 3:
        break
        print(i)

Return
^^^^^^
::

    return a, b

Variable Resolution
~~~~~~~~~~~~~~~~~~~

TorchScript supports a subset of Python's variable resolution (i.e. scoping)
rules. Local variables behave the same as in Python, except for the restriction
that a variable must have the same type along all paths through a function.
If a variable has a different type on different branches of an if statement, it
is an error to use it after the end of the if statement.

Similarly, a variable is not allowed to be used if it is only *defined* along some
paths through the function.

Example:

.. testcode::

    @torch.jit.script
    def foo(x):
        if x < 0:
            y = 4
        print(y)

.. testoutput::

     Traceback (most recent call last):
       ...
     RuntimeError: ...

     y is not defined in the false branch...
     @torch.jit.script...
     def foo(x):
         if x < 0:
         ~~~~~~~~~...  <--- HERE
             y = 4
         print(y)
     ...

Non-local variables are resolved to Python values at compile time when the
function is defined. These values are then converted into TorchScript values using
the rules described in `Use of Python Values`_.

Use of Python Values
~~~~~~~~~~~~~~~~~~~~

To make writing TorchScript more convenient, we allow script code to refer
to Python values in the surrounding scope. For instance, any time there is a
reference to ``torch``, the TorchScript compiler is actually resolving it to the
``torch`` Python module when the function is declared.  These Python values are
not a first class part of TorchScript. Instead they are de-sugared at compile-time
into the primitive types that TorchScript supports. This depends
on the dynamic type of the Python valued referenced when compilation occurs.
This section describes the rules that are used when accessing Python values in TorchScript.

Functions
^^^^^^^^^

TorchScript can call Python functions. This functionality is very useful when
incrementally converting a model to TorchScript. The model can be moved function-by-function
to TorchScript, leaving calls to Python functions in place. This way you can incrementally
check the correctness of the model as you go.

.. autofunction:: ignore


Attribute Lookup On Python Modules
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TorchScript can lookup attributes on modules. `Builtin functions`_ like ``torch.add``
are accessed this way. This allows TorchScript to call functions defined in
other modules.

.. _constant:

Python-defined Constants
^^^^^^^^^^^^^^^^^^^^^^^^
TorchScript also provides a way to use constants that are defined in Python.
These can be used to hard-code hyper-parameters into the function, or to
define universal constants. There are two ways of specifying that a Python
value should be treated as a constant.

1. Values looked up as attributes of a module are assumed to be constant:

.. testcode::

    import math
    import torch

    @torch.jit.script
    def fn():
        return math.pi

2. Attributes of a ScriptModule can be marked constant by annotating them with ``Final[T]``

::

    import torch
    import torch.nn as nn

    class Foo(nn.Module):
        # `Final` from the `typing_extensions` module can also be used
        a : torch.jit.Final[int]

        def __init__(self):
            super(Foo, self).__init__()
            self.a = 1 + 4

        def forward(self, input):
            return self.a + input

    f = torch.jit.script(Foo())

Supported constant Python types are

* ``int``
* ``float``
* ``bool``
* ``torch.device``
* ``torch.layout``
* ``torch.dtype``
* tuples containing supported types
* ``torch.nn.ModuleList`` which can be used in a TorchScript for loop

.. note::
    If you are on Python 2, you can mark an attribute as a constant by adding
    its name to the ``__constants__`` property of the class:

    .. testcode::

        import torch
        import torch.nn as nn

        class Foo(nn.Module):
            __constants__ = ['a']

            def __init__(self):
                super(Foo, self).__init__()
                self.a = 1 + 4

            def forward(self, input):
                return self.a + input

        f = torch.jit.script(Foo())

    |

.. _module attributes:

Module Attributes
^^^^^^^^^^^^^^^^^

The ``torch.nn.Parameter`` wrapper and ``register_buffer`` can be used to assign
tensors to a module. Other values assigned to a module that is compiled
will be added to the compiled module if their types can be inferred. All `types`_
available in TorchScript can be used as module attributes. Tensor attributes are
semantically the same as buffers. The type of empty containers and ``None``
values cannot be inferred and must be specified via
`PEP 526-style <https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations>`_ class annotations.

Example:

.. testcode::

    from typing import List, Dict

    class Foo(nn.Module):
        # `words` is initialzed as an empty list, so its type must be specified
        words: List[str]

        # The type could potentially be inferred if `a_dict` (below) was not
        # empty, but this annotation ensures `some_dict` will be made into the
        # proper type
        some_dict: Dict[str, int]

        def __init__(self, a_dict):
            super(Foo, self).__init__()
            self.words = []
            self.some_dict = a_dict

            # `int`s can be inferred
            self.my_int = 10

        def forward(self, input):
            # type: (str) -> int
            self.words.append(input)
            return self.some_dict[input] + self.my_int

    f = torch.jit.script(Foo({'hi': 2}))


.. note::
    If you are on Python 2, you can mark an attribute's type by adding it to
    the ``__annotations__`` class property as a dictionary of attribute name to
    type

    .. testcode::

        from typing import List, Dict

        class Foo(nn.Module):
            __annotations__ = {'words': List[str], 'some_dict': Dict[str, int]}

            def __init__(self, a_dict):
                super(Foo, self).__init__()
                self.words = []
                self.some_dict = a_dict

                # `int`s can be inferred
                self.my_int = 10

            def forward(self, input):
                # type: (str) -> int
                self.words.append(input)
                return self.some_dict[input] + self.my_int

        f = torch.jit.script(Foo({'hi': 2}))

    |

Debugging
~~~~~~~~~

.. _`disable TorchScript`:

Disable JIT for Debugging
^^^^^^^^^^^^^^^^^^^^^^^^^
.. envvar:: PYTORCH_JIT

    Setting the environment variable ``PYTORCH_JIT=0`` will disable all script
    and tracing annotations. If there is hard-to-debug error in one of your
    TorchScript model, you can use this flag to force everything to run using native
    Python. Since TorchScript (scripting and tracing) are disabled with this flag,
    you can use tools like ``pdb`` to debug the model code.

    Given an example script::

        @torch.jit.script
        def scripted_fn(x : torch.Tensor):
            for i in range(12):
                x = x + x
            return x


        def fn(x):
            x = torch.neg(x)
            import pdb; pdb.set_trace()
            return scripted_fn(x)

        traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
        traced_fn(torch.rand(3, 4))

    Debugging this script with ``pdb`` works except for when we invoke the :func:`@torch.jit.script <torch.jit.script>`
    function. We can globally disable JIT, so that we can call the ``@torch.jit.script``
    function as a normal Python function and not compile it. If the above script
    is called ``disable_jit_example.py``, we can invoke it like so::

        $ PYTORCH_JIT=0 python disable_jit_example.py

    and we will be able to step into the ``@torch.jit.script`` function as a normal Python
    function. To disable the TorchScript compiler for a specific function, see
    :func:`@torch.jit.ignore <torch.jit.ignore>`.


Inspecting Code
^^^^^^^^^^^^^^^

TorchScript provides a code pretty-printer for all ``ScriptModule`` instances. This
pretty-printer gives an interpretation of the script method's code as valid
Python syntax. For example:

.. testcode::

    @torch.jit.script
    def foo(len):
        # type: (int) -> torch.Tensor
        rv = torch.zeros(3, 4)
        for i in range(len):
            if i < 10:
                rv = rv - 1.0
            else:
                rv = rv + 1.0
        return rv

    print(foo.code)

.. testoutput::
    :hide:

    ...

A ``ScriptModule`` with a single ``forward`` method will have an attribute
``code``, which you can use to inspect the ``ScriptModule``'s code.
If the ``ScriptModule`` has more than one method, you will need to access
``.code`` on the method itself and not the module. We can inspect the
code of a method named ``bar`` on a ScriptModule by accessing ``.bar.code``.
The example above produces this output: ::

    def foo(len: int) -> Tensor:
        rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
        rv0 = rv
        for i in range(len):
            if torch.lt(i, 10):
                rv1 = torch.sub(rv0, 1., 1)
            else:
                rv1 = torch.add(rv0, 1., 1)
            rv0 = rv1
        return rv0

This is TorchScript's compilation of the code for the ``forward`` method.
You can use this to ensure TorchScript (tracing or scripting) has captured
your model code correctly.


Interpreting Graphs
^^^^^^^^^^^^^^^^^^^
TorchScript also has a representation at a lower level than the code pretty-
printer, in the form of IR graphs.

TorchScript uses a static single assignment (SSA) intermediate representation
(IR) to represent computation. The instructions in this format consist of
ATen (the C++ backend of PyTorch) operators and other primitive operators,
including control flow operators for loops and conditionals. As an example:

.. testcode::

    @torch.jit.script
    def foo(len):
        # type: (int) -> torch.Tensor
        rv = torch.zeros(3, 4)
        for i in range(len):
            if i < 10:
                rv = rv - 1.0
            else:
                rv = rv + 1.0
        return rv

    print(foo.graph)

.. testoutput::
    :hide:

    ...

``graph`` follows the same rules described in the `Inspecting Code`_ section
with regard to ``forward`` method lookup.

The example script above produces the graph::

    graph(%len.1 : int):
      %24 : int = prim::Constant[value=1]()
      %17 : bool = prim::Constant[value=1]() # test.py:10:5
      %12 : bool? = prim::Constant()
      %10 : Device? = prim::Constant()
      %6 : int? = prim::Constant()
      %1 : int = prim::Constant[value=3]() # test.py:9:22
      %2 : int = prim::Constant[value=4]() # test.py:9:25
      %20 : int = prim::Constant[value=10]() # test.py:11:16
      %23 : float = prim::Constant[value=1]() # test.py:12:23
      %4 : int[] = prim::ListConstruct(%1, %2)
      %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
      %rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5
        block0(%i.1 : int, %rv.14 : Tensor):
          %21 : bool = aten::lt(%i.1, %20) # test.py:11:12
          %rv.13 : Tensor = prim::If(%21) # test.py:11:9
            block0():
              %rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18
              -> (%rv.3)
            block1():
              %rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18
              -> (%rv.6)
          -> (%17, %rv.13)
      return (%rv)


Take the instruction ``%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10`` for
example.

* ``%rv.1 : Tensor`` means we assign the output to a (unique) value named ``rv.1``, that value is of ``Tensor`` type and that we do not know its concrete shape.
* ``aten::zeros`` is the operator (equivalent to ``torch.zeros``) and the input list ``(%4, %6, %6, %10, %12)`` specifies which values in scope should be passed as inputs. The schema for built-in functions like ``aten::zeros`` can be found at `Builtin Functions`_.
* ``# test.py:9:10`` is the location in the original source file that generated this instruction. In this case, it is a file named `test.py`, on line 9, and at character 10.

Notice that operators can also have associated ``blocks``, namely the
``prim::Loop`` and ``prim::If`` operators. In the graph print-out, these
operators are formatted to reflect their equivalent source code forms
to facilitate easy debugging.

Graphs can be inspected as shown to confirm that the computation described
by a ``ScriptModule`` is correct, in both automated and manual fashion, as
described below.


Tracing Edge Cases
^^^^^^^^^^^^^^^^^^
There are some edge cases that exist where the trace of a given Python
function/module will not be representative of the underlying code. These
cases can include:

* Tracing of control flow that is dependent on inputs (e.g. tensor shapes)
* Tracing of in-place operations of tensor views (e.g. indexing on the left-hand side of an assignment)

Note that these cases may in fact be traceable in the future.


Automatic Trace Checking
^^^^^^^^^^^^^^^^^^^^^^^^
One way to automatically catch many errors in traces is by using ``check_inputs``
on the ``torch.jit.trace()`` API. ``check_inputs`` takes a list of tuples
of inputs that will be used to re-trace the computation and verify the
results. For example::

    def loop_in_traced_fn(x):
        result = x[0]
        for i in range(x.size(0)):
            result = result * x[i]
        return result

    inputs = (torch.rand(3, 4, 5),)
    check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

    traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)

Gives us the following diagnostic information::

    ERROR: Graphs differed across invocations!
    Graph diff:

                graph(%x : Tensor) {
                %1 : int = prim::Constant[value=0]()
                %2 : int = prim::Constant[value=0]()
                %result.1 : Tensor = aten::select(%x, %1, %2)
                %4 : int = prim::Constant[value=0]()
                %5 : int = prim::Constant[value=0]()
                %6 : Tensor = aten::select(%x, %4, %5)
                %result.2 : Tensor = aten::mul(%result.1, %6)
                %8 : int = prim::Constant[value=0]()
                %9 : int = prim::Constant[value=1]()
                %10 : Tensor = aten::select(%x, %8, %9)
            -   %result : Tensor = aten::mul(%result.2, %10)
            +   %result.3 : Tensor = aten::mul(%result.2, %10)
            ?          ++
                %12 : int = prim::Constant[value=0]()
                %13 : int = prim::Constant[value=2]()
                %14 : Tensor = aten::select(%x, %12, %13)
            +   %result : Tensor = aten::mul(%result.3, %14)
            +   %16 : int = prim::Constant[value=0]()
            +   %17 : int = prim::Constant[value=3]()
            +   %18 : Tensor = aten::select(%x, %16, %17)
            -   %15 : Tensor = aten::mul(%result, %14)
            ?     ^                                 ^
            +   %19 : Tensor = aten::mul(%result, %18)
            ?     ^                                 ^
            -   return (%15);
            ?             ^
            +   return (%19);
            ?             ^
                }


This message indicates to us that the computation differed between when
we first traced it and when we traced it with the ``check_inputs``. Indeed,
the loop within the body of ``loop_in_traced_fn`` depends on the shape
of the input ``x``, and thus when we try another ``x`` with a different
shape, the trace differs.

In this case, data-dependent control flow like this can be captured using
:func:`torch.jit.script` instead:

.. testcode::

    def fn(x):
        result = x[0]
        for i in range(x.size(0)):
            result = result * x[i]
        return result

    inputs = (torch.rand(3, 4, 5),)
    check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

    scripted_fn = torch.jit.script(fn)
    print(scripted_fn.graph)
    #print(str(scripted_fn.graph).strip())

    for input_tuple in [inputs] + check_inputs:
        torch.testing.assert_allclose(fn(*input_tuple), scripted_fn(*input_tuple))

.. testoutput::
    :hide:

    ...


Which produces::

    graph(%x : Tensor) {
        %5 : bool = prim::Constant[value=1]()
        %1 : int = prim::Constant[value=0]()
        %result.1 : Tensor = aten::select(%x, %1, %1)
        %4 : int = aten::size(%x, %1)
        %result : Tensor = prim::Loop(%4, %5, %result.1)
        block0(%i : int, %7 : Tensor) {
            %10 : Tensor = aten::select(%x, %1, %i)
            %result.2 : Tensor = aten::mul(%7, %10)
            -> (%5, %result.2)
        }
        return (%result);
    }

Tracer Warnings
^^^^^^^^^^^^^^^
The tracer produces warnings for several problematic patterns in traced
computation. As an example, take a trace of a function that contains an
in-place assignment on a slice (a view) of a Tensor:

.. testcode::

    def fill_row_zero(x):
        x[0] = torch.rand(*x.shape[1:2])
        return x

    traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
    print(traced.graph)

.. testoutput::
    :hide:

    ...

Produces several warnings and a graph which simply returns the input::

    fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
        x[0] = torch.rand(*x.shape[1:2])
    fill_row_zero.py:6: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
    Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%)
        traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
    graph(%0 : Float(3, 4)) {
        return (%0);
    }

We can fix this by modifying the code to not use the in-place update, but
rather build up the result tensor out-of-place with ``torch.cat``:

.. testcode::

    def fill_row_zero(x):
        x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0)
        return x

    traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
    print(traced.graph)

.. testoutput::
    :hide:

    ...

.. _Builtin functions:

Builtin Functions
~~~~~~~~~~~~~~~~~

TorchScript supports a subset of the builtin tensor and neural network
functions that PyTorch provides. Most methods on Tensor as well as functions in
the ``torch`` namespace, all functions in ``torch.nn.functional`` and all
modules from ``torch.nn`` are supported in TorchScript, excluding those in the
table below. For unsupported modules, we suggest using :meth:`torch.jit.trace`.

Unsupported ``torch.nn`` Modules  ::

    torch.nn.modules.adaptive.AdaptiveLogSoftmaxWithLoss
    torch.nn.modules.normalization.CrossMapLRN2d
    torch.nn.modules.rnn.RNN


See :ref:`builtin-functions` for a full reference of supported functions


Frequently Asked Questions
--------------------------

Q: I would like to train a model on GPU and do inference on CPU. What are the
best practices?
   First convert your model from GPU to CPU and then save it, like so: ::

      cpu_model = gpu_model.cpu()
      sample_input_cpu = sample_input_gpu.cpu()
      traced_cpu = torch.jit.trace(traced_cpu, sample_input_cpu)
      torch.jit.save(traced_cpu, "cpu.pth")

      traced_gpu = torch.jit.trace(traced_gpu, sample_input_gpu)
      torch.jit.save(traced_gpu, "gpu.pth")

      # ... later, when using the model:

      if use_gpu:
        model = torch.jit.load("gpu.pth")
      else:
        model = torch.jit.load("cpu.pth")

      model(input)

   This is recommended because the tracer may witness tensor creation on a
   specific device, so casting an already-loaded model may have unexpected
   effects. Casting the model *before* saving it ensures that the tracer has
   the correct device information.


Q: How do I store attributes on a ``ScriptModule``?

    Say we have a model like:

    .. testcode::

        class Model(nn.Module):
            def __init__(self):
                super(Model, self).__init__()
                self.x = 2

            def forward(self):
                return self.x

        m = torch.jit.script(Model())



    If ``Model`` is instantiated it will result in a compilation error
    since the compiler doesn't know about ``x``. There are 4 ways to inform the
    compiler of attributes on ``ScriptModule``:

    1. ``nn.Parameter`` - Values wrapped in ``nn.Parameter`` will work as they
    do on ``nn.Module``\s

    2. ``register_buffer`` - Values wrapped in ``register_buffer`` will work as
    they do on ``nn.Module``\s. This is equivalent to an attribute (see 4) of type
    ``Tensor``.

    3. Constants - Annotating a class member as ``Final`` (or adding it to a list called
    ``__constants__`` at the class definition level) will mark the contained names
    as constants. Constants are saved directly in the code of the model. See
    `Python-defined Constants`_ for details.

    4. Attributes - Values that are a `supported type`_ can be added as mutable
    attributes. Most types can be inferred but some may need to be specified, see
    `Module Attributes`_ for details.



Q: I would like to trace module's method but I keep getting this error:

``RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient``

    This error usually means that the method you are tracing uses a module's parameters and
    you are passing the module's method instead of the module instance (e.g. ``my_module_instance.forward`` vs ``my_module_instance``).
      - Invoking ``trace`` with a module's method captures module parameters (which may require gradients) as **constants**.
      - On the other hand, invoking ``trace`` with module's instance (e.g. ``my_module``) creates a new module and correctly copies parameters into the new module, so they can accumulate gradients if required.

    To trace a specific method on a module, see :func:`torch.jit.trace_module <torch.jit.trace_module>`