.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "tutorials/export-to-executorch-tutorial.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_tutorials_export-to-executorch-tutorial.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_tutorials_export-to-executorch-tutorial.py:


Exporting to ExecuTorch Tutorial
================================

**Author:** `Angela Yi <https://github.com/angelayi>`__

.. GENERATED FROM PYTHON SOURCE LINES 16-29

ExecuTorch is a unified ML stack for lowering PyTorch models to edge devices.
It introduces improved entry points to perform model, device, and/or use-case
specific optimizations such as backend delegation, user-defined compiler
transformations, default or user-defined memory planning, and more.

At a high level, the workflow looks as follows:

.. image:: ../executorch_stack.png
  :width: 560

In this tutorial, we will cover the APIs in the "Program preparation" steps to
lower a PyTorch model to a format which can be loaded to device and run on the
ExecuTorch runtime.

.. GENERATED FROM PYTHON SOURCE LINES 31-36

Prerequisites
-------------

To run this tutorial, you’ll first need to
`Set up your ExecuTorch environment <../getting-started-setup.html>`__.

.. GENERATED FROM PYTHON SOURCE LINES 38-50

Exporting a Model
-----------------

Note: The Export APIs are still undergoing changes to align better with the
longer term state of export. Please refer to this
`issue <https://github.com/pytorch/executorch/issues/290>`__ for more details.

The first step of lowering to ExecuTorch is to export the given model (any
callable or ``torch.nn.Module``) to a graph representation. This is done via
``torch.export``, which takes in an ``torch.nn.Module``, a tuple of
positional arguments, optionally a dictionary of keyword arguments (not shown
in the example), and a list of dynamic shapes (covered later).

.. GENERATED FROM PYTHON SOURCE LINES 50-72

.. code-block:: default


    import torch
    from torch.export import export, ExportedProgram


    class SimpleConv(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv = torch.nn.Conv2d(
                in_channels=3, out_channels=16, kernel_size=3, padding=1
            )
            self.relu = torch.nn.ReLU()

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            a = self.conv(x)
            return self.relu(a)


    example_args = (torch.randn(1, 3, 256, 256),)
    aten_dialect: ExportedProgram = export(SimpleConv(), example_args)
    print(aten_dialect)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]"):
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:64 in forward, code: a = self.conv(x)
                conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1]);  x = p_conv_weight = p_conv_bias = None
            
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:65 in forward, code: return self.relu(a)
                relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(conv2d);  conv2d = None
                return (relu,)
            
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='relu'), target=None)])
    Range constraints: {}





.. GENERATED FROM PYTHON SOURCE LINES 73-83

The output of ``torch.export.export`` is a fully flattened graph (meaning the
graph does not contain any module hierarchy, except in the case of control
flow operators). Additionally, the graph is purely functional, meaning it does
not contain operations with side effects such as mutations or aliasing.

More specifications about the result of ``torch.export`` can be found
`here <https://pytorch.org/docs/main/export.html>`__ .

The graph returned by ``torch.export`` only contains functional ATen operators
(~2000 ops), which we will call the ``ATen Dialect``.

.. GENERATED FROM PYTHON SOURCE LINES 85-91

Expressing Dynamism
^^^^^^^^^^^^^^^^^^^

By default, the exporting flow will trace the program assuming that all input
shapes are static, so if we run the program with inputs shapes that are
different than the ones we used while tracing, we will run into an error:

.. GENERATED FROM PYTHON SOURCE LINES 91-115

