.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "tutorials/export-to-executorch-tutorial.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_tutorials_export-to-executorch-tutorial.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_tutorials_export-to-executorch-tutorial.py:


Exporting to ExecuTorch Tutorial
================================

**Author:** `Angela Yi <https://github.com/angelayi>`__

.. GENERATED FROM PYTHON SOURCE LINES 16-29

ExecuTorch is a unified ML stack for lowering PyTorch models to edge devices.
It introduces improved entry points to perform model, device, and/or use-case
specific optimizations such as backend delegation, user-defined compiler
transformations, default or user-defined memory planning, and more.

At a high level, the workflow looks as follows:

.. image:: ../executorch_stack.png
  :width: 560

In this tutorial, we will cover the APIs in the "Program preparation" steps to
lower a PyTorch model to a format which can be loaded to device and run on the
ExecuTorch runtime.

.. GENERATED FROM PYTHON SOURCE LINES 31-36

Prerequisites
-------------

To run this tutorial, you’ll first need to
`Set up your ExecuTorch environment <../getting-started-setup.html>`__.

.. GENERATED FROM PYTHON SOURCE LINES 38-53

Exporting a Model
-----------------

Note: The Export APIs are still undergoing changes to align better with the
longer term state of export. Please refer to this
`issue <https://github.com/pytorch/executorch/issues/290>`__ for more details.

The first step of lowering to ExecuTorch is to export the given model (any
callable or ``torch.nn.Module``) to a graph representation. This is done via
the two-stage APIs, ``torch._export.capture_pre_autograd_graph``, and
``torch.export``.

Both APIs take in a model (any callable or ``torch.nn.Module``), a tuple of
positional arguments, optionally a dictionary of keyword arguments (not shown
in the example), and a list of constraints (covered later).

.. GENERATED FROM PYTHON SOURCE LINES 53-81

