.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/export-to-executorch-tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end <sphx_glr_download_tutorials_export-to-executorch-tutorial.py>` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_export-to-executorch-tutorial.py: Exporting to ExecuTorch Tutorial ================================ **Author:** `Angela Yi <https://github.com/angelayi>`__ .. GENERATED FROM PYTHON SOURCE LINES 16-29 ExecuTorch is a unified ML stack for lowering PyTorch models to edge devices. It introduces improved entry points to perform model, device, and/or use-case specific optimizations such as backend delegation, user-defined compiler transformations, default or user-defined memory planning, and more. At a high level, the workflow looks as follows: .. image:: ../executorch_stack.png :width: 560 In this tutorial, we will cover the APIs in the "Program preparation" steps to lower a PyTorch model to a format which can be loaded to device and run on the ExecuTorch runtime. .. GENERATED FROM PYTHON SOURCE LINES 31-36 Prerequisites ------------- To run this tutorial, you’ll first need to `Set up your ExecuTorch environment <../getting-started-setup.html>`__. .. GENERATED FROM PYTHON SOURCE LINES 38-53 Exporting a Model ----------------- Note: The Export APIs are still undergoing changes to align better with the longer term state of export. Please refer to this `issue <https://github.com/pytorch/executorch/issues/290>`__ for more details. The first step of lowering to ExecuTorch is to export the given model (any callable or ``torch.nn.Module``) to a graph representation. This is done via the two-stage APIs, ``torch._export.capture_pre_autograd_graph``, and ``torch.export``. Both APIs take in a model (any callable or ``torch.nn.Module``), a tuple of positional arguments, optionally a dictionary of keyword arguments (not shown in the example), and a list of constraints (covered later). .. GENERATED FROM PYTHON SOURCE LINES 53-81 .. code-block:: default import torch from torch._export import capture_pre_autograd_graph from torch.export import export, ExportedProgram class SimpleConv(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d( in_channels=3, out_channels=16, kernel_size=3, padding=1 ) self.relu = torch.nn.ReLU() def forward(self, x: torch.Tensor) -> torch.Tensor: a = self.conv(x) return self.relu(a) example_args = (torch.randn(1, 3, 256, 256),) pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args) print("Pre-Autograd ATen Dialect Graph") print(pre_autograd_aten_dialect) aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args) print("ATen Dialect Graph") print(aten_dialect) .. rst-class:: sphx-glr-script-out .. code-block:: none Pre-Autograd ATen Dialect Graph GraphModule() def forward(self, x): arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) _param_constant0 = self._param_constant0 _param_constant1 = self._param_constant1 conv2d_default = torch.ops.aten.conv2d.default(arg0, _param_constant0, _param_constant1, [1, 1], [1, 1]); arg0 = _param_constant0 = _param_constant1 = None relu_default = torch.ops.aten.relu.default(conv2d_default); conv2d_default = None return pytree.tree_unflatten([relu_default], self._out_spec) # To see more debug info, please use `graph_module.print_readable()` ATen Dialect Graph ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256]): # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:68, code: a = self.conv(x) convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); arg2_1 = arg0_1 = arg1_1 = None # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:69, code: return self.relu(a) relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(convolution); convolution = None return (relu,) Graph signature: ExportGraphSignature(parameters=['_param_constant0', '_param_constant1'], buffers=[], user_inputs=['arg2_1'], user_outputs=['relu'], inputs_to_parameters={'arg0_1': '_param_constant0', 'arg1_1': '_param_constant1'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Range constraints: {} Equality constraints: [] .. GENERATED FROM PYTHON SOURCE LINES 82-107 The output of ``torch._export.capture_pre_autograd_graph`` is a fully flattened graph (meaning the graph does not contain any module hierarchy, except in the case of control flow operators). Furthermore, the captured graph contains only ATen operators (~3000 ops) which are Autograd safe, for example, safe for eager mode training. The output of ``torch.export`` further compiles the graph to a lower and cleaner representation. Specifically, it has the following: - The graph is purely functional, meaning it does not contain operations with side effects such as mutations or aliasing. - The graph contains only a small defined `Core ATen IR <https://pytorch.org/docs/stable/torch.compiler_ir.html#core-aten-ir>`__ operator set (~180 ops), along with registered custom operators. - The nodes in the graph contain metadata captured during tracing, such as a stacktrace from user's code. More specifications about the result of ``torch.export`` can be found `here <https://pytorch.org/docs/2.1/export.html>`__ . Since the result of ``torch.export`` is a graph containing the Core ATen operators, we will call this the ``ATen Dialect``, and since ``torch._export.capture_pre_autograd_graph`` returns a graph containing the set of ATen operators which are Autograd safe, we will call it the ``Pre-Autograd ATen Dialect``. .. GENERATED FROM PYTHON SOURCE LINES 109-115 Expressing Dynamism ^^^^^^^^^^^^^^^^^^^ By default, the exporting flow will trace the program assuming that all input shapes are static, so if we run the program with inputs shapes that are different than the ones we used while tracing, we will run into an error: .. GENERATED FROM PYTHON SOURCE LINES 115-136 .. code-block:: default import traceback as tb def f(x, y): return x + y example_args = (torch.randn(3, 3), torch.randn(3, 3)) pre_autograd_aten_dialect = capture_pre_autograd_graph(f, example_args) aten_dialect: ExportedProgram = export(f, example_args) # Works correctly print(aten_dialect(torch.ones(3, 3), torch.ones(3, 3))) # Errors try: print(aten_dialect(torch.ones(3, 2), torch.ones(3, 2))) except Exception: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[2., 2., 2.], [2., 2., 2.], [2., 2., 2.]]) Traceback (most recent call last): File "/pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py", line 132, in <module> print(aten_dialect(torch.ones(3, 2), torch.ones(3, 2))) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 369, in __call__ self._check_input_constraints(*ordered_params, *ordered_buffers, *args) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 664, in _check_input_constraints _assertion_graph(*args) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 728, in call_wrapped return self._wrapped_call(self, *args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 307, in __call__ raise e File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 294, in __call__ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1528, in _call_impl return forward_call(*args, **kwargs) File "<eval_with_key>.107", line 11, in forward _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, 'Input arg1_1.shape[1] is specialized at 3'); scalar_tensor = None File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_ops.py", line 516, in __call__ return self._op(*args, **kwargs or {}) RuntimeError: Input arg1_1.shape[1] is specialized at 3 .. GENERATED FROM PYTHON SOURCE LINES 137-139 To express that some input shapes are dynamic, we can insert constraints to the exporting flow. This is done through the ``dynamic_dim`` API: .. GENERATED FROM PYTHON SOURCE LINES 139-165 .. code-block:: default from torch.export import dynamic_dim def f(x, y): return x + y example_args = (torch.randn(3, 3), torch.randn(3, 3)) constraints = [ # Input 0, dimension 1 is dynamic dynamic_dim(example_args[0], 1), # Input 0, dimension 1 must be greater than or equal to 1 1 <= dynamic_dim(example_args[0], 1), # Input 0, dimension 1 must be less than or equal to 10 dynamic_dim(example_args[0], 1) <= 10, # Input 1, dimension 1 is equal to input 0, dimension 1 dynamic_dim(example_args[1], 1) == dynamic_dim(example_args[0], 1), ] pre_autograd_aten_dialect = capture_pre_autograd_graph( f, example_args, constraints=constraints ) aten_dialect: ExportedProgram = export(f, example_args, constraints=constraints) print("ATen Dialect Graph") print(aten_dialect) .. rst-class:: sphx-glr-script-out .. code-block:: none ATen Dialect Graph ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3, s0], arg1_1: f32[3, s0]): # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:144, code: return x + y add: f32[3, s0] = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None return (add,) Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['add'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Range constraints: {s0: RangeConstraint(min_val=2, max_val=10)} Equality constraints: [(InputDim(input_name='arg1_1', dim=1), InputDim(input_name='arg0_1', dim=1))] .. GENERATED FROM PYTHON SOURCE LINES 166-177 Note that that the inputs ``arg0_1`` and ``arg1_1`` now have shapes (3, s0), with ``s0`` being a symbol representing that this dimension can be a range of values. Additionally, we can see in the **Range constraints** that value of ``s0`` has the range [1, 10], which was specified by our constraints. We also see in the **Equality constraints**, the tuple ``(InputDim(input_name='arg1_1', dim=1), InputDim(input_name='arg0_1', dim=1))```, meaning that input 0's dimension 1 is equal to input 1's dimension 1, which was also specified by our constraints. Now let's try running the model with different shapes: .. GENERATED FROM PYTHON SOURCE LINES 177-195 .. code-block:: default # Works correctly print(aten_dialect(torch.ones(3, 3), torch.ones(3, 3))) print(aten_dialect(torch.ones(3, 2), torch.ones(3, 2))) # Errors because it violates our constraint that input 0, dim 1 <= 10 try: print(aten_dialect(torch.ones(3, 15), torch.ones(3, 15))) except Exception: tb.print_exc() # Errors because it violates our constraint that input 0, dim 1 == input 1, dim 1 try: print(aten_dialect(torch.ones(3, 3), torch.ones(3, 2))) except Exception: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[2., 2., 2.], [2., 2., 2.], [2., 2., 2.]]) tensor([[2., 2.], [2., 2.], [2., 2.]]) Traceback (most recent call last): File "/pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py", line 184, in <module> print(aten_dialect(torch.ones(3, 15), torch.ones(3, 15))) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 369, in __call__ self._check_input_constraints(*ordered_params, *ordered_buffers, *args) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 664, in _check_input_constraints _assertion_graph(*args) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 728, in call_wrapped return self._wrapped_call(self, *args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 307, in __call__ raise e File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 294, in __call__ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1528, in _call_impl return forward_call(*args, **kwargs) File "<eval_with_key>.141", line 17, in forward _assert_async_2 = torch.ops.aten._assert_async.msg(scalar_tensor_2, 'Input arg0_1.shape[1] is outside of specified dynamic range [2, 10]'); scalar_tensor_2 = None File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_ops.py", line 516, in __call__ return self._op(*args, **kwargs or {}) RuntimeError: Input arg0_1.shape[1] is outside of specified dynamic range [2, 10] Traceback (most recent call last): File "/pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py", line 190, in <module> print(aten_dialect(torch.ones(3, 3), torch.ones(3, 2))) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 369, in __call__ self._check_input_constraints(*ordered_params, *ordered_buffers, *args) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 664, in _check_input_constraints _assertion_graph(*args) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 728, in call_wrapped return self._wrapped_call(self, *args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 307, in __call__ raise e File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 294, in __call__ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1528, in _call_impl return forward_call(*args, **kwargs) File "<eval_with_key>.146", line 11, in forward _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, 'Input arg1_1.shape[1] is not equal to input arg0_1.shape[1]'); scalar_tensor = None File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_ops.py", line 516, in __call__ return self._op(*args, **kwargs or {}) RuntimeError: Input arg1_1.shape[1] is not equal to input arg0_1.shape[1] .. GENERATED FROM PYTHON SOURCE LINES 196-206 Addressing Untraceable Code ^^^^^^^^^^^^^^^^^^^^^^^^^^^ As our goal is to capture the entire computational graph from a PyTorch program, we might ultimately run into untraceable parts of programs. To address these issues, the `torch.export documentation <https://pytorch.org/docs/2.1/export.html#limitations-of-torch-export>`__, or the `torch.export tutorial <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html>`__ would be the best place to look. .. GENERATED FROM PYTHON SOURCE LINES 208-223 Performing Quantization ----------------------- To quantize a model, we can do so between the call to ``torch._export.capture_pre_autograd_graph`` and ``torch.export``, in the ``Pre-Autograd ATen Dialect``. This is because quantization must operate at a level which is safe for eager mode training. Compared to `FX Graph Mode Quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html>`__, we will need to call two new APIs: ``prepare_pt2e`` and ``convert_pt2e`` instead of ``prepare_fx`` and ``convert_fx``. It differs in that ``prepare_pt2e`` takes a backend-specific ``Quantizer`` as an argument, which will annotate the nodes in the graph with information needed to quantize the model properly for a specific backend. .. GENERATED FROM PYTHON SOURCE LINES 223-246 .. code-block:: default example_args = (torch.randn(1, 3, 256, 256),) pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args) print("Pre-Autograd ATen Dialect Graph") print(pre_autograd_aten_dialect) from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer) # calibrate with a sample dataset converted_graph = convert_pt2e(prepared_graph) print("Quantized Graph") print(converted_graph) aten_dialect: ExportedProgram = export(converted_graph, example_args) print("ATen Dialect Graph") print(aten_dialect) .. rst-class:: sphx-glr-script-out .. code-block:: none Pre-Autograd ATen Dialect Graph GraphModule() def forward(self, x): arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) _param_constant0 = self._param_constant0 _param_constant1 = self._param_constant1 conv2d_default = torch.ops.aten.conv2d.default(arg0, _param_constant0, _param_constant1, [1, 1], [1, 1]); arg0 = _param_constant0 = _param_constant1 = None relu_default = torch.ops.aten.relu.default(conv2d_default); conv2d_default = None return pytree.tree_unflatten([relu_default], self._out_spec) # To see more debug info, please use `graph_module.print_readable()` /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/ao/quantization/observer.py:1220: UserWarning: must run observer before calling calculate_qparams. Returning default scale and zero point warnings.warn( /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/ao/quantization/utils.py:339: UserWarning: must run observer before calling calculate_qparams. Returning default values. warnings.warn( Quantized Graph GraphModule() def forward(self, x): arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) quantize_per_tensor_default = torch.ops.quantized_decomposed.quantize_per_tensor.default(arg0, 1.0, 0, -128, 127, torch.int8); arg0 = None dequantize_per_tensor_default = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default, 1.0, 0, -128, 127, torch.int8); quantize_per_tensor_default = None _param_constant0 = self._param_constant0 quantize_per_tensor_default_1 = torch.ops.quantized_decomposed.quantize_per_tensor.default(_param_constant0, 1.0, 0, -127, 127, torch.int8); _param_constant0 = None dequantize_per_tensor_default_1 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_1, 1.0, 0, -127, 127, torch.int8); quantize_per_tensor_default_1 = None _param_constant1 = self._param_constant1 conv2d_default = torch.ops.aten.conv2d.default(dequantize_per_tensor_default, dequantize_per_tensor_default_1, _param_constant1, [1, 1], [1, 1]); dequantize_per_tensor_default = dequantize_per_tensor_default_1 = _param_constant1 = None relu_default = torch.ops.aten.relu.default(conv2d_default); conv2d_default = None quantize_per_tensor_default_2 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu_default, 1.0, 0, -128, 127, torch.int8); relu_default = None dequantize_per_tensor_default_2 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_2, 1.0, 0, -128, 127, torch.int8); quantize_per_tensor_default_2 = None return pytree.tree_unflatten([dequantize_per_tensor_default_2], self._out_spec) # To see more debug info, please use `graph_module.print_readable()` ATen Dialect Graph ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256]): # No stacktrace found for following nodes quantize_per_tensor: i8[1, 3, 256, 256] = torch.ops.quantized_decomposed.quantize_per_tensor.default(arg2_1, 1.0, 0, -128, 127, torch.int8); arg2_1 = None # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:68, code: a = self.conv(x) dequantize_per_tensor: f32[1, 3, 256, 256] = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor, 1.0, 0, -128, 127, torch.int8); quantize_per_tensor = None quantize_per_tensor_1: i8[16, 3, 3, 3] = torch.ops.quantized_decomposed.quantize_per_tensor.default(arg0_1, 1.0, 0, -127, 127, torch.int8); arg0_1 = None dequantize_per_tensor_1: f32[16, 3, 3, 3] = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_1, 1.0, 0, -127, 127, torch.int8); quantize_per_tensor_1 = None convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(dequantize_per_tensor, dequantize_per_tensor_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); dequantize_per_tensor = dequantize_per_tensor_1 = arg1_1 = None # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:69, code: return self.relu(a) relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(convolution); convolution = None quantize_per_tensor_2: i8[1, 16, 256, 256] = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu, 1.0, 0, -128, 127, torch.int8); relu = None # No stacktrace found for following nodes dequantize_per_tensor_2: f32[1, 16, 256, 256] = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_2, 1.0, 0, -128, 127, torch.int8); quantize_per_tensor_2 = None return (dequantize_per_tensor_2,) Graph signature: ExportGraphSignature(parameters=['_param_constant0', '_param_constant1'], buffers=[], user_inputs=['arg2_1'], user_outputs=['dequantize_per_tensor_2'], inputs_to_parameters={'arg0_1': '_param_constant0', 'arg1_1': '_param_constant1'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Range constraints: {} Equality constraints: [] .. GENERATED FROM PYTHON SOURCE LINES 247-250 More information on how to quantize a model, and how a backend can implement a ``Quantizer`` can be found `here <https://pytorch.org/docs/main/quantization.html#prototype-pytorch-2-export-quantization>`__. .. GENERATED FROM PYTHON SOURCE LINES 252-268 Lowering to Edge Dialect ------------------------ After exporting and lowering the graph to the ``ATen Dialect``, the next step is to lower to the ``Edge Dialect``, in which specializations that are useful for edge devices but not necessary for general (server) environments will be applied. Some of these specializations include: - DType specialization - Scalar to tensor conversion - Converting all ops to the ``executorch.exir.dialects.edge`` namespace. Note that this dialect is still backend (or target) agnostic. The lowering is done through the ``to_edge`` API. .. GENERATED FROM PYTHON SOURCE LINES 268-284 .. code-block:: default from executorch.exir import EdgeProgramManager, to_edge example_args = (torch.randn(1, 3, 256, 256),) pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args) print("Pre-Autograd ATen Dialect Graph") print(pre_autograd_aten_dialect) aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args) print("ATen Dialect Graph") print(aten_dialect) edge_program: EdgeProgramManager = to_edge(aten_dialect) print("Edge Dialect Graph") print(edge_program.exported_program()) .. rst-class:: sphx-glr-script-out .. code-block:: none Pre-Autograd ATen Dialect Graph GraphModule() def forward(self, x): arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) _param_constant0 = self._param_constant0 _param_constant1 = self._param_constant1 conv2d_default = torch.ops.aten.conv2d.default(arg0, _param_constant0, _param_constant1, [1, 1], [1, 1]); arg0 = _param_constant0 = _param_constant1 = None relu_default = torch.ops.aten.relu.default(conv2d_default); conv2d_default = None return pytree.tree_unflatten([relu_default], self._out_spec) # To see more debug info, please use `graph_module.print_readable()` ATen Dialect Graph ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256]): # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:68, code: a = self.conv(x) convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); arg2_1 = arg0_1 = arg1_1 = None # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:69, code: return self.relu(a) relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(convolution); convolution = None return (relu,) Graph signature: ExportGraphSignature(parameters=['_param_constant0', '_param_constant1'], buffers=[], user_inputs=['arg2_1'], user_outputs=['relu'], inputs_to_parameters={'arg0_1': '_param_constant0', 'arg1_1': '_param_constant1'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Range constraints: {} Equality constraints: [] Edge Dialect Graph ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256]): # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:68, code: a = self.conv(x) aten_convolution_default: f32[1, 16, 256, 256] = executorch_exir_dialects_edge__ops_aten_convolution_default(arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); arg2_1 = arg0_1 = arg1_1 = None # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:69, code: return self.relu(a) aten_relu_default: f32[1, 16, 256, 256] = executorch_exir_dialects_edge__ops_aten_relu_default(aten_convolution_default); aten_convolution_default = None return (aten_relu_default,) Graph signature: ExportGraphSignature(parameters=['_param_constant0', '_param_constant1'], buffers=[], user_inputs=['arg2_1'], user_outputs=['aten_relu_default'], inputs_to_parameters={'arg0_1': '_param_constant0', 'arg1_1': '_param_constant1'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Range constraints: {} Equality constraints: [] .. GENERATED FROM PYTHON SOURCE LINES 285-289 ``to_edge()`` returns an ``EdgeProgramManager`` object, which contains the exported programs which will be placed on this device. This data structure allows users to export multiple programs and combine them into one binary. If there is only one program, it will by default be saved to the name "forward". .. GENERATED FROM PYTHON SOURCE LINES 289-318 .. code-block:: default def encode(x): return torch.nn.functional.linear(x, torch.randn(5, 10)) def decode(x): return torch.nn.functional.linear(x, torch.randn(10, 5)) encode_args = (torch.randn(1, 10),) aten_encode: ExportedProgram = export( capture_pre_autograd_graph(encode, encode_args), encode_args, ) decode_args = (torch.randn(1, 5),) aten_decode: ExportedProgram = export( capture_pre_autograd_graph(decode, decode_args), decode_args, ) edge_program: EdgeProgramManager = to_edge( {"encode": aten_encode, "decode": aten_decode} ) for method in edge_program.methods: print(f"Edge Dialect graph of {method}") print(edge_program.exported_program(method)) .. rst-class:: sphx-glr-script-out .. code-block:: none Edge Dialect graph of encode ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[1, 10]): # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:292, code: return torch.nn.functional.linear(x, torch.randn(5, 10)) aten_randn_default: f32[5, 10] = executorch_exir_dialects_edge__ops_aten_randn_default([5, 10], device = device(type='cpu'), pin_memory = False) aten_permute_copy_default: f32[10, 5] = executorch_exir_dialects_edge__ops_aten_permute_copy_default(aten_randn_default, [1, 0]); aten_randn_default = None aten_mm_default: f32[1, 5] = executorch_exir_dialects_edge__ops_aten_mm_default(arg0_1, aten_permute_copy_default); arg0_1 = aten_permute_copy_default = None return (aten_mm_default,) Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['aten_mm_default'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Range constraints: {} Equality constraints: [] Edge Dialect graph of decode ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[1, 5]): # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:296, code: return torch.nn.functional.linear(x, torch.randn(10, 5)) aten_randn_default: f32[10, 5] = executorch_exir_dialects_edge__ops_aten_randn_default([10, 5], device = device(type='cpu'), pin_memory = False) aten_permute_copy_default: f32[5, 10] = executorch_exir_dialects_edge__ops_aten_permute_copy_default(aten_randn_default, [1, 0]); aten_randn_default = None aten_mm_default: f32[1, 10] = executorch_exir_dialects_edge__ops_aten_mm_default(arg0_1, aten_permute_copy_default); arg0_1 = aten_permute_copy_default = None return (aten_mm_default,) Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['aten_mm_default'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Range constraints: {} Equality constraints: [] .. GENERATED FROM PYTHON SOURCE LINES 319-328 We can also run additional passes on the exported program through the ``transform`` API. An in-depth documentation on how to write transformations can be found `here <../compiler-custom-compiler-passes.html>`__. Note that since the graph is now in the Edge Dialect, all passes must also result in a valid Edge Dialect graph (specifically one thing to point out is that the operators are now in the ``executorch.exir.dialects.edge`` namespace, rather than the ``torch.ops.aten`` namespace. .. GENERATED FROM PYTHON SOURCE LINES 328-354 .. code-block:: default example_args = (torch.randn(1, 3, 256, 256),) pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args) aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args) edge_program: EdgeProgramManager = to_edge(aten_dialect) print("Edge Dialect Graph") print(edge_program.exported_program()) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass class ConvertReluToSigmoid(ExportPass): def call_operator(self, op, args, kwargs, meta): if op == exir_ops.edge.aten.relu.default: return super().call_operator( exir_ops.edge.aten.sigmoid.default, args, kwargs, meta ) else: return super().call_operator(op, args, kwargs, meta) transformed_edge_program = edge_program.transform((ConvertReluToSigmoid(),)) print("Transformed Edge Dialect Graph") print(transformed_edge_program.exported_program()) .. rst-class:: sphx-glr-script-out .. code-block:: none Edge Dialect Graph ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256]): # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:68, code: a = self.conv(x) aten_convolution_default: f32[1, 16, 256, 256] = executorch_exir_dialects_edge__ops_aten_convolution_default(arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); arg2_1 = arg0_1 = arg1_1 = None # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:69, code: return self.relu(a) aten_relu_default: f32[1, 16, 256, 256] = executorch_exir_dialects_edge__ops_aten_relu_default(aten_convolution_default); aten_convolution_default = None return (aten_relu_default,) Graph signature: ExportGraphSignature(parameters=['_param_constant0', '_param_constant1'], buffers=[], user_inputs=['arg2_1'], user_outputs=['aten_relu_default'], inputs_to_parameters={'arg0_1': '_param_constant0', 'arg1_1': '_param_constant1'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Range constraints: {} Equality constraints: [] Transformed Edge Dialect Graph ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256]): # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:68, code: a = self.conv(x) aten_convolution_default: f32[1, 16, 256, 256] = executorch_exir_dialects_edge__ops_aten_convolution_default(arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); arg2_1 = arg0_1 = arg1_1 = None # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:69, code: return self.relu(a) aten_sigmoid_default: f32[1, 16, 256, 256] = executorch_exir_dialects_edge__ops_aten_sigmoid_default(aten_convolution_default); aten_convolution_default = None return (aten_sigmoid_default,) Graph signature: ExportGraphSignature(parameters=['_param_constant0', '_param_constant1'], buffers=[], user_inputs=['arg2_1'], user_outputs=['aten_sigmoid_default'], inputs_to_parameters={'arg0_1': '_param_constant0', 'arg1_1': '_param_constant1'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Range constraints: {} Equality constraints: [] .. GENERATED FROM PYTHON SOURCE LINES 355-370 Delegating to a Backend ----------------------- We can now delegate parts of the graph or the whole graph to a third-party backend through the ``to_backend`` API. An in-depth documentation on the specifics of backend delegation, including how to delegate to a backend and how to implement a backend, can be found `here <../compiler-delegate-and-partitioner.html>`__ There are three ways for using this API: 1. We can lower the whole module. 2. We can take the lowered module, and insert it in another larger module. 3. We can partition the module into subgraphs that are lowerable, and then lower those subgraphs to a backend. .. GENERATED FROM PYTHON SOURCE LINES 372-378 Lowering the Whole Module ^^^^^^^^^^^^^^^^^^^^^^^^^ To lower an entire module, we can pass ``to_backend`` the backend name, the module to be lowered, and a list of compile specs to help the backend with the lowering process. .. GENERATED FROM PYTHON SOURCE LINES 378-416 .. code-block:: default class LowerableModule(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.sin(x) # Export and lower the module to Edge Dialect example_args = (torch.ones(1),) pre_autograd_aten_dialect = capture_pre_autograd_graph(LowerableModule(), example_args) aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args) edge_program: EdgeProgramManager = to_edge(aten_dialect) to_be_lowered_module = edge_program.exported_program() from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend # Import the backend from executorch.exir.backend.test.backend_with_compiler_demo import ( # noqa BackendWithCompilerDemo, ) # Lower the module lowered_module: LoweredBackendModule = to_backend( "BackendWithCompilerDemo", to_be_lowered_module, [] ) print(lowered_module) print(lowered_module.backend_id) print(lowered_module.processed_bytes) print(lowered_module.original_module) # Serialize and save it to a file save_path = "delegate.pte" with open(save_path, "wb") as f: f.write(lowered_module.buffer()) .. rst-class:: sphx-glr-script-out .. code-block:: none LoweredBackendModule() BackendWithCompilerDemo b'1#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#' ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[1]): # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:385, code: return torch.sin(x) aten_sin_default: f32[1] = executorch_exir_dialects_edge__ops_aten_sin_default(arg0_1); arg0_1 = None return (aten_sin_default,) Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['aten_sin_default'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Range constraints: {} Equality constraints: [] /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/_pytree.py:590: UserWarning: pytree_to_str is deprecated. Please use treespec_dumps warnings.warn("pytree_to_str is deprecated. Please use treespec_dumps") .. GENERATED FROM PYTHON SOURCE LINES 417-425 In this call, ``to_backend`` will return a ``LoweredBackendModule``. Some important attributes of the ``LoweredBackendModule`` are: - ``backend_id``: The name of the backend this lowered module will run on in the runtime - ``processed_bytes``: a binary blob which will tell the backend how to run this program in the runtime - ``original_module``: the original exported module .. GENERATED FROM PYTHON SOURCE LINES 427-432 Compose the Lowered Module into Another Module ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ In cases where we want to reuse this lowered module in multiple programs, we can compose this lowered module with another module. .. GENERATED FROM PYTHON SOURCE LINES 432-468 .. code-block:: default class NotLowerableModule(torch.nn.Module): def __init__(self, bias): super().__init__() self.bias = bias def forward(self, a, b): return torch.add(torch.add(a, b), self.bias) class ComposedModule(torch.nn.Module): def __init__(self): super().__init__() self.non_lowerable = NotLowerableModule(torch.ones(1) * 0.3) self.lowerable = lowered_module def forward(self, x): a = self.lowerable(x) b = self.lowerable(a) ret = self.non_lowerable(a, b) return a, b, ret example_args = (torch.ones(1),) pre_autograd_aten_dialect = capture_pre_autograd_graph(ComposedModule(), example_args) aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args) edge_program: EdgeProgramManager = to_edge(aten_dialect) exported_program = edge_program.exported_program() print("Edge Dialect graph") print(exported_program) print("Lowered Module within the graph") print(exported_program.graph_module.lowered_module_0.backend_id) print(exported_program.graph_module.lowered_module_0.processed_bytes) print(exported_program.graph_module.lowered_module_0.original_module) .. rst-class:: sphx-glr-script-out .. code-block:: none Edge Dialect graph ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[1], arg1_1: f32[1]): # File: /pytorch/executorch/exir/lowered_backend_module.py:273, code: return executorch_call_delegate(self, *args) lowered_module_0 = self.lowered_module_0 executorch_call_delegate: f32[1] = torch.ops.executorch_call_delegate(lowered_module_0, arg1_1); lowered_module_0 = arg1_1 = None # File: /pytorch/executorch/exir/lowered_backend_module.py:273, code: return executorch_call_delegate(self, *args) lowered_module_1 = self.lowered_module_0 executorch_call_delegate_1: f32[1] = torch.ops.executorch_call_delegate(lowered_module_1, executorch_call_delegate); lowered_module_1 = None # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:440, code: return torch.add(torch.add(a, b), self.bias) aten_add_tensor: f32[1] = executorch_exir_dialects_edge__ops_aten_add_Tensor(executorch_call_delegate, executorch_call_delegate_1) aten_add_tensor_1: f32[1] = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_add_tensor, arg0_1); aten_add_tensor = arg0_1 = None return (executorch_call_delegate, executorch_call_delegate_1, aten_add_tensor_1) Graph signature: ExportGraphSignature(parameters=[], buffers=['_tensor_constant0'], user_inputs=['arg1_1'], user_outputs=['executorch_call_delegate', 'executorch_call_delegate_1', 'aten_add_tensor_1'], inputs_to_parameters={}, inputs_to_buffers={'arg0_1': '_tensor_constant0'}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Range constraints: {} Equality constraints: [] Lowered Module within the graph BackendWithCompilerDemo b'1#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#' ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[1]): # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:385, code: return torch.sin(x) aten_sin_default: f32[1] = executorch_exir_dialects_edge__ops_aten_sin_default(arg0_1); arg0_1 = None return (aten_sin_default,) Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['aten_sin_default'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Range constraints: {} Equality constraints: [] .. GENERATED FROM PYTHON SOURCE LINES 469-473 Notice that there is now a ``torch.ops.executorch_call_delegate`` node in the graph, which is calling ``lowered_module_0``. Additionally, the contents of ``lowered_module_0`` are the same as the ``lowered_module`` we created previously. .. GENERATED FROM PYTHON SOURCE LINES 475-483 Partition and Lower Parts of a Module ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ A separate lowering flow is to pass ``to_backend`` the module that we want to lower, and a backend-specific partitioner. ``to_backend`` will use the backend-specific partitioner to tag nodes in the module which are lowerable, partition those nodes into subgraphs, and then create a ``LoweredBackendModule`` for each of those subgraphs. .. GENERATED FROM PYTHON SOURCE LINES 483-510 .. code-block:: default def f(a, x, b): y = torch.mm(a, x) z = y + b a = z - a y = torch.mm(a, x) z = y + b return z example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)) pre_autograd_aten_dialect = capture_pre_autograd_graph(f, example_args) aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args) edge_program: EdgeProgramManager = to_edge(aten_dialect) exported_program = edge_program.exported_program() print("Edge Dialect graph") print(exported_program) from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo delegated_program = to_backend(exported_program, AddMulPartitionerDemo) print("Delegated program") print(delegated_program) print(delegated_program.graph_module.lowered_module_0.original_module) print(delegated_program.graph_module.lowered_module_1.original_module) .. rst-class:: sphx-glr-script-out .. code-block:: none Edge Dialect graph ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[2, 2], arg1_1: f32[2, 2], arg2_1: f32[2, 2]): # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:486, code: y = torch.mm(a, x) aten_mm_default: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_mm_default(arg0_1, arg1_1) # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:487, code: z = y + b aten_add_tensor: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_mm_default, arg2_1); aten_mm_default = None # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:488, code: a = z - a aten_sub_tensor: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_add_tensor, arg0_1); aten_add_tensor = arg0_1 = None # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:489, code: y = torch.mm(a, x) aten_mm_default_1: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_mm_default(aten_sub_tensor, arg1_1); aten_sub_tensor = arg1_1 = None # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:490, code: z = y + b aten_add_tensor_1: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_mm_default_1, arg2_1); aten_mm_default_1 = arg2_1 = None return (aten_add_tensor_1,) Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1', 'arg2_1'], user_outputs=['aten_add_tensor_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Range constraints: {} Equality constraints: [] Delegated program ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[2, 2], arg1_1: f32[2, 2], arg2_1: f32[2, 2]): # No stacktrace found for following nodes lowered_module_0 = self.lowered_module_0 executorch_call_delegate = torch.ops.executorch_call_delegate(lowered_module_0, arg0_1, arg1_1, arg2_1); lowered_module_0 = None getitem: f32[2, 2] = executorch_call_delegate[0]; executorch_call_delegate = None # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:488, code: a = z - a aten_sub_tensor: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_sub_Tensor(getitem, arg0_1); getitem = arg0_1 = None # No stacktrace found for following nodes lowered_module_1 = self.lowered_module_1 executorch_call_delegate_1 = torch.ops.executorch_call_delegate(lowered_module_1, aten_sub_tensor, arg1_1, arg2_1); lowered_module_1 = aten_sub_tensor = arg1_1 = arg2_1 = None getitem_1: f32[2, 2] = executorch_call_delegate_1[0]; executorch_call_delegate_1 = None return (getitem_1,) Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1', 'arg2_1'], user_outputs=['getitem_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Range constraints: {} Equality constraints: [] ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[2, 2], arg1_1: f32[2, 2], arg2_1: f32[2, 2]): # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:486, code: y = torch.mm(a, x) aten_mm_default: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_mm_default(arg0_1, arg1_1); arg0_1 = arg1_1 = None # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:487, code: z = y + b aten_add_tensor: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_mm_default, arg2_1); aten_mm_default = arg2_1 = None return [aten_add_tensor] Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1', 'arg2_1'], user_outputs=['aten_add_tensor'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Range constraints: {} Equality constraints: [] ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, aten_sub_tensor: f32[2, 2], arg1_1: f32[2, 2], arg2_1: f32[2, 2]): # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:489, code: y = torch.mm(a, x) aten_mm_default_1: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_mm_default(aten_sub_tensor, arg1_1); aten_sub_tensor = arg1_1 = None # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:490, code: z = y + b aten_add_tensor_1: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_mm_default_1, arg2_1); aten_mm_default_1 = arg2_1 = None return [aten_add_tensor_1] Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['aten_sub_tensor', 'arg1_1', 'arg2_1'], user_outputs=['aten_add_tensor_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Range constraints: {} Equality constraints: [] .. GENERATED FROM PYTHON SOURCE LINES 511-517 Notice that there are now 2 ``torch.ops.executorch_call_delegate`` nodes in the graph, one containing the operations `add, mul` and the other containing the operations `mul, add`. Alternatively, a more cohesive API to lower parts of a module is to directly call ``to_backend`` on it: .. GENERATED FROM PYTHON SOURCE LINES 517-538 .. code-block:: default def f(a, x, b): y = torch.mm(a, x) z = y + b a = z - a y = torch.mm(a, x) z = y + b return z example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)) pre_autograd_aten_dialect = capture_pre_autograd_graph(f, example_args) aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args) edge_program: EdgeProgramManager = to_edge(aten_dialect) exported_program = edge_program.exported_program() delegated_program = edge_program.to_backend(AddMulPartitionerDemo) print("Delegated program") print(delegated_program.exported_program()) .. rst-class:: sphx-glr-script-out .. code-block:: none Delegated program ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[2, 2], arg1_1: f32[2, 2], arg2_1: f32[2, 2]): # No stacktrace found for following nodes lowered_module_0 = self.lowered_module_0 executorch_call_delegate = torch.ops.executorch_call_delegate(lowered_module_0, arg0_1, arg1_1, arg2_1); lowered_module_0 = None getitem: f32[2, 2] = executorch_call_delegate[0]; executorch_call_delegate = None # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:522, code: a = z - a aten_sub_tensor: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_sub_Tensor(getitem, arg0_1); getitem = arg0_1 = None # No stacktrace found for following nodes lowered_module_1 = self.lowered_module_1 executorch_call_delegate_1 = torch.ops.executorch_call_delegate(lowered_module_1, aten_sub_tensor, arg1_1, arg2_1); lowered_module_1 = aten_sub_tensor = arg1_1 = arg2_1 = None getitem_1: f32[2, 2] = executorch_call_delegate_1[0]; executorch_call_delegate_1 = None return (getitem_1,) Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1', 'arg2_1'], user_outputs=['getitem_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Range constraints: {} Equality constraints: [] .. GENERATED FROM PYTHON SOURCE LINES 539-551 Running User-Defined Passes and Memory Planning ----------------------------------------------- As a final step of lowering, we can use the ``to_executorch()`` API to pass in backend-specific passes, such as replacing sets of operators with a custom backend operator, and a memory planning pass, to tell the runtime how to allocate memory ahead of time when running the program. A default memory planning pass is provided, but we can also choose a backend-specific memory planning pass if it exists. More information on writing a custom memory planning pass can be found `here <../compiler-memory-planning.html>`__ .. GENERATED FROM PYTHON SOURCE LINES 551-569 .. code-block:: default from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager from executorch.exir.passes import MemoryPlanningPass executorch_program: ExecutorchProgramManager = edge_program.to_executorch( ExecutorchBackendConfig( passes=[], # User-defined passes memory_planning_pass=MemoryPlanningPass( "greedy" ), # Default memory planning pass ) ) print("ExecuTorch Dialect") print(executorch_program.exported_program()) import executorch.exir as exir .. rst-class:: sphx-glr-script-out .. code-block:: none /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/_pytree.py:590: UserWarning: pytree_to_str is deprecated. Please use treespec_dumps warnings.warn("pytree_to_str is deprecated. Please use treespec_dumps") ExecuTorch Dialect ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[2, 2], arg1_1: f32[2, 2], arg2_1: f32[2, 2]): # No stacktrace found for following nodes alloc: f32[2, 2] = executorch_exir_memory_alloc(((2, 2), torch.float32)) # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:520, code: y = torch.mm(a, x) aten_mm_default: f32[2, 2] = torch.ops.aten.mm.out(arg0_1, arg1_1, out = alloc); alloc = None # No stacktrace found for following nodes alloc_1: f32[2, 2] = executorch_exir_memory_alloc(((2, 2), torch.float32)) # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:521, code: z = y + b aten_add_tensor: f32[2, 2] = torch.ops.aten.add.out(aten_mm_default, arg2_1, out = alloc_1); aten_mm_default = alloc_1 = None # No stacktrace found for following nodes alloc_2: f32[2, 2] = executorch_exir_memory_alloc(((2, 2), torch.float32)) # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:522, code: a = z - a aten_sub_tensor: f32[2, 2] = torch.ops.aten.sub.out(aten_add_tensor, arg0_1, out = alloc_2); aten_add_tensor = arg0_1 = alloc_2 = None # No stacktrace found for following nodes alloc_3: f32[2, 2] = executorch_exir_memory_alloc(((2, 2), torch.float32)) # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:523, code: y = torch.mm(a, x) aten_mm_default_1: f32[2, 2] = torch.ops.aten.mm.out(aten_sub_tensor, arg1_1, out = alloc_3); aten_sub_tensor = arg1_1 = alloc_3 = None # No stacktrace found for following nodes alloc_4: f32[2, 2] = executorch_exir_memory_alloc(((2, 2), torch.float32)) # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:524, code: z = y + b aten_add_tensor_1: f32[2, 2] = torch.ops.aten.add.out(aten_mm_default_1, arg2_1, out = alloc_4); aten_mm_default_1 = arg2_1 = alloc_4 = None return (aten_add_tensor_1,) Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1', 'arg2_1'], user_outputs=['aten_add_tensor_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None) Range constraints: {} Equality constraints: [] .. GENERATED FROM PYTHON SOURCE LINES 570-586 Notice that in the graph we now see operators like ``torch.ops.aten.sub.out`` and ``torch.ops.aten.div.out`` rather than ``torch.ops.aten.sub.Tensor`` and ``torch.ops.aten.div.Tensor``. This is because between running the backend passes and memory planning passes, to prepare the graph for memory planning, an out-variant pass is run on the graph to convert all of the operators to their out variants. Instead of allocating returned tensors in the kernel implementations, an operator's ``out`` variant will take in a prealloacated tensor to its out kwarg, and store the result there, making it easier for memory planners to do tensor lifetime analysis. We also insert ``alloc`` nodes into the graph containing calls to a special ``executorch.exir.memory.alloc`` operator. This tells us how much memory is needed to allocate each tensor output by the out-variant operator. .. GENERATED FROM PYTHON SOURCE LINES 588-595 Saving to a File ---------------- Finally, we can save the ExecuTorch Program to a file and load it to a device to be run. Here is an example for an entire end-to-end workflow: .. GENERATED FROM PYTHON SOURCE LINES 595-628 .. code-block:: default import torch from torch._export import capture_pre_autograd_graph from torch.export import export, ExportedProgram class M(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return self.linear(x + self.param).clamp(min=0.0, max=1.0) example_args = (torch.randn(3, 4),) pre_autograd_aten_dialect = capture_pre_autograd_graph(M(), example_args) # Optionally do quantization: # pre_autograd_aten_dialect = convert_pt2e(prepare_pt2e(pre_autograd_aten_dialect, CustomBackendQuantizer)) aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args) edge_program: exir.EdgeProgramManager = exir.to_edge(aten_dialect) # Optionally do delegation: # edge_program = edge_program.to_backend(CustomBackendPartitioner) executorch_program: exir.ExecutorchProgramManager = edge_program.to_executorch( ExecutorchBackendConfig( passes=[], # User-defined passes ) ) with open("model.pte", "wb") as file: file.write(executorch_program.buffer) .. rst-class:: sphx-glr-script-out .. code-block:: none /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/_pytree.py:590: UserWarning: pytree_to_str is deprecated. Please use treespec_dumps warnings.warn("pytree_to_str is deprecated. Please use treespec_dumps") .. GENERATED FROM PYTHON SOURCE LINES 629-644 Conclusion ---------- In this tutorial, we went over the APIs and steps required to lower a PyTorch program to a file that can be run on the ExecuTorch runtime. Links Mentioned ^^^^^^^^^^^^^^^ - `torch.export Documentation <https://pytorch.org/docs/2.1/export.html>`__ - `Quantization Documentation <https://pytorch.org/docs/main/quantization.html#prototype-pytorch-2-export-quantization>`__ - `IR Spec <../ir-exir.html>`__ - `Writing Compiler Passes + Partitioner Documentation <../compiler-custom-compiler-passes.html>`__ - `Backend Delegation Documentation <../compiler-delegate-and-partitioner.html>`__ - `Memory Planning Documentation <../compiler-memory-planning.html>`__ .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 4.224 seconds) .. _sphx_glr_download_tutorials_export-to-executorch-tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: export-to-executorch-tutorial.py <export-to-executorch-tutorial.py>` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: export-to-executorch-tutorial.ipynb <export-to-executorch-tutorial.ipynb>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_