.. code-block:: default


    import traceback as tb


    class Basic(torch.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
            return x + y


    example_args = (torch.randn(3, 3), torch.randn(3, 3))
    aten_dialect: ExportedProgram = export(Basic(), example_args)

    # Works correctly
    print(aten_dialect.module()(torch.ones(3, 3), torch.ones(3, 3)))

    # Errors
    try:
        print(aten_dialect.module()(torch.ones(3, 2), torch.ones(3, 2)))
    except Exception:
        tb.print_exc()





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    tensor([[2., 2., 2.],
            [2., 2., 2.],
            [2., 2., 2.]])
    Traceback (most recent call last):
      File "/pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py", line 111, in <module>
        print(aten_dialect.module()(torch.ones(3, 2), torch.ones(3, 2)))
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 822, in call_wrapped
        return self._wrapped_call(self, *args, **kwargs)
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 400, in __call__
        raise e
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 387, in __call__
        return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
        return inner()
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1772, in inner
        args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
        return fn(*args, **kwargs)
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_unlift.py", line 49, in _check_input_constraints_pre_hook
        _check_input_constraints_for_graph(
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_export/utils.py", line 360, in _check_input_constraints_for_graph
        raise RuntimeError(
    RuntimeError: Expected input at *args[0].shape[1] to be equal to 3, but got 2




.. GENERATED FROM PYTHON SOURCE LINES 116-118

To express that some input shapes are dynamic, we can insert dynamic
 shapes to the exporting flow. This is done through the ``Dim`` API:

.. GENERATED FROM PYTHON SOURCE LINES 118-138

.. code-block:: default


    from torch.export import Dim


    class Basic(torch.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
            return x + y


    example_args = (torch.randn(3, 3), torch.randn(3, 3))
    dim1_x = Dim("dim1_x", min=1, max=10)
    dynamic_shapes = {"x": {1: dim1_x}, "y": {1: dim1_x}}
    aten_dialect: ExportedProgram = export(
        Basic(), example_args, dynamic_shapes=dynamic_shapes
    )
    print(aten_dialect)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, s0]", y: "f32[3, s0]"):
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:127 in forward, code: return x + y
                add: "f32[3, s0]" = torch.ops.aten.add.Tensor(x, y);  x = y = None
                return (add,)
            
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
    Range constraints: {s0: VR[1, 10]}





.. GENERATED FROM PYTHON SOURCE LINES 139-147

Note that that the inputs ``arg0_1`` and ``arg1_1`` now have shapes (3, s0),
with ``s0`` being a symbol representing that this dimension can be a range
of values.

Additionally, we can see in the **Range constraints** that value of ``s0`` has
the range [1, 10], which was specified by our dynamic shapes.

Now let's try running the model with different shapes:

.. GENERATED FROM PYTHON SOURCE LINES 147-165

.. code-block:: default


    # Works correctly
    print(aten_dialect.module()(torch.ones(3, 3), torch.ones(3, 3)))
    print(aten_dialect.module()(torch.ones(3, 2), torch.ones(3, 2)))

    # Errors because it violates our constraint that input 0, dim 1 <= 10
    try:
        print(aten_dialect.module()(torch.ones(3, 15), torch.ones(3, 15)))
    except Exception:
        tb.print_exc()

    # Errors because it violates our constraint that input 0, dim 1 == input 1, dim 1
    try:
        print(aten_dialect.module()(torch.ones(3, 3), torch.ones(3, 2)))
    except Exception:
        tb.print_exc()






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    tensor([[2., 2., 2.],
            [2., 2., 2.],
            [2., 2., 2.]])
    tensor([[2., 2.],
            [2., 2.],
            [2., 2.]])
    Traceback (most recent call last):
      File "/pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py", line 154, in <module>
        print(aten_dialect.module()(torch.ones(3, 15), torch.ones(3, 15)))
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 822, in call_wrapped
        return self._wrapped_call(self, *args, **kwargs)
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 400, in __call__
        raise e
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 387, in __call__
        return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
        return inner()
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1772, in inner
        args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
        return fn(*args, **kwargs)
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_unlift.py", line 49, in _check_input_constraints_pre_hook
        _check_input_constraints_for_graph(
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_export/utils.py", line 347, in _check_input_constraints_for_graph
        raise RuntimeError(
    RuntimeError: Expected input at *args[0].shape[1] to be <= 10, but got 15
    Traceback (most recent call last):
      File "/pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py", line 160, in <module>
        print(aten_dialect.module()(torch.ones(3, 3), torch.ones(3, 2)))
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 822, in call_wrapped
        return self._wrapped_call(self, *args, **kwargs)
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 400, in __call__
        raise e
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 387, in __call__
        return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
        return inner()
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1772, in inner
        args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
        return fn(*args, **kwargs)
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_unlift.py", line 49, in _check_input_constraints_pre_hook
        _check_input_constraints_for_graph(
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_export/utils.py", line 302, in _check_input_constraints_for_graph
        raise RuntimeError(
    RuntimeError: Expected input at *args[1].shape[1] to be equal to 3, but got 2




.. GENERATED FROM PYTHON SOURCE LINES 166-176

Addressing Untraceable Code
^^^^^^^^^^^^^^^^^^^^^^^^^^^

As our goal is to capture the entire computational graph from a PyTorch
program, we might ultimately run into untraceable parts of programs. To
address these issues, the
`torch.export documentation <https://pytorch.org/docs/main/export.html#limitations-of-torch-export>`__,
or the
`torch.export tutorial <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html>`__
would be the best place to look.

.. GENERATED FROM PYTHON SOURCE LINES 178-195

Performing Quantization
-----------------------

To quantize a model, we first need to capture the graph with
``torch.export.export_for_training``, perform quantization, and then
call ``torch.export``. ``torch.export.export_for_training`` returns a
graph which contains ATen operators which are Autograd safe, meaning they are
safe for eager-mode training, which is needed for quantization. We will call
the graph at this level, the ``Pre-Autograd ATen Dialect`` graph.

Compared to
`FX Graph Mode Quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html>`__,
we will need to call two new APIs: ``prepare_pt2e`` and ``convert_pt2e``
instead of ``prepare_fx`` and ``convert_fx``. It differs in that
``prepare_pt2e`` takes a backend-specific ``Quantizer`` as an argument, which
will annotate the nodes in the graph with information needed to quantize the
model properly for a specific backend.

.. GENERATED FROM PYTHON SOURCE LINES 195-220

.. code-block:: default


    from torch.export import export_for_training

    example_args = (torch.randn(1, 3, 256, 256),)
    pre_autograd_aten_dialect = export_for_training(SimpleConv(), example_args).module()
    print("Pre-Autograd ATen Dialect Graph")
    print(pre_autograd_aten_dialect)

    from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
    from torch.ao.quantization.quantizer.xnnpack_quantizer import (
        get_symmetric_quantization_config,
        XNNPACKQuantizer,
    )

    quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
    prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer)
    # calibrate with a sample dataset
    converted_graph = convert_pt2e(prepared_graph)
    print("Quantized Graph")
    print(converted_graph)

    aten_dialect: ExportedProgram = export(converted_graph, example_args)
    print("ATen Dialect Graph")
    print(aten_dialect)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Pre-Autograd ATen Dialect Graph
    GraphModule(
      (conv): Module()
    )



    def forward(self, x):
        x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
        conv_weight = self.conv.weight
        conv_bias = self.conv.bias
        conv2d = torch.ops.aten.conv2d.default(x, conv_weight, conv_bias, [1, 1], [1, 1]);  x = conv_weight = conv_bias = None
        relu = torch.ops.aten.relu.default(conv2d);  conv2d = None
        return pytree.tree_unflatten((relu,), self._out_spec)
    
    # To see more debug info, please use `graph_module.print_readable()`
    /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/ao/quantization/utils.py:408: UserWarning: must run observer before calling calculate_qparams. Returning default values.
      warnings.warn(
    /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/ao/quantization/observer.py:1318: UserWarning: must run observer before calling calculate_qparams.                                    Returning default scale and zero point 
      warnings.warn(
    Quantized Graph
    GraphModule(
      (conv): Module()
    )



    def forward(self, x):
        x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
        _frozen_param0 = self._frozen_param0
        dequantize_per_tensor_default = torch.ops.quantized_decomposed.dequantize_per_tensor.default(_frozen_param0, 1.0, 0, -127, 127, torch.int8);  _frozen_param0 = None
        conv_bias = self.conv.bias
        quantize_per_tensor_default_1 = torch.ops.quantized_decomposed.quantize_per_tensor.default(x, 1.0, 0, -128, 127, torch.int8);  x = None
        dequantize_per_tensor_default_1 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_1, 1.0, 0, -128, 127, torch.int8);  quantize_per_tensor_default_1 = None
        conv2d = torch.ops.aten.conv2d.default(dequantize_per_tensor_default_1, dequantize_per_tensor_default, conv_bias, [1, 1], [1, 1]);  dequantize_per_tensor_default_1 = dequantize_per_tensor_default = conv_bias = None
        relu = torch.ops.aten.relu.default(conv2d);  conv2d = None
        quantize_per_tensor_default_2 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu, 1.0, 0, -128, 127, torch.int8);  relu = None
        dequantize_per_tensor_default_2 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_2, 1.0, 0, -128, 127, torch.int8);  quantize_per_tensor_default_2 = None
        return pytree.tree_unflatten((dequantize_per_tensor_default_2,), self._out_spec)
    
    # To see more debug info, please use `graph_module.print_readable()`
    ATen Dialect Graph
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, p_conv_bias: "f32[16]", b__frozen_param0: "i8[16, 3, 3, 3]", x: "f32[1, 3, 256, 256]"):
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:64 in forward, code: a = self.conv(x)
                dequantize_per_tensor: "f32[16, 3, 3, 3]" = torch.ops.quantized_decomposed.dequantize_per_tensor.default(b__frozen_param0, 1.0, 0, -127, 127, torch.int8);  b__frozen_param0 = None
            
                 # File: <eval_with_key>.203:9 in forward, code: quantize_per_tensor_default_1 = torch.ops.quantized_decomposed.quantize_per_tensor.default(x, 1.0, 0, -128, 127, torch.int8);  x = None
                quantize_per_tensor: "i8[1, 3, 256, 256]" = torch.ops.quantized_decomposed.quantize_per_tensor.default(x, 1.0, 0, -128, 127, torch.int8);  x = None
            
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:64 in forward, code: a = self.conv(x)
                dequantize_per_tensor_1: "f32[1, 3, 256, 256]" = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor, 1.0, 0, -128, 127, torch.int8);  quantize_per_tensor = None
                conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(dequantize_per_tensor_1, dequantize_per_tensor, p_conv_bias, [1, 1], [1, 1]);  dequantize_per_tensor_1 = dequantize_per_tensor = p_conv_bias = None
            
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:65 in forward, code: return self.relu(a)
                relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(conv2d);  conv2d = None
                quantize_per_tensor_1: "i8[1, 16, 256, 256]" = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu, 1.0, 0, -128, 127, torch.int8);  relu = None
            
                 # File: <eval_with_key>.203:14 in forward, code: dequantize_per_tensor_default_2 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_2, 1.0, 0, -128, 127, torch.int8);  quantize_per_tensor_default_2 = None
                dequantize_per_tensor_2: "f32[1, 16, 256, 256]" = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_1, 1.0, 0, -128, 127, torch.int8);  quantize_per_tensor_1 = None
                return (dequantize_per_tensor_2,)
            
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b__frozen_param0'), target='_frozen_param0', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='dequantize_per_tensor_2'), target=None)])
    Range constraints: {}





.. GENERATED FROM PYTHON SOURCE LINES 221-224

More information on how to quantize a model, and how a backend can implement a
``Quantizer`` can be found
`here <https://pytorch.org/docs/main/quantization.html#prototype-pytorch-2-export-quantization>`__.

.. GENERATED FROM PYTHON SOURCE LINES 226-242

Lowering to Edge Dialect
------------------------

After exporting and lowering the graph to the ``ATen Dialect``, the next step
is to lower to the ``Edge Dialect``, in which specializations that are useful
for edge devices but not necessary for general (server) environments will be
applied.
Some of these specializations include:

- DType specialization
- Scalar to tensor conversion
- Converting all ops to the ``executorch.exir.dialects.edge`` namespace.

Note that this dialect is still backend (or target) agnostic.

The lowering is done through the ``to_edge`` API.

.. GENERATED FROM PYTHON SOURCE LINES 242-252

.. code-block:: default


    from executorch.exir import EdgeProgramManager, to_edge

    example_args = (torch.randn(1, 3, 256, 256),)
    aten_dialect: ExportedProgram = export(SimpleConv(), example_args)

    edge_program: EdgeProgramManager = to_edge(aten_dialect)
    print("Edge Dialect Graph")
    print(edge_program.exported_program())





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Edge Dialect Graph
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]"):
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:64 in forward, code: a = self.conv(x)
                aten_convolution_default: "f32[1, 16, 256, 256]" = executorch_exir_dialects_edge__ops_aten_convolution_default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  x = p_conv_weight = p_conv_bias = None
            
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:65 in forward, code: return self.relu(a)
                aten_relu_default: "f32[1, 16, 256, 256]" = executorch_exir_dialects_edge__ops_aten_relu_default(aten_convolution_default);  aten_convolution_default = None
                return (aten_relu_default,)
            
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_relu_default'), target=None)])
    Range constraints: {}