.. code-block:: default


    import torch
    from torch._export import capture_pre_autograd_graph
    from torch.export import export, ExportedProgram


    class SimpleConv(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv = torch.nn.Conv2d(
                in_channels=3, out_channels=16, kernel_size=3, padding=1
            )
            self.relu = torch.nn.ReLU()

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            a = self.conv(x)
            return self.relu(a)


    example_args = (torch.randn(1, 3, 256, 256),)
    pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args)
    print("Pre-Autograd ATen Dialect Graph")
    print(pre_autograd_aten_dialect)

    aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
    print("ATen Dialect Graph")
    print(aten_dialect)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Pre-Autograd ATen Dialect Graph
    GraphModule()



    def forward(self, x):
        arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
        _param_constant0 = self._param_constant0
        _param_constant1 = self._param_constant1
        conv2d_default = torch.ops.aten.conv2d.default(arg0, _param_constant0, _param_constant1, [1, 1], [1, 1]);  arg0 = _param_constant0 = _param_constant1 = None
        relu_default = torch.ops.aten.relu.default(conv2d_default);  conv2d_default = None
        return pytree.tree_unflatten([relu_default], self._out_spec)
    
    # To see more debug info, please use `graph_module.print_readable()`
    ATen Dialect Graph
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256]):
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:68, code: a = self.conv(x)
                convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg2_1 = arg0_1 = arg1_1 = None
            
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:69, code: return self.relu(a)
                relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(convolution);  convolution = None
                return (relu,)
            
    Graph signature: ExportGraphSignature(parameters=['_param_constant0', '_param_constant1'], buffers=[], user_inputs=['arg2_1'], user_outputs=['relu'], inputs_to_parameters={'arg0_1': '_param_constant0', 'arg1_1': '_param_constant1'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Range constraints: {}
    Equality constraints: []





.. GENERATED FROM PYTHON SOURCE LINES 82-107

The output of ``torch._export.capture_pre_autograd_graph`` is a fully
flattened graph (meaning the graph does not contain any module hierarchy,
except in the case of control flow operators). Furthermore, the captured graph
contains only ATen operators (~3000 ops) which are Autograd safe, for example, safe
for eager mode training.

The output of ``torch.export`` further compiles the graph to a lower and
cleaner representation. Specifically, it has the following:

- The graph is purely functional, meaning it does not contain operations with
  side effects such as mutations or aliasing.
- The graph contains only a small defined
  `Core ATen IR <https://pytorch.org/docs/stable/torch.compiler_ir.html#core-aten-ir>`__
  operator set (~180 ops), along with registered custom operators.
- The nodes in the graph contain metadata captured during tracing, such as a
  stacktrace from user's code.

More specifications about the result of ``torch.export`` can be found
`here <https://pytorch.org/docs/2.1/export.html>`__ .

Since the result of ``torch.export`` is a graph containing the Core ATen
operators, we will call this the ``ATen Dialect``, and since
``torch._export.capture_pre_autograd_graph`` returns a graph containing the
set of ATen operators which are Autograd safe, we will call it the
``Pre-Autograd ATen Dialect``.

.. GENERATED FROM PYTHON SOURCE LINES 109-115

Expressing Dynamism
^^^^^^^^^^^^^^^^^^^

By default, the exporting flow will trace the program assuming that all input
shapes are static, so if we run the program with inputs shapes that are
different than the ones we used while tracing, we will run into an error:

.. GENERATED FROM PYTHON SOURCE LINES 115-136

.. code-block:: default


    import traceback as tb


    def f(x, y):
        return x + y


    example_args = (torch.randn(3, 3), torch.randn(3, 3))
    pre_autograd_aten_dialect = capture_pre_autograd_graph(f, example_args)
    aten_dialect: ExportedProgram = export(f, example_args)

    # Works correctly
    print(aten_dialect(torch.ones(3, 3), torch.ones(3, 3)))

    # Errors
    try:
        print(aten_dialect(torch.ones(3, 2), torch.ones(3, 2)))
    except Exception:
        tb.print_exc()





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    tensor([[2., 2., 2.],
            [2., 2., 2.],
            [2., 2., 2.]])
    Traceback (most recent call last):
      File "/pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py", line 132, in <module>
        print(aten_dialect(torch.ones(3, 2), torch.ones(3, 2)))
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 369, in __call__
        self._check_input_constraints(*ordered_params, *ordered_buffers, *args)
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 664, in _check_input_constraints
        _assertion_graph(*args)
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 728, in call_wrapped
        return self._wrapped_call(self, *args, **kwargs)
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 307, in __call__
        raise e
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 294, in __call__
        return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1528, in _call_impl
        return forward_call(*args, **kwargs)
      File "<eval_with_key>.107", line 11, in forward
        _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, 'Input arg1_1.shape[1] is specialized at 3');  scalar_tensor = None
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_ops.py", line 516, in __call__
        return self._op(*args, **kwargs or {})
    RuntimeError: Input arg1_1.shape[1] is specialized at 3




.. GENERATED FROM PYTHON SOURCE LINES 137-139

To express that some input shapes are dynamic, we can insert constraints to
the exporting flow. This is done through the ``dynamic_dim`` API:

.. GENERATED FROM PYTHON SOURCE LINES 139-165

.. code-block:: default


    from torch.export import dynamic_dim


    def f(x, y):
        return x + y


    example_args = (torch.randn(3, 3), torch.randn(3, 3))
    constraints = [
        # Input 0, dimension 1 is dynamic
        dynamic_dim(example_args[0], 1),
        # Input 0, dimension 1 must be greater than or equal to 1
        1 <= dynamic_dim(example_args[0], 1),
        # Input 0, dimension 1 must be less than or equal to 10
        dynamic_dim(example_args[0], 1) <= 10,
        # Input 1, dimension 1 is equal to input 0, dimension 1
        dynamic_dim(example_args[1], 1) == dynamic_dim(example_args[0], 1),
    ]
    pre_autograd_aten_dialect = capture_pre_autograd_graph(
        f, example_args, constraints=constraints
    )
    aten_dialect: ExportedProgram = export(f, example_args, constraints=constraints)
    print("ATen Dialect Graph")
    print(aten_dialect)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    ATen Dialect Graph
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[3, s0], arg1_1: f32[3, s0]):
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:144, code: return x + y
                add: f32[3, s0] = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                return (add,)
            
    Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['add'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Range constraints: {s0: RangeConstraint(min_val=2, max_val=10)}
    Equality constraints: [(InputDim(input_name='arg1_1', dim=1), InputDim(input_name='arg0_1', dim=1))]





.. GENERATED FROM PYTHON SOURCE LINES 166-177

Note that that the inputs ``arg0_1`` and ``arg1_1`` now have shapes (3, s0),
with ``s0`` being a symbol representing that this dimension can be a range
of values.

Additionally, we can see in the **Range constraints** that value of ``s0`` has
the range [1, 10], which was specified by our constraints. We also see in the
**Equality constraints**, the tuple ``(InputDim(input_name='arg1_1', dim=1),
InputDim(input_name='arg0_1', dim=1))```, meaning that input 0's dimension 1
is equal to input 1's dimension 1, which was also specified by our constraints.

Now let's try running the model with different shapes:

.. GENERATED FROM PYTHON SOURCE LINES 177-195

.. code-block:: default


    # Works correctly
    print(aten_dialect(torch.ones(3, 3), torch.ones(3, 3)))
    print(aten_dialect(torch.ones(3, 2), torch.ones(3, 2)))

    # Errors because it violates our constraint that input 0, dim 1 <= 10
    try:
        print(aten_dialect(torch.ones(3, 15), torch.ones(3, 15)))
    except Exception:
        tb.print_exc()

    # Errors because it violates our constraint that input 0, dim 1 == input 1, dim 1
    try:
        print(aten_dialect(torch.ones(3, 3), torch.ones(3, 2)))
    except Exception:
        tb.print_exc()






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    tensor([[2., 2., 2.],
            [2., 2., 2.],
            [2., 2., 2.]])
    tensor([[2., 2.],
            [2., 2.],
            [2., 2.]])
    Traceback (most recent call last):
      File "/pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py", line 184, in <module>
        print(aten_dialect(torch.ones(3, 15), torch.ones(3, 15)))
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 369, in __call__
        self._check_input_constraints(*ordered_params, *ordered_buffers, *args)
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 664, in _check_input_constraints
        _assertion_graph(*args)
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 728, in call_wrapped
        return self._wrapped_call(self, *args, **kwargs)
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 307, in __call__
        raise e
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 294, in __call__
        return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1528, in _call_impl
        return forward_call(*args, **kwargs)
      File "<eval_with_key>.141", line 17, in forward
        _assert_async_2 = torch.ops.aten._assert_async.msg(scalar_tensor_2, 'Input arg0_1.shape[1] is outside of specified dynamic range [2, 10]');  scalar_tensor_2 = None
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_ops.py", line 516, in __call__
        return self._op(*args, **kwargs or {})
    RuntimeError: Input arg0_1.shape[1] is outside of specified dynamic range [2, 10]
    Traceback (most recent call last):
      File "/pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py", line 190, in <module>
        print(aten_dialect(torch.ones(3, 3), torch.ones(3, 2)))
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 369, in __call__
        self._check_input_constraints(*ordered_params, *ordered_buffers, *args)
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 664, in _check_input_constraints
        _assertion_graph(*args)
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 728, in call_wrapped
        return self._wrapped_call(self, *args, **kwargs)
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 307, in __call__
        raise e
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 294, in __call__
        return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1528, in _call_impl
        return forward_call(*args, **kwargs)
      File "<eval_with_key>.146", line 11, in forward
        _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, 'Input arg1_1.shape[1] is not equal to input arg0_1.shape[1]');  scalar_tensor = None
      File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_ops.py", line 516, in __call__
        return self._op(*args, **kwargs or {})
    RuntimeError: Input arg1_1.shape[1] is not equal to input arg0_1.shape[1]




.. GENERATED FROM PYTHON SOURCE LINES 196-206

Addressing Untraceable Code
^^^^^^^^^^^^^^^^^^^^^^^^^^^

As our goal is to capture the entire computational graph from a PyTorch
program, we might ultimately run into untraceable parts of programs. To
address these issues, the
`torch.export documentation <https://pytorch.org/docs/2.1/export.html#limitations-of-torch-export>`__,
or the
`torch.export tutorial <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html>`__
would be the best place to look.

.. GENERATED FROM PYTHON SOURCE LINES 208-223

Performing Quantization
-----------------------

To quantize a model, we can do so between the call to
``torch._export.capture_pre_autograd_graph`` and ``torch.export``, in the
``Pre-Autograd ATen Dialect``. This is because quantization must operate at a
level which is safe for eager mode training.

Compared to
`FX Graph Mode Quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html>`__,
we will need to call two new APIs: ``prepare_pt2e`` and ``convert_pt2e``
instead of ``prepare_fx`` and ``convert_fx``. It differs in that
``prepare_pt2e`` takes a backend-specific ``Quantizer`` as an argument, which
will annotate the nodes in the graph with information needed to quantize the
model properly for a specific backend.

.. GENERATED FROM PYTHON SOURCE LINES 223-246

.. code-block:: default


    example_args = (torch.randn(1, 3, 256, 256),)
    pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args)
    print("Pre-Autograd ATen Dialect Graph")
    print(pre_autograd_aten_dialect)

    from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
    from torch.ao.quantization.quantizer.xnnpack_quantizer import (
        get_symmetric_quantization_config,
        XNNPACKQuantizer,
    )

    quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
    prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer)
    # calibrate with a sample dataset
    converted_graph = convert_pt2e(prepared_graph)
    print("Quantized Graph")
    print(converted_graph)

    aten_dialect: ExportedProgram = export(converted_graph, example_args)
    print("ATen Dialect Graph")
    print(aten_dialect)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Pre-Autograd ATen Dialect Graph
    GraphModule()



    def forward(self, x):
        arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
        _param_constant0 = self._param_constant0
        _param_constant1 = self._param_constant1
        conv2d_default = torch.ops.aten.conv2d.default(arg0, _param_constant0, _param_constant1, [1, 1], [1, 1]);  arg0 = _param_constant0 = _param_constant1 = None
        relu_default = torch.ops.aten.relu.default(conv2d_default);  conv2d_default = None
        return pytree.tree_unflatten([relu_default], self._out_spec)
    
    # To see more debug info, please use `graph_module.print_readable()`
    /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/ao/quantization/observer.py:1220: UserWarning: must run observer before calling calculate_qparams.                                    Returning default scale and zero point 
      warnings.warn(
    /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/ao/quantization/utils.py:339: UserWarning: must run observer before calling calculate_qparams. Returning default values.
      warnings.warn(
    Quantized Graph
    GraphModule()



    def forward(self, x):
        arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
        quantize_per_tensor_default = torch.ops.quantized_decomposed.quantize_per_tensor.default(arg0, 1.0, 0, -128, 127, torch.int8);  arg0 = None
        dequantize_per_tensor_default = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default, 1.0, 0, -128, 127, torch.int8);  quantize_per_tensor_default = None
        _param_constant0 = self._param_constant0
        quantize_per_tensor_default_1 = torch.ops.quantized_decomposed.quantize_per_tensor.default(_param_constant0, 1.0, 0, -127, 127, torch.int8);  _param_constant0 = None
        dequantize_per_tensor_default_1 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_1, 1.0, 0, -127, 127, torch.int8);  quantize_per_tensor_default_1 = None
        _param_constant1 = self._param_constant1
        conv2d_default = torch.ops.aten.conv2d.default(dequantize_per_tensor_default, dequantize_per_tensor_default_1, _param_constant1, [1, 1], [1, 1]);  dequantize_per_tensor_default = dequantize_per_tensor_default_1 = _param_constant1 = None
        relu_default = torch.ops.aten.relu.default(conv2d_default);  conv2d_default = None
        quantize_per_tensor_default_2 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu_default, 1.0, 0, -128, 127, torch.int8);  relu_default = None
        dequantize_per_tensor_default_2 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_2, 1.0, 0, -128, 127, torch.int8);  quantize_per_tensor_default_2 = None
        return pytree.tree_unflatten([dequantize_per_tensor_default_2], self._out_spec)
    
    # To see more debug info, please use `graph_module.print_readable()`
    ATen Dialect Graph
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256]):
                # No stacktrace found for following nodes
                quantize_per_tensor: i8[1, 3, 256, 256] = torch.ops.quantized_decomposed.quantize_per_tensor.default(arg2_1, 1.0, 0, -128, 127, torch.int8);  arg2_1 = None
            
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:68, code: a = self.conv(x)
                dequantize_per_tensor: f32[1, 3, 256, 256] = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor, 1.0, 0, -128, 127, torch.int8);  quantize_per_tensor = None
                quantize_per_tensor_1: i8[16, 3, 3, 3] = torch.ops.quantized_decomposed.quantize_per_tensor.default(arg0_1, 1.0, 0, -127, 127, torch.int8);  arg0_1 = None
                dequantize_per_tensor_1: f32[16, 3, 3, 3] = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_1, 1.0, 0, -127, 127, torch.int8);  quantize_per_tensor_1 = None
                convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(dequantize_per_tensor, dequantize_per_tensor_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  dequantize_per_tensor = dequantize_per_tensor_1 = arg1_1 = None
            
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:69, code: return self.relu(a)
                relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(convolution);  convolution = None
                quantize_per_tensor_2: i8[1, 16, 256, 256] = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu, 1.0, 0, -128, 127, torch.int8);  relu = None
            
                # No stacktrace found for following nodes
                dequantize_per_tensor_2: f32[1, 16, 256, 256] = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_2, 1.0, 0, -128, 127, torch.int8);  quantize_per_tensor_2 = None
                return (dequantize_per_tensor_2,)
            
    Graph signature: ExportGraphSignature(parameters=['_param_constant0', '_param_constant1'], buffers=[], user_inputs=['arg2_1'], user_outputs=['dequantize_per_tensor_2'], inputs_to_parameters={'arg0_1': '_param_constant0', 'arg1_1': '_param_constant1'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Range constraints: {}
    Equality constraints: []





.. GENERATED FROM PYTHON SOURCE LINES 247-250

More information on how to quantize a model, and how a backend can implement a
``Quantizer`` can be found
`here <https://pytorch.org/docs/main/quantization.html#prototype-pytorch-2-export-quantization>`__.

.. GENERATED FROM PYTHON SOURCE LINES 252-268

Lowering to Edge Dialect
------------------------

After exporting and lowering the graph to the ``ATen Dialect``, the next step
is to lower to the ``Edge Dialect``, in which specializations that are useful
for edge devices but not necessary for general (server) environments will be
applied.
Some of these specializations include:

- DType specialization
- Scalar to tensor conversion
- Converting all ops to the ``executorch.exir.dialects.edge`` namespace.

Note that this dialect is still backend (or target) agnostic.

The lowering is done through the ``to_edge`` API.

.. GENERATED FROM PYTHON SOURCE LINES 268-284

.. code-block:: default


    from executorch.exir import EdgeProgramManager, to_edge

    example_args = (torch.randn(1, 3, 256, 256),)
    pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args)
    print("Pre-Autograd ATen Dialect Graph")
    print(pre_autograd_aten_dialect)

    aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
    print("ATen Dialect Graph")
    print(aten_dialect)

    edge_program: EdgeProgramManager = to_edge(aten_dialect)
    print("Edge Dialect Graph")
    print(edge_program.exported_program())





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Pre-Autograd ATen Dialect Graph
    GraphModule()



    def forward(self, x):
        arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
        _param_constant0 = self._param_constant0
        _param_constant1 = self._param_constant1
        conv2d_default = torch.ops.aten.conv2d.default(arg0, _param_constant0, _param_constant1, [1, 1], [1, 1]);  arg0 = _param_constant0 = _param_constant1 = None
        relu_default = torch.ops.aten.relu.default(conv2d_default);  conv2d_default = None
        return pytree.tree_unflatten([relu_default], self._out_spec)
    
    # To see more debug info, please use `graph_module.print_readable()`
    ATen Dialect Graph
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256]):
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:68, code: a = self.conv(x)
                convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg2_1 = arg0_1 = arg1_1 = None
            
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:69, code: return self.relu(a)
                relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(convolution);  convolution = None
                return (relu,)
            
    Graph signature: ExportGraphSignature(parameters=['_param_constant0', '_param_constant1'], buffers=[], user_inputs=['arg2_1'], user_outputs=['relu'], inputs_to_parameters={'arg0_1': '_param_constant0', 'arg1_1': '_param_constant1'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Range constraints: {}
    Equality constraints: []

    Edge Dialect Graph
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256]):
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:68, code: a = self.conv(x)
                aten_convolution_default: f32[1, 16, 256, 256] = executorch_exir_dialects_edge__ops_aten_convolution_default(arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg2_1 = arg0_1 = arg1_1 = None
            
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:69, code: return self.relu(a)
                aten_relu_default: f32[1, 16, 256, 256] = executorch_exir_dialects_edge__ops_aten_relu_default(aten_convolution_default);  aten_convolution_default = None
                return (aten_relu_default,)
            
    Graph signature: ExportGraphSignature(parameters=['_param_constant0', '_param_constant1'], buffers=[], user_inputs=['arg2_1'], user_outputs=['aten_relu_default'], inputs_to_parameters={'arg0_1': '_param_constant0', 'arg1_1': '_param_constant1'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Range constraints: {}
    Equality constraints: []





.. GENERATED FROM PYTHON SOURCE LINES 285-289

``to_edge()`` returns an ``EdgeProgramManager`` object, which contains the
exported programs which will be placed on this device. This data structure
allows users to export multiple programs and combine them into one binary. If
there is only one program, it will by default be saved to the name "forward".

.. GENERATED FROM PYTHON SOURCE LINES 289-318

.. code-block:: default



    def encode(x):
        return torch.nn.functional.linear(x, torch.randn(5, 10))


    def decode(x):
        return torch.nn.functional.linear(x, torch.randn(10, 5))


    encode_args = (torch.randn(1, 10),)
    aten_encode: ExportedProgram = export(
        capture_pre_autograd_graph(encode, encode_args),
        encode_args,
    )

    decode_args = (torch.randn(1, 5),)
    aten_decode: ExportedProgram = export(
        capture_pre_autograd_graph(decode, decode_args),
        decode_args,
    )

    edge_program: EdgeProgramManager = to_edge(
        {"encode": aten_encode, "decode": aten_decode}
    )
    for method in edge_program.methods:
        print(f"Edge Dialect graph of {method}")
        print(edge_program.exported_program(method))





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Edge Dialect graph of encode
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[1, 10]):
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:292, code: return torch.nn.functional.linear(x, torch.randn(5, 10))
                aten_randn_default: f32[5, 10] = executorch_exir_dialects_edge__ops_aten_randn_default([5, 10], device = device(type='cpu'), pin_memory = False)
                aten_permute_copy_default: f32[10, 5] = executorch_exir_dialects_edge__ops_aten_permute_copy_default(aten_randn_default, [1, 0]);  aten_randn_default = None
                aten_mm_default: f32[1, 5] = executorch_exir_dialects_edge__ops_aten_mm_default(arg0_1, aten_permute_copy_default);  arg0_1 = aten_permute_copy_default = None
                return (aten_mm_default,)
            
    Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['aten_mm_default'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Range constraints: {}
    Equality constraints: []

    Edge Dialect graph of decode
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[1, 5]):
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:296, code: return torch.nn.functional.linear(x, torch.randn(10, 5))
                aten_randn_default: f32[10, 5] = executorch_exir_dialects_edge__ops_aten_randn_default([10, 5], device = device(type='cpu'), pin_memory = False)
                aten_permute_copy_default: f32[5, 10] = executorch_exir_dialects_edge__ops_aten_permute_copy_default(aten_randn_default, [1, 0]);  aten_randn_default = None
                aten_mm_default: f32[1, 10] = executorch_exir_dialects_edge__ops_aten_mm_default(arg0_1, aten_permute_copy_default);  arg0_1 = aten_permute_copy_default = None
                return (aten_mm_default,)
            
    Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['aten_mm_default'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Range constraints: {}
    Equality constraints: []





.. GENERATED FROM PYTHON SOURCE LINES 319-328

We can also run additional passes on the exported program through
the ``transform`` API. An in-depth documentation on how to write
transformations can be found
`here <../compiler-custom-compiler-passes.html>`__.

Note that since the graph is now in the Edge Dialect, all passes must also
result in a valid Edge Dialect graph (specifically one thing to point out is
that the operators are now in the ``executorch.exir.dialects.edge`` namespace,
rather than the ``torch.ops.aten`` namespace.

.. GENERATED FROM PYTHON SOURCE LINES 328-354

.. code-block:: default


    example_args = (torch.randn(1, 3, 256, 256),)
    pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args)
    aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
    edge_program: EdgeProgramManager = to_edge(aten_dialect)
    print("Edge Dialect Graph")
    print(edge_program.exported_program())

    from executorch.exir.dialects._ops import ops as exir_ops
    from executorch.exir.pass_base import ExportPass


    class ConvertReluToSigmoid(ExportPass):
        def call_operator(self, op, args, kwargs, meta):
            if op == exir_ops.edge.aten.relu.default:
                return super().call_operator(
                    exir_ops.edge.aten.sigmoid.default, args, kwargs, meta
                )
            else:
                return super().call_operator(op, args, kwargs, meta)


    transformed_edge_program = edge_program.transform((ConvertReluToSigmoid(),))
    print("Transformed Edge Dialect Graph")
    print(transformed_edge_program.exported_program())





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Edge Dialect Graph
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256]):
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:68, code: a = self.conv(x)
                aten_convolution_default: f32[1, 16, 256, 256] = executorch_exir_dialects_edge__ops_aten_convolution_default(arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg2_1 = arg0_1 = arg1_1 = None
            
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:69, code: return self.relu(a)
                aten_relu_default: f32[1, 16, 256, 256] = executorch_exir_dialects_edge__ops_aten_relu_default(aten_convolution_default);  aten_convolution_default = None
                return (aten_relu_default,)
            
    Graph signature: ExportGraphSignature(parameters=['_param_constant0', '_param_constant1'], buffers=[], user_inputs=['arg2_1'], user_outputs=['aten_relu_default'], inputs_to_parameters={'arg0_1': '_param_constant0', 'arg1_1': '_param_constant1'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Range constraints: {}
    Equality constraints: []

    Transformed Edge Dialect Graph
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256]):
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:68, code: a = self.conv(x)
                aten_convolution_default: f32[1, 16, 256, 256] = executorch_exir_dialects_edge__ops_aten_convolution_default(arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg2_1 = arg0_1 = arg1_1 = None
            
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:69, code: return self.relu(a)
                aten_sigmoid_default: f32[1, 16, 256, 256] = executorch_exir_dialects_edge__ops_aten_sigmoid_default(aten_convolution_default);  aten_convolution_default = None
                return (aten_sigmoid_default,)
            
    Graph signature: ExportGraphSignature(parameters=['_param_constant0', '_param_constant1'], buffers=[], user_inputs=['arg2_1'], user_outputs=['aten_sigmoid_default'], inputs_to_parameters={'arg0_1': '_param_constant0', 'arg1_1': '_param_constant1'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Range constraints: {}
    Equality constraints: []





.. GENERATED FROM PYTHON SOURCE LINES 355-370

Delegating to a Backend
-----------------------

We can now delegate parts of the graph or the whole graph to a third-party
backend through the ``to_backend`` API.  An in-depth documentation on the
specifics of backend delegation, including how to delegate to a backend and
how to implement a backend, can be found
`here <../compiler-delegate-and-partitioner.html>`__

There are three ways for using this API:

1. We can lower the whole module.
2. We can take the lowered module, and insert it in another larger module.
3. We can partition the module into subgraphs that are lowerable, and then
   lower those subgraphs to a backend.

.. GENERATED FROM PYTHON SOURCE LINES 372-378

Lowering the Whole Module
^^^^^^^^^^^^^^^^^^^^^^^^^

To lower an entire module, we can pass ``to_backend`` the backend name, the
module to be lowered, and a list of compile specs to help the backend with the
lowering process.

.. GENERATED FROM PYTHON SOURCE LINES 378-416

.. code-block:: default



    class LowerableModule(torch.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, x):
            return torch.sin(x)


    # Export and lower the module to Edge Dialect
    example_args = (torch.ones(1),)
    pre_autograd_aten_dialect = capture_pre_autograd_graph(LowerableModule(), example_args)
    aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
    edge_program: EdgeProgramManager = to_edge(aten_dialect)
    to_be_lowered_module = edge_program.exported_program()

    from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend

    # Import the backend
    from executorch.exir.backend.test.backend_with_compiler_demo import (  # noqa
        BackendWithCompilerDemo,
    )

    # Lower the module
    lowered_module: LoweredBackendModule = to_backend(
        "BackendWithCompilerDemo", to_be_lowered_module, []
    )
    print(lowered_module)
    print(lowered_module.backend_id)
    print(lowered_module.processed_bytes)
    print(lowered_module.original_module)

    # Serialize and save it to a file
    save_path = "delegate.pte"
    with open(save_path, "wb") as f:
        f.write(lowered_module.buffer())





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    LoweredBackendModule()
    BackendWithCompilerDemo
    b'1#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#'
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[1]):
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:385, code: return torch.sin(x)
                aten_sin_default: f32[1] = executorch_exir_dialects_edge__ops_aten_sin_default(arg0_1);  arg0_1 = None
                return (aten_sin_default,)
            
    Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['aten_sin_default'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Range constraints: {}
    Equality constraints: []

    /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/_pytree.py:590: UserWarning: pytree_to_str is deprecated. Please use treespec_dumps
      warnings.warn("pytree_to_str is deprecated. Please use treespec_dumps")




.. GENERATED FROM PYTHON SOURCE LINES 417-425

In this call, ``to_backend`` will return a ``LoweredBackendModule``. Some
important attributes of the ``LoweredBackendModule`` are:

- ``backend_id``: The name of the backend this lowered module will run on in
  the runtime
- ``processed_bytes``: a binary blob which will tell the backend how to run
  this program in the runtime
- ``original_module``: the original exported module

.. GENERATED FROM PYTHON SOURCE LINES 427-432

Compose the Lowered Module into Another Module
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In cases where we want to reuse this lowered module in multiple programs, we
can compose this lowered module with another module.

.. GENERATED FROM PYTHON SOURCE LINES 432-468

.. code-block:: default



    class NotLowerableModule(torch.nn.Module):
        def __init__(self, bias):
            super().__init__()
            self.bias = bias

        def forward(self, a, b):
            return torch.add(torch.add(a, b), self.bias)


    class ComposedModule(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.non_lowerable = NotLowerableModule(torch.ones(1) * 0.3)
            self.lowerable = lowered_module

        def forward(self, x):
            a = self.lowerable(x)
            b = self.lowerable(a)
            ret = self.non_lowerable(a, b)
            return a, b, ret


    example_args = (torch.ones(1),)
    pre_autograd_aten_dialect = capture_pre_autograd_graph(ComposedModule(), example_args)
    aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
    edge_program: EdgeProgramManager = to_edge(aten_dialect)
    exported_program = edge_program.exported_program()
    print("Edge Dialect graph")
    print(exported_program)
    print("Lowered Module within the graph")
    print(exported_program.graph_module.lowered_module_0.backend_id)
    print(exported_program.graph_module.lowered_module_0.processed_bytes)
    print(exported_program.graph_module.lowered_module_0.original_module)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Edge Dialect graph
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[1], arg1_1: f32[1]):
                # File: /pytorch/executorch/exir/lowered_backend_module.py:273, code: return executorch_call_delegate(self, *args)
                lowered_module_0 = self.lowered_module_0
                executorch_call_delegate: f32[1] = torch.ops.executorch_call_delegate(lowered_module_0, arg1_1);  lowered_module_0 = arg1_1 = None
            
                # File: /pytorch/executorch/exir/lowered_backend_module.py:273, code: return executorch_call_delegate(self, *args)
                lowered_module_1 = self.lowered_module_0
                executorch_call_delegate_1: f32[1] = torch.ops.executorch_call_delegate(lowered_module_1, executorch_call_delegate);  lowered_module_1 = None
            
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:440, code: return torch.add(torch.add(a, b), self.bias)
                aten_add_tensor: f32[1] = executorch_exir_dialects_edge__ops_aten_add_Tensor(executorch_call_delegate, executorch_call_delegate_1)
                aten_add_tensor_1: f32[1] = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_add_tensor, arg0_1);  aten_add_tensor = arg0_1 = None
                return (executorch_call_delegate, executorch_call_delegate_1, aten_add_tensor_1)
            
    Graph signature: ExportGraphSignature(parameters=[], buffers=['_tensor_constant0'], user_inputs=['arg1_1'], user_outputs=['executorch_call_delegate', 'executorch_call_delegate_1', 'aten_add_tensor_1'], inputs_to_parameters={}, inputs_to_buffers={'arg0_1': '_tensor_constant0'}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Range constraints: {}
    Equality constraints: []

    Lowered Module within the graph
    BackendWithCompilerDemo
    b'1#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#'
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[1]):
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:385, code: return torch.sin(x)
                aten_sin_default: f32[1] = executorch_exir_dialects_edge__ops_aten_sin_default(arg0_1);  arg0_1 = None
                return (aten_sin_default,)
            
    Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['aten_sin_default'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Range constraints: {}
    Equality constraints: []





.. GENERATED FROM PYTHON SOURCE LINES 469-473

Notice that there is now a ``torch.ops.executorch_call_delegate`` node in the
graph, which is calling ``lowered_module_0``. Additionally, the contents of
``lowered_module_0`` are the same as the ``lowered_module`` we created
previously.

.. GENERATED FROM PYTHON SOURCE LINES 475-483

Partition and Lower Parts of a Module
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

A separate lowering flow is to pass ``to_backend`` the module that we want to
lower, and a backend-specific partitioner. ``to_backend`` will use the
backend-specific partitioner to tag nodes in the module which are lowerable,
partition those nodes into subgraphs, and then create a
``LoweredBackendModule`` for each of those subgraphs.

.. GENERATED FROM PYTHON SOURCE LINES 483-510

.. code-block:: default



    def f(a, x, b):
        y = torch.mm(a, x)
        z = y + b
        a = z - a
        y = torch.mm(a, x)
        z = y + b
        return z


    example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
    pre_autograd_aten_dialect = capture_pre_autograd_graph(f, example_args)
    aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
    edge_program: EdgeProgramManager = to_edge(aten_dialect)
    exported_program = edge_program.exported_program()
    print("Edge Dialect graph")
    print(exported_program)

    from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo

    delegated_program = to_backend(exported_program, AddMulPartitionerDemo)
    print("Delegated program")
    print(delegated_program)
    print(delegated_program.graph_module.lowered_module_0.original_module)
    print(delegated_program.graph_module.lowered_module_1.original_module)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Edge Dialect graph
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[2, 2], arg1_1: f32[2, 2], arg2_1: f32[2, 2]):
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:486, code: y = torch.mm(a, x)
                aten_mm_default: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_mm_default(arg0_1, arg1_1)
            
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:487, code: z = y + b
                aten_add_tensor: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_mm_default, arg2_1);  aten_mm_default = None
            
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:488, code: a = z - a
                aten_sub_tensor: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_add_tensor, arg0_1);  aten_add_tensor = arg0_1 = None
            
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:489, code: y = torch.mm(a, x)
                aten_mm_default_1: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_mm_default(aten_sub_tensor, arg1_1);  aten_sub_tensor = arg1_1 = None
            
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:490, code: z = y + b
                aten_add_tensor_1: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_mm_default_1, arg2_1);  aten_mm_default_1 = arg2_1 = None
                return (aten_add_tensor_1,)
            
    Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1', 'arg2_1'], user_outputs=['aten_add_tensor_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Range constraints: {}
    Equality constraints: []

    Delegated program
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[2, 2], arg1_1: f32[2, 2], arg2_1: f32[2, 2]):
                # No stacktrace found for following nodes
                lowered_module_0 = self.lowered_module_0
                executorch_call_delegate = torch.ops.executorch_call_delegate(lowered_module_0, arg0_1, arg1_1, arg2_1);  lowered_module_0 = None
                getitem: f32[2, 2] = executorch_call_delegate[0];  executorch_call_delegate = None
            
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:488, code: a = z - a
                aten_sub_tensor: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_sub_Tensor(getitem, arg0_1);  getitem = arg0_1 = None
            
                # No stacktrace found for following nodes
                lowered_module_1 = self.lowered_module_1
                executorch_call_delegate_1 = torch.ops.executorch_call_delegate(lowered_module_1, aten_sub_tensor, arg1_1, arg2_1);  lowered_module_1 = aten_sub_tensor = arg1_1 = arg2_1 = None
                getitem_1: f32[2, 2] = executorch_call_delegate_1[0];  executorch_call_delegate_1 = None
                return (getitem_1,)
            
    Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1', 'arg2_1'], user_outputs=['getitem_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Range constraints: {}
    Equality constraints: []

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[2, 2], arg1_1: f32[2, 2], arg2_1: f32[2, 2]):
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:486, code: y = torch.mm(a, x)
                aten_mm_default: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_mm_default(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
            
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:487, code: z = y + b
                aten_add_tensor: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_mm_default, arg2_1);  aten_mm_default = arg2_1 = None
                return [aten_add_tensor]
            
    Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1', 'arg2_1'], user_outputs=['aten_add_tensor'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Range constraints: {}
    Equality constraints: []

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, aten_sub_tensor: f32[2, 2], arg1_1: f32[2, 2], arg2_1: f32[2, 2]):
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:489, code: y = torch.mm(a, x)
                aten_mm_default_1: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_mm_default(aten_sub_tensor, arg1_1);  aten_sub_tensor = arg1_1 = None
            
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:490, code: z = y + b
                aten_add_tensor_1: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_mm_default_1, arg2_1);  aten_mm_default_1 = arg2_1 = None
                return [aten_add_tensor_1]
            
    Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['aten_sub_tensor', 'arg1_1', 'arg2_1'], user_outputs=['aten_add_tensor_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Range constraints: {}
    Equality constraints: []





.. GENERATED FROM PYTHON SOURCE LINES 511-517

Notice that there are now 2 ``torch.ops.executorch_call_delegate`` nodes in the
graph, one containing the operations `add, mul` and the other containing the
operations `mul, add`.

Alternatively, a more cohesive API to lower parts of a module is to directly
call ``to_backend`` on it:

.. GENERATED FROM PYTHON SOURCE LINES 517-538

.. code-block:: default



    def f(a, x, b):
        y = torch.mm(a, x)
        z = y + b
        a = z - a
        y = torch.mm(a, x)
        z = y + b
        return z


    example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
    pre_autograd_aten_dialect = capture_pre_autograd_graph(f, example_args)
    aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
    edge_program: EdgeProgramManager = to_edge(aten_dialect)
    exported_program = edge_program.exported_program()
    delegated_program = edge_program.to_backend(AddMulPartitionerDemo)

    print("Delegated program")
    print(delegated_program.exported_program())





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Delegated program
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[2, 2], arg1_1: f32[2, 2], arg2_1: f32[2, 2]):
                # No stacktrace found for following nodes
                lowered_module_0 = self.lowered_module_0
                executorch_call_delegate = torch.ops.executorch_call_delegate(lowered_module_0, arg0_1, arg1_1, arg2_1);  lowered_module_0 = None
                getitem: f32[2, 2] = executorch_call_delegate[0];  executorch_call_delegate = None
            
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:522, code: a = z - a
                aten_sub_tensor: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_sub_Tensor(getitem, arg0_1);  getitem = arg0_1 = None
            
                # No stacktrace found for following nodes
                lowered_module_1 = self.lowered_module_1
                executorch_call_delegate_1 = torch.ops.executorch_call_delegate(lowered_module_1, aten_sub_tensor, arg1_1, arg2_1);  lowered_module_1 = aten_sub_tensor = arg1_1 = arg2_1 = None
                getitem_1: f32[2, 2] = executorch_call_delegate_1[0];  executorch_call_delegate_1 = None
                return (getitem_1,)
            
    Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1', 'arg2_1'], user_outputs=['getitem_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Range constraints: {}
    Equality constraints: []