.. GENERATED FROM PYTHON SOURCE LINES 253-257

``to_edge()`` returns an ``EdgeProgramManager`` object, which contains the
exported programs which will be placed on this device. This data structure
allows users to export multiple programs and combine them into one binary. If
there is only one program, it will by default be saved to the name "forward".

.. GENERATED FROM PYTHON SOURCE LINES 257-282

.. code-block:: default



    class Encode(torch.nn.Module):
        def forward(self, x):
            return torch.nn.functional.linear(x, torch.randn(5, 10))


    class Decode(torch.nn.Module):
        def forward(self, x):
            return torch.nn.functional.linear(x, torch.randn(10, 5))


    encode_args = (torch.randn(1, 10),)
    aten_encode: ExportedProgram = export(Encode(), encode_args)

    decode_args = (torch.randn(1, 5),)
    aten_decode: ExportedProgram = export(Decode(), decode_args)

    edge_program: EdgeProgramManager = to_edge(
        {"encode": aten_encode, "decode": aten_decode}
    )
    for method in edge_program.methods:
        print(f"Edge Dialect graph of {method}")
        print(edge_program.exported_program(method))





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Edge Dialect graph of encode
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[1, 10]"):
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:261 in forward, code: return torch.nn.functional.linear(x, torch.randn(5, 10))
                aten_randn_default: "f32[5, 10]" = executorch_exir_dialects_edge__ops_aten_randn_default([5, 10], device = device(type='cpu'), pin_memory = False)
                aten_permute_copy_default: "f32[10, 5]" = executorch_exir_dialects_edge__ops_aten_permute_copy_default(aten_randn_default, [1, 0]);  aten_randn_default = None
                aten_mm_default: "f32[1, 5]" = executorch_exir_dialects_edge__ops_aten_mm_default(x, aten_permute_copy_default);  x = aten_permute_copy_default = None
                return (aten_mm_default,)
            
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_mm_default'), target=None)])
    Range constraints: {}

    Edge Dialect graph of decode
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[1, 5]"):
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:266 in forward, code: return torch.nn.functional.linear(x, torch.randn(10, 5))
                aten_randn_default: "f32[10, 5]" = executorch_exir_dialects_edge__ops_aten_randn_default([10, 5], device = device(type='cpu'), pin_memory = False)
                aten_permute_copy_default: "f32[5, 10]" = executorch_exir_dialects_edge__ops_aten_permute_copy_default(aten_randn_default, [1, 0]);  aten_randn_default = None
                aten_mm_default: "f32[1, 10]" = executorch_exir_dialects_edge__ops_aten_mm_default(x, aten_permute_copy_default);  x = aten_permute_copy_default = None
                return (aten_mm_default,)
            
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_mm_default'), target=None)])
    Range constraints: {}





.. GENERATED FROM PYTHON SOURCE LINES 283-292

We can also run additional passes on the exported program through
the ``transform`` API. An in-depth documentation on how to write
transformations can be found
`here <../compiler-custom-compiler-passes.html>`__.

Note that since the graph is now in the Edge Dialect, all passes must also
result in a valid Edge Dialect graph (specifically one thing to point out is
that the operators are now in the ``executorch.exir.dialects.edge`` namespace,
rather than the ``torch.ops.aten`` namespace.

.. GENERATED FROM PYTHON SOURCE LINES 292-317

.. code-block:: default


    example_args = (torch.randn(1, 3, 256, 256),)
    aten_dialect: ExportedProgram = export(SimpleConv(), example_args)
    edge_program: EdgeProgramManager = to_edge(aten_dialect)
    print("Edge Dialect Graph")
    print(edge_program.exported_program())

    from executorch.exir.dialects._ops import ops as exir_ops
    from executorch.exir.pass_base import ExportPass


    class ConvertReluToSigmoid(ExportPass):
        def call_operator(self, op, args, kwargs, meta):
            if op == exir_ops.edge.aten.relu.default:
                return super().call_operator(
                    exir_ops.edge.aten.sigmoid.default, args, kwargs, meta
                )
            else:
                return super().call_operator(op, args, kwargs, meta)


    transformed_edge_program = edge_program.transform((ConvertReluToSigmoid(),))
    print("Transformed Edge Dialect Graph")
    print(transformed_edge_program.exported_program())





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Edge Dialect Graph
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]"):
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:64 in forward, code: a = self.conv(x)
                aten_convolution_default: "f32[1, 16, 256, 256]" = executorch_exir_dialects_edge__ops_aten_convolution_default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  x = p_conv_weight = p_conv_bias = None
            
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:65 in forward, code: return self.relu(a)
                aten_relu_default: "f32[1, 16, 256, 256]" = executorch_exir_dialects_edge__ops_aten_relu_default(aten_convolution_default);  aten_convolution_default = None
                return (aten_relu_default,)
            
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_relu_default'), target=None)])
    Range constraints: {}

    Transformed Edge Dialect Graph
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]"):
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:64 in forward, code: a = self.conv(x)
                aten_convolution_default: "f32[1, 16, 256, 256]" = executorch_exir_dialects_edge__ops_aten_convolution_default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  x = p_conv_weight = p_conv_bias = None
            
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:65 in forward, code: return self.relu(a)
                aten_sigmoid_default: "f32[1, 16, 256, 256]" = executorch_exir_dialects_edge__ops_aten_sigmoid_default(aten_convolution_default);  aten_convolution_default = None
                return (aten_sigmoid_default,)
            
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_sigmoid_default'), target=None)])
    Range constraints: {}