.. GENERATED FROM PYTHON SOURCE LINES 539-551

Running User-Defined Passes and Memory Planning
-----------------------------------------------

As a final step of lowering, we can use the ``to_executorch()`` API to pass in
backend-specific passes, such as replacing sets of operators with a custom
backend operator, and a memory planning pass, to tell the runtime how to
allocate memory ahead of time when running the program.

A default memory planning pass is provided, but we can also choose a
backend-specific memory planning pass if it exists. More information on
writing a custom memory planning pass can be found
`here <../compiler-memory-planning.html>`__

.. GENERATED FROM PYTHON SOURCE LINES 551-569

.. code-block:: default


    from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager
    from executorch.exir.passes import MemoryPlanningPass

    executorch_program: ExecutorchProgramManager = edge_program.to_executorch(
        ExecutorchBackendConfig(
            passes=[],  # User-defined passes
            memory_planning_pass=MemoryPlanningPass(
                "greedy"
            ),  # Default memory planning pass
        )
    )

    print("ExecuTorch Dialect")
    print(executorch_program.exported_program())

    import executorch.exir as exir





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/_pytree.py:590: UserWarning: pytree_to_str is deprecated. Please use treespec_dumps
      warnings.warn("pytree_to_str is deprecated. Please use treespec_dumps")
    ExecuTorch Dialect
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[2, 2], arg1_1: f32[2, 2], arg2_1: f32[2, 2]):
                # No stacktrace found for following nodes
                alloc: f32[2, 2] = executorch_exir_memory_alloc(((2, 2), torch.float32))
            
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:520, code: y = torch.mm(a, x)
                aten_mm_default: f32[2, 2] = torch.ops.aten.mm.out(arg0_1, arg1_1, out = alloc);  alloc = None
            
                # No stacktrace found for following nodes
                alloc_1: f32[2, 2] = executorch_exir_memory_alloc(((2, 2), torch.float32))
            
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:521, code: z = y + b
                aten_add_tensor: f32[2, 2] = torch.ops.aten.add.out(aten_mm_default, arg2_1, out = alloc_1);  aten_mm_default = alloc_1 = None
            
                # No stacktrace found for following nodes
                alloc_2: f32[2, 2] = executorch_exir_memory_alloc(((2, 2), torch.float32))
            
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:522, code: a = z - a
                aten_sub_tensor: f32[2, 2] = torch.ops.aten.sub.out(aten_add_tensor, arg0_1, out = alloc_2);  aten_add_tensor = arg0_1 = alloc_2 = None
            
                # No stacktrace found for following nodes
                alloc_3: f32[2, 2] = executorch_exir_memory_alloc(((2, 2), torch.float32))
            
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:523, code: y = torch.mm(a, x)
                aten_mm_default_1: f32[2, 2] = torch.ops.aten.mm.out(aten_sub_tensor, arg1_1, out = alloc_3);  aten_sub_tensor = arg1_1 = alloc_3 = None
            
                # No stacktrace found for following nodes
                alloc_4: f32[2, 2] = executorch_exir_memory_alloc(((2, 2), torch.float32))
            
                # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:524, code: z = y + b
                aten_add_tensor_1: f32[2, 2] = torch.ops.aten.add.out(aten_mm_default_1, arg2_1, out = alloc_4);  aten_mm_default_1 = arg2_1 = alloc_4 = None
                return (aten_add_tensor_1,)
            
    Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1', 'arg2_1'], user_outputs=['aten_add_tensor_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Range constraints: {}
    Equality constraints: []