.. GENERATED FROM PYTHON SOURCE LINES 318-322

Note: if you see error like ``torch._export.verifier.SpecViolationError:
Operator torch._ops.aten._native_batch_norm_legit_functional.default is not
Aten Canonical``,
please file an issue in https://github.com/pytorch/executorch/issues and we're happy to help!

.. GENERATED FROM PYTHON SOURCE LINES 325-340

Delegating to a Backend
-----------------------

We can now delegate parts of the graph or the whole graph to a third-party
backend through the ``to_backend`` API.  An in-depth documentation on the
specifics of backend delegation, including how to delegate to a backend and
how to implement a backend, can be found
`here <../compiler-delegate-and-partitioner.html>`__.

There are three ways for using this API:

1. We can lower the whole module.
2. We can take the lowered module, and insert it in another larger module.
3. We can partition the module into subgraphs that are lowerable, and then
   lower those subgraphs to a backend.

.. GENERATED FROM PYTHON SOURCE LINES 342-348

Lowering the Whole Module
^^^^^^^^^^^^^^^^^^^^^^^^^

To lower an entire module, we can pass ``to_backend`` the backend name, the
module to be lowered, and a list of compile specs to help the backend with the
lowering process.

.. GENERATED FROM PYTHON SOURCE LINES 348-385

.. code-block:: default



    class LowerableModule(torch.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, x):
            return torch.sin(x)


    # Export and lower the module to Edge Dialect
    example_args = (torch.ones(1),)
    aten_dialect: ExportedProgram = export(LowerableModule(), example_args)
    edge_program: EdgeProgramManager = to_edge(aten_dialect)
    to_be_lowered_module = edge_program.exported_program()

    from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend

    # Import the backend
    from executorch.exir.backend.test.backend_with_compiler_demo import (  # noqa
        BackendWithCompilerDemo,
    )

    # Lower the module
    lowered_module: LoweredBackendModule = to_backend(
        "BackendWithCompilerDemo", to_be_lowered_module, []
    )
    print(lowered_module)
    print(lowered_module.backend_id)
    print(lowered_module.processed_bytes)
    print(lowered_module.original_module)

    # Serialize and save it to a file
    save_path = "delegate.pte"
    with open(save_path, "wb") as f:
        f.write(lowered_module.buffer())





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    LoweredBackendModule()
    BackendWithCompilerDemo
    b'1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#'
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[1]"):
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:355 in forward, code: return torch.sin(x)
                aten_sin_default: "f32[1]" = executorch_exir_dialects_edge__ops_aten_sin_default(x);  x = None
                return (aten_sin_default,)
            
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_sin_default'), target=None)])
    Range constraints: {}





.. GENERATED FROM PYTHON SOURCE LINES 386-394

In this call, ``to_backend`` will return a ``LoweredBackendModule``. Some
important attributes of the ``LoweredBackendModule`` are:

- ``backend_id``: The name of the backend this lowered module will run on in
  the runtime
- ``processed_bytes``: a binary blob which will tell the backend how to run
  this program in the runtime
- ``original_module``: the original exported module

.. GENERATED FROM PYTHON SOURCE LINES 396-401

Compose the Lowered Module into Another Module
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In cases where we want to reuse this lowered module in multiple programs, we
can compose this lowered module with another module.

.. GENERATED FROM PYTHON SOURCE LINES 401-436

.. code-block:: default



    class NotLowerableModule(torch.nn.Module):
        def __init__(self, bias):
            super().__init__()
            self.bias = bias

        def forward(self, a, b):
            return torch.add(torch.add(a, b), self.bias)


    class ComposedModule(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.non_lowerable = NotLowerableModule(torch.ones(1) * 0.3)
            self.lowerable = lowered_module

        def forward(self, x):
            a = self.lowerable(x)
            b = self.lowerable(a)
            ret = self.non_lowerable(a, b)
            return a, b, ret


    example_args = (torch.ones(1),)
    aten_dialect: ExportedProgram = export(ComposedModule(), example_args)
    edge_program: EdgeProgramManager = to_edge(aten_dialect)
    exported_program = edge_program.exported_program()
    print("Edge Dialect graph")
    print(exported_program)
    print("Lowered Module within the graph")
    print(exported_program.graph_module.lowered_module_0.backend_id)
    print(exported_program.graph_module.lowered_module_0.processed_bytes)
    print(exported_program.graph_module.lowered_module_0.original_module)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
      getattr_node = gm.graph.get_attr(lifted_node)
    /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph.py:1801: UserWarning: Node non_lowerable_bias target non_lowerable.bias bias of non_lowerable does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
      warnings.warn(
    Edge Dialect graph
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, c_non_lowerable_bias: "f32[1]", x: "f32[1]"):
                 # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/executorch/exir/lowered_backend_module.py:344 in forward, code: return executorch_call_delegate(self, *args)
                lowered_module_0 = self.lowered_module_0
                executorch_call_delegate: "f32[1]" = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, x);  lowered_module_0 = x = None
            
                 # File: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/executorch/exir/lowered_backend_module.py:344 in forward, code: return executorch_call_delegate(self, *args)
                lowered_module_1 = self.lowered_module_0
                executorch_call_delegate_1: "f32[1]" = torch.ops.higher_order.executorch_call_delegate(lowered_module_1, executorch_call_delegate);  lowered_module_1 = None
            
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:409 in forward, code: return torch.add(torch.add(a, b), self.bias)
                aten_add_tensor: "f32[1]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(executorch_call_delegate, executorch_call_delegate_1)
                aten_add_tensor_1: "f32[1]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_add_tensor, c_non_lowerable_bias);  aten_add_tensor = c_non_lowerable_bias = None
                return (executorch_call_delegate, executorch_call_delegate_1, aten_add_tensor_1)
            
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='c_non_lowerable_bias'), target='non_lowerable.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='executorch_call_delegate'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='executorch_call_delegate_1'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)])
    Range constraints: {}

    Lowered Module within the graph
    BackendWithCompilerDemo
    b'1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#'
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[1]"):
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:355 in forward, code: return torch.sin(x)
                aten_sin_default: "f32[1]" = executorch_exir_dialects_edge__ops_aten_sin_default(x);  x = None
                return (aten_sin_default,)
            
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_sin_default'), target=None)])
    Range constraints: {}





.. GENERATED FROM PYTHON SOURCE LINES 437-441

Notice that there is now a ``torch.ops.higher_order.executorch_call_delegate`` node in the
graph, which is calling ``lowered_module_0``. Additionally, the contents of
``lowered_module_0`` are the same as the ``lowered_module`` we created
previously.

.. GENERATED FROM PYTHON SOURCE LINES 443-451

Partition and Lower Parts of a Module
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

A separate lowering flow is to pass ``to_backend`` the module that we want to
lower, and a backend-specific partitioner. ``to_backend`` will use the
backend-specific partitioner to tag nodes in the module which are lowerable,
partition those nodes into subgraphs, and then create a
``LoweredBackendModule`` for each of those subgraphs.

.. GENERATED FROM PYTHON SOURCE LINES 451-478

.. code-block:: default



    class Foo(torch.nn.Module):
        def forward(self, a, x, b):
            y = torch.mm(a, x)
            z = y + b
            a = z - a
            y = torch.mm(a, x)
            z = y + b
            return z


    example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
    aten_dialect: ExportedProgram = export(Foo(), example_args)
    edge_program: EdgeProgramManager = to_edge(aten_dialect)
    exported_program = edge_program.exported_program()
    print("Edge Dialect graph")
    print(exported_program)

    from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo

    delegated_program = to_backend(exported_program, AddMulPartitionerDemo())
    print("Delegated program")
    print(delegated_program)
    print(delegated_program.graph_module.lowered_module_0.original_module)
    print(delegated_program.graph_module.lowered_module_1.original_module)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Edge Dialect graph
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, a: "f32[2, 2]", x: "f32[2, 2]", b: "f32[2, 2]"):
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:455 in forward, code: y = torch.mm(a, x)
                aten_mm_default: "f32[2, 2]" = executorch_exir_dialects_edge__ops_aten_mm_default(a, x)
            
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:456 in forward, code: z = y + b
                aten_add_tensor: "f32[2, 2]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_mm_default, b);  aten_mm_default = None
            
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:457 in forward, code: a = z - a
                aten_sub_tensor: "f32[2, 2]" = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_add_tensor, a);  aten_add_tensor = a = None
            
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:458 in forward, code: y = torch.mm(a, x)
                aten_mm_default_1: "f32[2, 2]" = executorch_exir_dialects_edge__ops_aten_mm_default(aten_sub_tensor, x);  aten_sub_tensor = x = None
            
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:459 in forward, code: z = y + b
                aten_add_tensor_1: "f32[2, 2]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_mm_default_1, b);  aten_mm_default_1 = b = None
                return (aten_add_tensor_1,)
            
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='a'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='b'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)])
    Range constraints: {}

    Delegated program
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, a: "f32[2, 2]", x: "f32[2, 2]", b: "f32[2, 2]"):
                # No stacktrace found for following nodes
                lowered_module_0 = self.lowered_module_0
                lowered_module_1 = self.lowered_module_1
                executorch_call_delegate_1 = torch.ops.higher_order.executorch_call_delegate(lowered_module_1, a, x, b);  lowered_module_1 = None
                getitem_1: "f32[2, 2]" = executorch_call_delegate_1[0];  executorch_call_delegate_1 = None
            
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:457 in forward, code: a = z - a
                aten_sub_tensor: "f32[2, 2]" = executorch_exir_dialects_edge__ops_aten_sub_Tensor(getitem_1, a);  getitem_1 = a = None
            
                # No stacktrace found for following nodes
                executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, aten_sub_tensor, x, b);  lowered_module_0 = aten_sub_tensor = x = b = None
                getitem: "f32[2, 2]" = executorch_call_delegate[0];  executorch_call_delegate = None
                return (getitem,)
            
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='a'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='b'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
    Range constraints: {}

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, aten_sub_tensor: "f32[2, 2]", x: "f32[2, 2]", b: "f32[2, 2]"):
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:458 in forward, code: y = torch.mm(a, x)
                aten_mm_default_1: "f32[2, 2]" = executorch_exir_dialects_edge__ops_aten_mm_default(aten_sub_tensor, x);  aten_sub_tensor = x = None
            
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:459 in forward, code: z = y + b
                aten_add_tensor_1: "f32[2, 2]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_mm_default_1, b);  aten_mm_default_1 = b = None
                return [aten_add_tensor_1]
            
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='aten_sub_tensor'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='b'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)])
    Range constraints: {}

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, a: "f32[2, 2]", x: "f32[2, 2]", b: "f32[2, 2]"):
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:455 in forward, code: y = torch.mm(a, x)
                aten_mm_default: "f32[2, 2]" = executorch_exir_dialects_edge__ops_aten_mm_default(a, x);  a = x = None
            
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:456 in forward, code: z = y + b
                aten_add_tensor: "f32[2, 2]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_mm_default, b);  aten_mm_default = b = None
                return [aten_add_tensor]
            
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='a'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='b'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor'), target=None)])
    Range constraints: {}