.. GENERATED FROM PYTHON SOURCE LINES 570-586

Notice that in the graph we now see operators like ``torch.ops.aten.sub.out``
and ``torch.ops.aten.div.out`` rather than ``torch.ops.aten.sub.Tensor`` and
``torch.ops.aten.div.Tensor``.

This is because between running the backend passes and memory planning passes,
to prepare the graph for memory planning, an out-variant pass is run on
the graph to convert all of the operators to their out variants. Instead of
allocating returned tensors in the kernel implementations, an operator's
``out`` variant will take in a prealloacated tensor to its out kwarg, and
store the result there, making it easier for memory planners to do tensor
lifetime analysis.

We also insert ``alloc`` nodes into the graph containing calls to a special
``executorch.exir.memory.alloc`` operator. This tells us how much memory is
needed to allocate each tensor output by the out-variant operator.


.. GENERATED FROM PYTHON SOURCE LINES 588-595

Saving to a File
----------------

Finally, we can save the ExecuTorch Program to a file and load it to a device
to be run.

Here is an example for an entire end-to-end workflow:

.. GENERATED FROM PYTHON SOURCE LINES 595-628

.. code-block:: default


    import torch
    from torch._export import capture_pre_autograd_graph
    from torch.export import export, ExportedProgram


    class M(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.param = torch.nn.Parameter(torch.rand(3, 4))
            self.linear = torch.nn.Linear(4, 5)

        def forward(self, x):
            return self.linear(x + self.param).clamp(min=0.0, max=1.0)


    example_args = (torch.randn(3, 4),)
    pre_autograd_aten_dialect = capture_pre_autograd_graph(M(), example_args)
    # Optionally do quantization:
    # pre_autograd_aten_dialect = convert_pt2e(prepare_pt2e(pre_autograd_aten_dialect, CustomBackendQuantizer))
    aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
    edge_program: exir.EdgeProgramManager = exir.to_edge(aten_dialect)
    # Optionally do delegation:
    # edge_program = edge_program.to_backend(CustomBackendPartitioner)
    executorch_program: exir.ExecutorchProgramManager = edge_program.to_executorch(
        ExecutorchBackendConfig(
            passes=[],  # User-defined passes
        )
    )

    with open("model.pte", "wb") as file:
        file.write(executorch_program.buffer)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/_pytree.py:590: UserWarning: pytree_to_str is deprecated. Please use treespec_dumps
      warnings.warn("pytree_to_str is deprecated. Please use treespec_dumps")




.. GENERATED FROM PYTHON SOURCE LINES 629-644

Conclusion
----------

In this tutorial, we went over the APIs and steps required to lower a PyTorch
program to a file that can be run on the ExecuTorch runtime.

Links Mentioned
^^^^^^^^^^^^^^^

- `torch.export Documentation <https://pytorch.org/docs/2.1/export.html>`__
- `Quantization Documentation <https://pytorch.org/docs/main/quantization.html#prototype-pytorch-2-export-quantization>`__
- `IR Spec <../ir-exir.html>`__
- `Writing Compiler Passes + Partitioner Documentation <../compiler-custom-compiler-passes.html>`__
- `Backend Delegation Documentation <../compiler-delegate-and-partitioner.html>`__
- `Memory Planning Documentation <../compiler-memory-planning.html>`__


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 4.224 seconds)


.. _sphx_glr_download_tutorials_export-to-executorch-tutorial.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example




    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: export-to-executorch-tutorial.py <export-to-executorch-tutorial.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: export-to-executorch-tutorial.ipynb <export-to-executorch-tutorial.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_