.. GENERATED FROM PYTHON SOURCE LINES 479-485

Notice that there are now 2 ``torch.ops.higher_order.executorch_call_delegate`` nodes in the
graph, one containing the operations `add, mul` and the other containing the
operations `mul, add`.

Alternatively, a more cohesive API to lower parts of a module is to directly
call ``to_backend`` on it:

.. GENERATED FROM PYTHON SOURCE LINES 485-506

.. code-block:: default



    class Foo(torch.nn.Module):
        def forward(self, a, x, b):
            y = torch.mm(a, x)
            z = y + b
            a = z - a
            y = torch.mm(a, x)
            z = y + b
            return z


    example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
    aten_dialect: ExportedProgram = export(Foo(), example_args)
    edge_program: EdgeProgramManager = to_edge(aten_dialect)
    exported_program = edge_program.exported_program()
    delegated_program = edge_program.to_backend(AddMulPartitionerDemo())

    print("Delegated program")
    print(delegated_program.exported_program())





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Delegated program
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, a: "f32[2, 2]", x: "f32[2, 2]", b: "f32[2, 2]"):
                # No stacktrace found for following nodes
                lowered_module_0 = self.lowered_module_0
                lowered_module_1 = self.lowered_module_1
                executorch_call_delegate_1 = torch.ops.higher_order.executorch_call_delegate(lowered_module_1, a, x, b);  lowered_module_1 = None
                getitem_1: "f32[2, 2]" = executorch_call_delegate_1[0];  executorch_call_delegate_1 = None
            
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:491 in forward, code: a = z - a
                aten_sub_tensor: "f32[2, 2]" = executorch_exir_dialects_edge__ops_aten_sub_Tensor(getitem_1, a);  getitem_1 = a = None
            
                # No stacktrace found for following nodes
                executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, aten_sub_tensor, x, b);  lowered_module_0 = aten_sub_tensor = x = b = None
                getitem: "f32[2, 2]" = executorch_call_delegate[0];  executorch_call_delegate = None
                return (getitem,)
            
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='a'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='b'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
    Range constraints: {}





.. GENERATED FROM PYTHON SOURCE LINES 507-519

Running User-Defined Passes and Memory Planning
-----------------------------------------------

As a final step of lowering, we can use the ``to_executorch()`` API to pass in
backend-specific passes, such as replacing sets of operators with a custom
backend operator, and a memory planning pass, to tell the runtime how to
allocate memory ahead of time when running the program.

A default memory planning pass is provided, but we can also choose a
backend-specific memory planning pass if it exists. More information on
writing a custom memory planning pass can be found
`here <../compiler-memory-planning.html>`__

.. GENERATED FROM PYTHON SOURCE LINES 519-535

.. code-block:: default


    from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager
    from executorch.exir.passes import MemoryPlanningPass

    executorch_program: ExecutorchProgramManager = edge_program.to_executorch(
        ExecutorchBackendConfig(
            passes=[],  # User-defined passes
            memory_planning_pass=MemoryPlanningPass(),  # Default memory planning pass
        )
    )

    print("ExecuTorch Dialect")
    print(executorch_program.exported_program())

    import executorch.exir as exir





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    ExecuTorch Dialect
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, a: "f32[2, 2]", x: "f32[2, 2]", b: "f32[2, 2]"):
                # No stacktrace found for following nodes
                alloc: "f32[2, 2]" = executorch_exir_memory_alloc(((2, 2), torch.float32))
            
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:489 in forward, code: y = torch.mm(a, x)
                aten_mm_default: "f32[2, 2]" = torch.ops.aten.mm.out(a, x, out = alloc);  alloc = None
            
                # No stacktrace found for following nodes
                alloc_1: "f32[2, 2]" = executorch_exir_memory_alloc(((2, 2), torch.float32))
            
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:490 in forward, code: z = y + b
                aten_add_tensor: "f32[2, 2]" = torch.ops.aten.add.out(aten_mm_default, b, out = alloc_1);  aten_mm_default = alloc_1 = None
            
                # No stacktrace found for following nodes
                alloc_2: "f32[2, 2]" = executorch_exir_memory_alloc(((2, 2), torch.float32))
            
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:491 in forward, code: a = z - a
                aten_sub_tensor: "f32[2, 2]" = torch.ops.aten.sub.out(aten_add_tensor, a, out = alloc_2);  aten_add_tensor = a = alloc_2 = None
            
                # No stacktrace found for following nodes
                alloc_3: "f32[2, 2]" = executorch_exir_memory_alloc(((2, 2), torch.float32))
            
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:492 in forward, code: y = torch.mm(a, x)
                aten_mm_default_1: "f32[2, 2]" = torch.ops.aten.mm.out(aten_sub_tensor, x, out = alloc_3);  aten_sub_tensor = x = alloc_3 = None
            
                # No stacktrace found for following nodes
                alloc_4: "f32[2, 2]" = executorch_exir_memory_alloc(((2, 2), torch.float32))
            
                 # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:493 in forward, code: z = y + b
                aten_add_tensor_1: "f32[2, 2]" = torch.ops.aten.add.out(aten_mm_default_1, b, out = alloc_4);  aten_mm_default_1 = b = alloc_4 = None
                return (aten_add_tensor_1,)
            
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='a'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='b'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)])
    Range constraints: {}





.. GENERATED FROM PYTHON SOURCE LINES 536-552

Notice that in the graph we now see operators like ``torch.ops.aten.sub.out``
and ``torch.ops.aten.div.out`` rather than ``torch.ops.aten.sub.Tensor`` and
``torch.ops.aten.div.Tensor``.

This is because between running the backend passes and memory planning passes,
to prepare the graph for memory planning, an out-variant pass is run on
the graph to convert all of the operators to their out variants. Instead of
allocating returned tensors in the kernel implementations, an operator's
``out`` variant will take in a prealloacated tensor to its out kwarg, and
store the result there, making it easier for memory planners to do tensor
lifetime analysis.

We also insert ``alloc`` nodes into the graph containing calls to a special
``executorch.exir.memory.alloc`` operator. This tells us how much memory is
needed to allocate each tensor output by the out-variant operator.


.. GENERATED FROM PYTHON SOURCE LINES 554-561

Saving to a File
----------------

Finally, we can save the ExecuTorch Program to a file and load it to a device
to be run.

Here is an example for an entire end-to-end workflow:

.. GENERATED FROM PYTHON SOURCE LINES 561-593

.. code-block:: default


    import torch
    from torch.export import export, export_for_training, ExportedProgram


    class M(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.param = torch.nn.Parameter(torch.rand(3, 4))
            self.linear = torch.nn.Linear(4, 5)

        def forward(self, x):
            return self.linear(x + self.param).clamp(min=0.0, max=1.0)


    example_args = (torch.randn(3, 4),)
    pre_autograd_aten_dialect = export_for_training(M(), example_args).module()
    # Optionally do quantization:
    # pre_autograd_aten_dialect = convert_pt2e(prepare_pt2e(pre_autograd_aten_dialect, CustomBackendQuantizer))
    aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
    edge_program: exir.EdgeProgramManager = exir.to_edge(aten_dialect)
    # Optionally do delegation:
    # edge_program = edge_program.to_backend(CustomBackendPartitioner)
    executorch_program: exir.ExecutorchProgramManager = edge_program.to_executorch(
        ExecutorchBackendConfig(
            passes=[],  # User-defined passes
        )
    )

    with open("model.pte", "wb") as file:
        file.write(executorch_program.buffer)








.. GENERATED FROM PYTHON SOURCE LINES 594-609

Conclusion
----------

In this tutorial, we went over the APIs and steps required to lower a PyTorch
program to a file that can be run on the ExecuTorch runtime.

Links Mentioned
^^^^^^^^^^^^^^^

- `torch.export Documentation <https://pytorch.org/docs/2.1/export.html>`__
- `Quantization Documentation <https://pytorch.org/docs/main/quantization.html#prototype-pytorch-2-export-quantization>`__
- `IR Spec <../ir-exir.html>`__
- `Writing Compiler Passes + Partitioner Documentation <../compiler-custom-compiler-passes.html>`__
- `Backend Delegation Documentation <../compiler-delegate-and-partitioner.html>`__
- `Memory Planning Documentation <../compiler-memory-planning.html>`__


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 3.236 seconds)


.. _sphx_glr_download_tutorials_export-to-executorch-tutorial.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example




    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: export-to-executorch-tutorial.py <export-to-executorch-tutorial.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: export-to-executorch-tutorial.ipynb <export-to-executorch-tutorial.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_