• Docs >
  • Exporting to ExecuTorch Tutorial
Shortcuts

Exporting to ExecuTorch Tutorial

Author: Angela Yi

ExecuTorch is a unified ML stack for lowering PyTorch models to edge devices. It introduces improved entry points to perform model, device, and/or use-case specific optimizations such as backend delegation, user-defined compiler transformations, default or user-defined memory planning, and more.

At a high level, the workflow looks as follows:

../_images/executorch_stack.png

In this tutorial, we will cover the APIs in the “Program preparation” steps to lower a PyTorch model to a format which can be loaded to device and run on the ExecuTorch runtime.

Prerequisites

To run this tutorial, you’ll first need to Set up your ExecuTorch environment.

Exporting a Model

Note: The Export APIs are still undergoing changes to align better with the longer term state of export. Please refer to this issue for more details.

The first step of lowering to ExecuTorch is to export the given model (any callable or torch.nn.Module) to a graph representation. This is done via the two-stage APIs, torch._export.capture_pre_autograd_graph, and torch.export.

Both APIs take in a model (any callable or torch.nn.Module), a tuple of positional arguments, optionally a dictionary of keyword arguments (not shown in the example), and a list of constraints (covered later).

import torch
from torch._export import capture_pre_autograd_graph
from torch.export import export, ExportedProgram


class SimpleConv(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(
            in_channels=3, out_channels=16, kernel_size=3, padding=1
        )
        self.relu = torch.nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        a = self.conv(x)
        return self.relu(a)


example_args = (torch.randn(1, 3, 256, 256),)
pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args)
print("Pre-Autograd ATen Dialect Graph")
print(pre_autograd_aten_dialect)

aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
print("ATen Dialect Graph")
print(aten_dialect)
Pre-Autograd ATen Dialect Graph
GraphModule()



def forward(self, x):
    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
    _param_constant0 = self._param_constant0
    _param_constant1 = self._param_constant1
    conv2d_default = torch.ops.aten.conv2d.default(arg0, _param_constant0, _param_constant1, [1, 1], [1, 1]);  arg0 = _param_constant0 = _param_constant1 = None
    relu_default = torch.ops.aten.relu.default(conv2d_default);  conv2d_default = None
    return pytree.tree_unflatten([relu_default], self._out_spec)

# To see more debug info, please use `graph_module.print_readable()`
ATen Dialect Graph
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256]):
            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:68, code: a = self.conv(x)
            convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg2_1 = arg0_1 = arg1_1 = None

            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:69, code: return self.relu(a)
            relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(convolution);  convolution = None
            return (relu,)

Graph signature: ExportGraphSignature(parameters=['_param_constant0', '_param_constant1'], buffers=[], user_inputs=['arg2_1'], user_outputs=['relu'], inputs_to_parameters={'arg0_1': '_param_constant0', 'arg1_1': '_param_constant1'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Range constraints: {}
Equality constraints: []

The output of torch._export.capture_pre_autograd_graph is a fully flattened graph (meaning the graph does not contain any module hierarchy, except in the case of control flow operators). Furthermore, the captured graph contains only ATen operators (~3000 ops) which are Autograd safe, for example, safe for eager mode training.

The output of torch.export further compiles the graph to a lower and cleaner representation. Specifically, it has the following:

  • The graph is purely functional, meaning it does not contain operations with side effects such as mutations or aliasing.

  • The graph contains only a small defined Core ATen IR operator set (~180 ops), along with registered custom operators.

  • The nodes in the graph contain metadata captured during tracing, such as a stacktrace from user’s code.

More specifications about the result of torch.export can be found here .

Since the result of torch.export is a graph containing the Core ATen operators, we will call this the ATen Dialect, and since torch._export.capture_pre_autograd_graph returns a graph containing the set of ATen operators which are Autograd safe, we will call it the Pre-Autograd ATen Dialect.

Expressing Dynamism

By default, the exporting flow will trace the program assuming that all input shapes are static, so if we run the program with inputs shapes that are different than the ones we used while tracing, we will run into an error:

import traceback as tb


def f(x, y):
    return x + y


example_args = (torch.randn(3, 3), torch.randn(3, 3))
pre_autograd_aten_dialect = capture_pre_autograd_graph(f, example_args)
aten_dialect: ExportedProgram = export(f, example_args)

# Works correctly
print(aten_dialect(torch.ones(3, 3), torch.ones(3, 3)))

# Errors
try:
    print(aten_dialect(torch.ones(3, 2), torch.ones(3, 2)))
except Exception:
    tb.print_exc()
tensor([[2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]])
Traceback (most recent call last):
  File "/pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py", line 132, in <module>
    print(aten_dialect(torch.ones(3, 2), torch.ones(3, 2)))
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 369, in __call__
    self._check_input_constraints(*ordered_params, *ordered_buffers, *args)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 664, in _check_input_constraints
    _assertion_graph(*args)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 728, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 307, in __call__
    raise e
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 294, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1528, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.107", line 11, in forward
    _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, 'Input arg1_1.shape[1] is specialized at 3');  scalar_tensor = None
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_ops.py", line 516, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: Input arg1_1.shape[1] is specialized at 3

To express that some input shapes are dynamic, we can insert constraints to the exporting flow. This is done through the dynamic_dim API:

from torch.export import dynamic_dim


def f(x, y):
    return x + y


example_args = (torch.randn(3, 3), torch.randn(3, 3))
constraints = [
    # Input 0, dimension 1 is dynamic
    dynamic_dim(example_args[0], 1),
    # Input 0, dimension 1 must be greater than or equal to 1
    1 <= dynamic_dim(example_args[0], 1),
    # Input 0, dimension 1 must be less than or equal to 10
    dynamic_dim(example_args[0], 1) <= 10,
    # Input 1, dimension 1 is equal to input 0, dimension 1
    dynamic_dim(example_args[1], 1) == dynamic_dim(example_args[0], 1),
]
pre_autograd_aten_dialect = capture_pre_autograd_graph(
    f, example_args, constraints=constraints
)
aten_dialect: ExportedProgram = export(f, example_args, constraints=constraints)
print("ATen Dialect Graph")
print(aten_dialect)
ATen Dialect Graph
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3, s0], arg1_1: f32[3, s0]):
            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:144, code: return x + y
            add: f32[3, s0] = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
            return (add,)

Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['add'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Range constraints: {s0: RangeConstraint(min_val=2, max_val=10)}
Equality constraints: [(InputDim(input_name='arg1_1', dim=1), InputDim(input_name='arg0_1', dim=1))]

Note that that the inputs arg0_1 and arg1_1 now have shapes (3, s0), with s0 being a symbol representing that this dimension can be a range of values.

Additionally, we can see in the Range constraints that value of s0 has the range [1, 10], which was specified by our constraints. We also see in the Equality constraints, the tuple (InputDim(input_name='arg1_1', dim=1), InputDim(input_name='arg0_1', dim=1))`, meaning that input 0’s dimension 1 is equal to input 1’s dimension 1, which was also specified by our constraints.

Now let’s try running the model with different shapes:

# Works correctly
print(aten_dialect(torch.ones(3, 3), torch.ones(3, 3)))
print(aten_dialect(torch.ones(3, 2), torch.ones(3, 2)))

# Errors because it violates our constraint that input 0, dim 1 <= 10
try:
    print(aten_dialect(torch.ones(3, 15), torch.ones(3, 15)))
except Exception:
    tb.print_exc()

# Errors because it violates our constraint that input 0, dim 1 == input 1, dim 1
try:
    print(aten_dialect(torch.ones(3, 3), torch.ones(3, 2)))
except Exception:
    tb.print_exc()
tensor([[2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]])
tensor([[2., 2.],
        [2., 2.],
        [2., 2.]])
Traceback (most recent call last):
  File "/pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py", line 184, in <module>
    print(aten_dialect(torch.ones(3, 15), torch.ones(3, 15)))
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 369, in __call__
    self._check_input_constraints(*ordered_params, *ordered_buffers, *args)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 664, in _check_input_constraints
    _assertion_graph(*args)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 728, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 307, in __call__
    raise e
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 294, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1528, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.141", line 17, in forward
    _assert_async_2 = torch.ops.aten._assert_async.msg(scalar_tensor_2, 'Input arg0_1.shape[1] is outside of specified dynamic range [2, 10]');  scalar_tensor_2 = None
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_ops.py", line 516, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: Input arg0_1.shape[1] is outside of specified dynamic range [2, 10]
Traceback (most recent call last):
  File "/pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py", line 190, in <module>
    print(aten_dialect(torch.ones(3, 3), torch.ones(3, 2)))
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 369, in __call__
    self._check_input_constraints(*ordered_params, *ordered_buffers, *args)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 664, in _check_input_constraints
    _assertion_graph(*args)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 728, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 307, in __call__
    raise e
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 294, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1528, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.146", line 11, in forward
    _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, 'Input arg1_1.shape[1] is not equal to input arg0_1.shape[1]');  scalar_tensor = None
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_ops.py", line 516, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: Input arg1_1.shape[1] is not equal to input arg0_1.shape[1]

Addressing Untraceable Code

As our goal is to capture the entire computational graph from a PyTorch program, we might ultimately run into untraceable parts of programs. To address these issues, the torch.export documentation, or the torch.export tutorial would be the best place to look.

Performing Quantization

To quantize a model, we can do so between the call to torch._export.capture_pre_autograd_graph and torch.export, in the Pre-Autograd ATen Dialect. This is because quantization must operate at a level which is safe for eager mode training.

Compared to FX Graph Mode Quantization, we will need to call two new APIs: prepare_pt2e and convert_pt2e instead of prepare_fx and convert_fx. It differs in that prepare_pt2e takes a backend-specific Quantizer as an argument, which will annotate the nodes in the graph with information needed to quantize the model properly for a specific backend.

example_args = (torch.randn(1, 3, 256, 256),)
pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args)
print("Pre-Autograd ATen Dialect Graph")
print(pre_autograd_aten_dialect)

from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    get_symmetric_quantization_config,
    XNNPACKQuantizer,
)

quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer)
# calibrate with a sample dataset
converted_graph = convert_pt2e(prepared_graph)
print("Quantized Graph")
print(converted_graph)

aten_dialect: ExportedProgram = export(converted_graph, example_args)
print("ATen Dialect Graph")
print(aten_dialect)
Pre-Autograd ATen Dialect Graph
GraphModule()



def forward(self, x):
    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
    _param_constant0 = self._param_constant0
    _param_constant1 = self._param_constant1
    conv2d_default = torch.ops.aten.conv2d.default(arg0, _param_constant0, _param_constant1, [1, 1], [1, 1]);  arg0 = _param_constant0 = _param_constant1 = None
    relu_default = torch.ops.aten.relu.default(conv2d_default);  conv2d_default = None
    return pytree.tree_unflatten([relu_default], self._out_spec)

# To see more debug info, please use `graph_module.print_readable()`
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/ao/quantization/observer.py:1220: UserWarning: must run observer before calling calculate_qparams.                                    Returning default scale and zero point
  warnings.warn(
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/ao/quantization/utils.py:339: UserWarning: must run observer before calling calculate_qparams. Returning default values.
  warnings.warn(
Quantized Graph
GraphModule()



def forward(self, x):
    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
    quantize_per_tensor_default = torch.ops.quantized_decomposed.quantize_per_tensor.default(arg0, 1.0, 0, -128, 127, torch.int8);  arg0 = None
    dequantize_per_tensor_default = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default, 1.0, 0, -128, 127, torch.int8);  quantize_per_tensor_default = None
    _param_constant0 = self._param_constant0
    quantize_per_tensor_default_1 = torch.ops.quantized_decomposed.quantize_per_tensor.default(_param_constant0, 1.0, 0, -127, 127, torch.int8);  _param_constant0 = None
    dequantize_per_tensor_default_1 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_1, 1.0, 0, -127, 127, torch.int8);  quantize_per_tensor_default_1 = None
    _param_constant1 = self._param_constant1
    conv2d_default = torch.ops.aten.conv2d.default(dequantize_per_tensor_default, dequantize_per_tensor_default_1, _param_constant1, [1, 1], [1, 1]);  dequantize_per_tensor_default = dequantize_per_tensor_default_1 = _param_constant1 = None
    relu_default = torch.ops.aten.relu.default(conv2d_default);  conv2d_default = None
    quantize_per_tensor_default_2 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu_default, 1.0, 0, -128, 127, torch.int8);  relu_default = None
    dequantize_per_tensor_default_2 = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default_2, 1.0, 0, -128, 127, torch.int8);  quantize_per_tensor_default_2 = None
    return pytree.tree_unflatten([dequantize_per_tensor_default_2], self._out_spec)

# To see more debug info, please use `graph_module.print_readable()`
ATen Dialect Graph
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256]):
            # No stacktrace found for following nodes
            quantize_per_tensor: i8[1, 3, 256, 256] = torch.ops.quantized_decomposed.quantize_per_tensor.default(arg2_1, 1.0, 0, -128, 127, torch.int8);  arg2_1 = None

            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:68, code: a = self.conv(x)
            dequantize_per_tensor: f32[1, 3, 256, 256] = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor, 1.0, 0, -128, 127, torch.int8);  quantize_per_tensor = None
            quantize_per_tensor_1: i8[16, 3, 3, 3] = torch.ops.quantized_decomposed.quantize_per_tensor.default(arg0_1, 1.0, 0, -127, 127, torch.int8);  arg0_1 = None
            dequantize_per_tensor_1: f32[16, 3, 3, 3] = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_1, 1.0, 0, -127, 127, torch.int8);  quantize_per_tensor_1 = None
            convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(dequantize_per_tensor, dequantize_per_tensor_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  dequantize_per_tensor = dequantize_per_tensor_1 = arg1_1 = None

            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:69, code: return self.relu(a)
            relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(convolution);  convolution = None
            quantize_per_tensor_2: i8[1, 16, 256, 256] = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu, 1.0, 0, -128, 127, torch.int8);  relu = None

            # No stacktrace found for following nodes
            dequantize_per_tensor_2: f32[1, 16, 256, 256] = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_2, 1.0, 0, -128, 127, torch.int8);  quantize_per_tensor_2 = None
            return (dequantize_per_tensor_2,)

Graph signature: ExportGraphSignature(parameters=['_param_constant0', '_param_constant1'], buffers=[], user_inputs=['arg2_1'], user_outputs=['dequantize_per_tensor_2'], inputs_to_parameters={'arg0_1': '_param_constant0', 'arg1_1': '_param_constant1'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Range constraints: {}
Equality constraints: []

More information on how to quantize a model, and how a backend can implement a Quantizer can be found here.

Lowering to Edge Dialect

After exporting and lowering the graph to the ATen Dialect, the next step is to lower to the Edge Dialect, in which specializations that are useful for edge devices but not necessary for general (server) environments will be applied. Some of these specializations include:

  • DType specialization

  • Scalar to tensor conversion

  • Converting all ops to the executorch.exir.dialects.edge namespace.

Note that this dialect is still backend (or target) agnostic.

The lowering is done through the to_edge API.

from executorch.exir import EdgeProgramManager, to_edge

example_args = (torch.randn(1, 3, 256, 256),)
pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args)
print("Pre-Autograd ATen Dialect Graph")
print(pre_autograd_aten_dialect)

aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
print("ATen Dialect Graph")
print(aten_dialect)

edge_program: EdgeProgramManager = to_edge(aten_dialect)
print("Edge Dialect Graph")
print(edge_program.exported_program())
Pre-Autograd ATen Dialect Graph
GraphModule()



def forward(self, x):
    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
    _param_constant0 = self._param_constant0
    _param_constant1 = self._param_constant1
    conv2d_default = torch.ops.aten.conv2d.default(arg0, _param_constant0, _param_constant1, [1, 1], [1, 1]);  arg0 = _param_constant0 = _param_constant1 = None
    relu_default = torch.ops.aten.relu.default(conv2d_default);  conv2d_default = None
    return pytree.tree_unflatten([relu_default], self._out_spec)

# To see more debug info, please use `graph_module.print_readable()`
ATen Dialect Graph
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256]):
            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:68, code: a = self.conv(x)
            convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg2_1 = arg0_1 = arg1_1 = None

            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:69, code: return self.relu(a)
            relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(convolution);  convolution = None
            return (relu,)

Graph signature: ExportGraphSignature(parameters=['_param_constant0', '_param_constant1'], buffers=[], user_inputs=['arg2_1'], user_outputs=['relu'], inputs_to_parameters={'arg0_1': '_param_constant0', 'arg1_1': '_param_constant1'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Range constraints: {}
Equality constraints: []

Edge Dialect Graph
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256]):
            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:68, code: a = self.conv(x)
            aten_convolution_default: f32[1, 16, 256, 256] = executorch_exir_dialects_edge__ops_aten_convolution_default(arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg2_1 = arg0_1 = arg1_1 = None

            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:69, code: return self.relu(a)
            aten_relu_default: f32[1, 16, 256, 256] = executorch_exir_dialects_edge__ops_aten_relu_default(aten_convolution_default);  aten_convolution_default = None
            return (aten_relu_default,)

Graph signature: ExportGraphSignature(parameters=['_param_constant0', '_param_constant1'], buffers=[], user_inputs=['arg2_1'], user_outputs=['aten_relu_default'], inputs_to_parameters={'arg0_1': '_param_constant0', 'arg1_1': '_param_constant1'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Range constraints: {}
Equality constraints: []

to_edge() returns an EdgeProgramManager object, which contains the exported programs which will be placed on this device. This data structure allows users to export multiple programs and combine them into one binary. If there is only one program, it will by default be saved to the name “forward”.

def encode(x):
    return torch.nn.functional.linear(x, torch.randn(5, 10))


def decode(x):
    return torch.nn.functional.linear(x, torch.randn(10, 5))


encode_args = (torch.randn(1, 10),)
aten_encode: ExportedProgram = export(
    capture_pre_autograd_graph(encode, encode_args),
    encode_args,
)

decode_args = (torch.randn(1, 5),)
aten_decode: ExportedProgram = export(
    capture_pre_autograd_graph(decode, decode_args),
    decode_args,
)

edge_program: EdgeProgramManager = to_edge(
    {"encode": aten_encode, "decode": aten_decode}
)
for method in edge_program.methods:
    print(f"Edge Dialect graph of {method}")
    print(edge_program.exported_program(method))
Edge Dialect graph of encode
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[1, 10]):
            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:292, code: return torch.nn.functional.linear(x, torch.randn(5, 10))
            aten_randn_default: f32[5, 10] = executorch_exir_dialects_edge__ops_aten_randn_default([5, 10], device = device(type='cpu'), pin_memory = False)
            aten_permute_copy_default: f32[10, 5] = executorch_exir_dialects_edge__ops_aten_permute_copy_default(aten_randn_default, [1, 0]);  aten_randn_default = None
            aten_mm_default: f32[1, 5] = executorch_exir_dialects_edge__ops_aten_mm_default(arg0_1, aten_permute_copy_default);  arg0_1 = aten_permute_copy_default = None
            return (aten_mm_default,)

Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['aten_mm_default'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Range constraints: {}
Equality constraints: []

Edge Dialect graph of decode
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[1, 5]):
            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:296, code: return torch.nn.functional.linear(x, torch.randn(10, 5))
            aten_randn_default: f32[10, 5] = executorch_exir_dialects_edge__ops_aten_randn_default([10, 5], device = device(type='cpu'), pin_memory = False)
            aten_permute_copy_default: f32[5, 10] = executorch_exir_dialects_edge__ops_aten_permute_copy_default(aten_randn_default, [1, 0]);  aten_randn_default = None
            aten_mm_default: f32[1, 10] = executorch_exir_dialects_edge__ops_aten_mm_default(arg0_1, aten_permute_copy_default);  arg0_1 = aten_permute_copy_default = None
            return (aten_mm_default,)

Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['aten_mm_default'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Range constraints: {}
Equality constraints: []

We can also run additional passes on the exported program through the transform API. An in-depth documentation on how to write transformations can be found here.

Note that since the graph is now in the Edge Dialect, all passes must also result in a valid Edge Dialect graph (specifically one thing to point out is that the operators are now in the executorch.exir.dialects.edge namespace, rather than the torch.ops.aten namespace.

example_args = (torch.randn(1, 3, 256, 256),)
pre_autograd_aten_dialect = capture_pre_autograd_graph(SimpleConv(), example_args)
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
edge_program: EdgeProgramManager = to_edge(aten_dialect)
print("Edge Dialect Graph")
print(edge_program.exported_program())

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


class ConvertReluToSigmoid(ExportPass):
    def call_operator(self, op, args, kwargs, meta):
        if op == exir_ops.edge.aten.relu.default:
            return super().call_operator(
                exir_ops.edge.aten.sigmoid.default, args, kwargs, meta
            )
        else:
            return super().call_operator(op, args, kwargs, meta)


transformed_edge_program = edge_program.transform((ConvertReluToSigmoid(),))
print("Transformed Edge Dialect Graph")
print(transformed_edge_program.exported_program())
Edge Dialect Graph
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256]):
            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:68, code: a = self.conv(x)
            aten_convolution_default: f32[1, 16, 256, 256] = executorch_exir_dialects_edge__ops_aten_convolution_default(arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg2_1 = arg0_1 = arg1_1 = None

            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:69, code: return self.relu(a)
            aten_relu_default: f32[1, 16, 256, 256] = executorch_exir_dialects_edge__ops_aten_relu_default(aten_convolution_default);  aten_convolution_default = None
            return (aten_relu_default,)

Graph signature: ExportGraphSignature(parameters=['_param_constant0', '_param_constant1'], buffers=[], user_inputs=['arg2_1'], user_outputs=['aten_relu_default'], inputs_to_parameters={'arg0_1': '_param_constant0', 'arg1_1': '_param_constant1'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Range constraints: {}
Equality constraints: []

Transformed Edge Dialect Graph
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256]):
            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:68, code: a = self.conv(x)
            aten_convolution_default: f32[1, 16, 256, 256] = executorch_exir_dialects_edge__ops_aten_convolution_default(arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg2_1 = arg0_1 = arg1_1 = None

            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:69, code: return self.relu(a)
            aten_sigmoid_default: f32[1, 16, 256, 256] = executorch_exir_dialects_edge__ops_aten_sigmoid_default(aten_convolution_default);  aten_convolution_default = None
            return (aten_sigmoid_default,)

Graph signature: ExportGraphSignature(parameters=['_param_constant0', '_param_constant1'], buffers=[], user_inputs=['arg2_1'], user_outputs=['aten_sigmoid_default'], inputs_to_parameters={'arg0_1': '_param_constant0', 'arg1_1': '_param_constant1'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Range constraints: {}
Equality constraints: []

Delegating to a Backend

We can now delegate parts of the graph or the whole graph to a third-party backend through the to_backend API. An in-depth documentation on the specifics of backend delegation, including how to delegate to a backend and how to implement a backend, can be found here

There are three ways for using this API:

  1. We can lower the whole module.

  2. We can take the lowered module, and insert it in another larger module.

  3. We can partition the module into subgraphs that are lowerable, and then lower those subgraphs to a backend.

Lowering the Whole Module

To lower an entire module, we can pass to_backend the backend name, the module to be lowered, and a list of compile specs to help the backend with the lowering process.

class LowerableModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.sin(x)


# Export and lower the module to Edge Dialect
example_args = (torch.ones(1),)
pre_autograd_aten_dialect = capture_pre_autograd_graph(LowerableModule(), example_args)
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
edge_program: EdgeProgramManager = to_edge(aten_dialect)
to_be_lowered_module = edge_program.exported_program()

from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend

# Import the backend
from executorch.exir.backend.test.backend_with_compiler_demo import (  # noqa
    BackendWithCompilerDemo,
)

# Lower the module
lowered_module: LoweredBackendModule = to_backend(
    "BackendWithCompilerDemo", to_be_lowered_module, []
)
print(lowered_module)
print(lowered_module.backend_id)
print(lowered_module.processed_bytes)
print(lowered_module.original_module)

# Serialize and save it to a file
save_path = "delegate.pte"
with open(save_path, "wb") as f:
    f.write(lowered_module.buffer())
LoweredBackendModule()
BackendWithCompilerDemo
b'1#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#'
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[1]):
            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:385, code: return torch.sin(x)
            aten_sin_default: f32[1] = executorch_exir_dialects_edge__ops_aten_sin_default(arg0_1);  arg0_1 = None
            return (aten_sin_default,)

Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['aten_sin_default'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Range constraints: {}
Equality constraints: []

/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/_pytree.py:590: UserWarning: pytree_to_str is deprecated. Please use treespec_dumps
  warnings.warn("pytree_to_str is deprecated. Please use treespec_dumps")

In this call, to_backend will return a LoweredBackendModule. Some important attributes of the LoweredBackendModule are:

  • backend_id: The name of the backend this lowered module will run on in the runtime

  • processed_bytes: a binary blob which will tell the backend how to run this program in the runtime

  • original_module: the original exported module

Compose the Lowered Module into Another Module

In cases where we want to reuse this lowered module in multiple programs, we can compose this lowered module with another module.

class NotLowerableModule(torch.nn.Module):
    def __init__(self, bias):
        super().__init__()
        self.bias = bias

    def forward(self, a, b):
        return torch.add(torch.add(a, b), self.bias)


class ComposedModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.non_lowerable = NotLowerableModule(torch.ones(1) * 0.3)
        self.lowerable = lowered_module

    def forward(self, x):
        a = self.lowerable(x)
        b = self.lowerable(a)
        ret = self.non_lowerable(a, b)
        return a, b, ret


example_args = (torch.ones(1),)
pre_autograd_aten_dialect = capture_pre_autograd_graph(ComposedModule(), example_args)
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
edge_program: EdgeProgramManager = to_edge(aten_dialect)
exported_program = edge_program.exported_program()
print("Edge Dialect graph")
print(exported_program)
print("Lowered Module within the graph")
print(exported_program.graph_module.lowered_module_0.backend_id)
print(exported_program.graph_module.lowered_module_0.processed_bytes)
print(exported_program.graph_module.lowered_module_0.original_module)
Edge Dialect graph
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[1], arg1_1: f32[1]):
            # File: /pytorch/executorch/exir/lowered_backend_module.py:273, code: return executorch_call_delegate(self, *args)
            lowered_module_0 = self.lowered_module_0
            executorch_call_delegate: f32[1] = torch.ops.executorch_call_delegate(lowered_module_0, arg1_1);  lowered_module_0 = arg1_1 = None

            # File: /pytorch/executorch/exir/lowered_backend_module.py:273, code: return executorch_call_delegate(self, *args)
            lowered_module_1 = self.lowered_module_0
            executorch_call_delegate_1: f32[1] = torch.ops.executorch_call_delegate(lowered_module_1, executorch_call_delegate);  lowered_module_1 = None

            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:440, code: return torch.add(torch.add(a, b), self.bias)
            aten_add_tensor: f32[1] = executorch_exir_dialects_edge__ops_aten_add_Tensor(executorch_call_delegate, executorch_call_delegate_1)
            aten_add_tensor_1: f32[1] = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_add_tensor, arg0_1);  aten_add_tensor = arg0_1 = None
            return (executorch_call_delegate, executorch_call_delegate_1, aten_add_tensor_1)

Graph signature: ExportGraphSignature(parameters=[], buffers=['_tensor_constant0'], user_inputs=['arg1_1'], user_outputs=['executorch_call_delegate', 'executorch_call_delegate_1', 'aten_add_tensor_1'], inputs_to_parameters={}, inputs_to_buffers={'arg0_1': '_tensor_constant0'}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Range constraints: {}
Equality constraints: []

Lowered Module within the graph
BackendWithCompilerDemo
b'1#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#'
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[1]):
            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:385, code: return torch.sin(x)
            aten_sin_default: f32[1] = executorch_exir_dialects_edge__ops_aten_sin_default(arg0_1);  arg0_1 = None
            return (aten_sin_default,)

Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['aten_sin_default'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Range constraints: {}
Equality constraints: []

Notice that there is now a torch.ops.executorch_call_delegate node in the graph, which is calling lowered_module_0. Additionally, the contents of lowered_module_0 are the same as the lowered_module we created previously.

Partition and Lower Parts of a Module

A separate lowering flow is to pass to_backend the module that we want to lower, and a backend-specific partitioner. to_backend will use the backend-specific partitioner to tag nodes in the module which are lowerable, partition those nodes into subgraphs, and then create a LoweredBackendModule for each of those subgraphs.

def f(a, x, b):
    y = torch.mm(a, x)
    z = y + b
    a = z - a
    y = torch.mm(a, x)
    z = y + b
    return z


example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
pre_autograd_aten_dialect = capture_pre_autograd_graph(f, example_args)
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
edge_program: EdgeProgramManager = to_edge(aten_dialect)
exported_program = edge_program.exported_program()
print("Edge Dialect graph")
print(exported_program)

from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo

delegated_program = to_backend(exported_program, AddMulPartitionerDemo)
print("Delegated program")
print(delegated_program)
print(delegated_program.graph_module.lowered_module_0.original_module)
print(delegated_program.graph_module.lowered_module_1.original_module)
Edge Dialect graph
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[2, 2], arg1_1: f32[2, 2], arg2_1: f32[2, 2]):
            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:486, code: y = torch.mm(a, x)
            aten_mm_default: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_mm_default(arg0_1, arg1_1)

            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:487, code: z = y + b
            aten_add_tensor: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_mm_default, arg2_1);  aten_mm_default = None

            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:488, code: a = z - a
            aten_sub_tensor: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_add_tensor, arg0_1);  aten_add_tensor = arg0_1 = None

            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:489, code: y = torch.mm(a, x)
            aten_mm_default_1: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_mm_default(aten_sub_tensor, arg1_1);  aten_sub_tensor = arg1_1 = None

            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:490, code: z = y + b
            aten_add_tensor_1: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_mm_default_1, arg2_1);  aten_mm_default_1 = arg2_1 = None
            return (aten_add_tensor_1,)

Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1', 'arg2_1'], user_outputs=['aten_add_tensor_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Range constraints: {}
Equality constraints: []

Delegated program
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[2, 2], arg1_1: f32[2, 2], arg2_1: f32[2, 2]):
            # No stacktrace found for following nodes
            lowered_module_0 = self.lowered_module_0
            executorch_call_delegate = torch.ops.executorch_call_delegate(lowered_module_0, arg0_1, arg1_1, arg2_1);  lowered_module_0 = None
            getitem: f32[2, 2] = executorch_call_delegate[0];  executorch_call_delegate = None

            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:488, code: a = z - a
            aten_sub_tensor: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_sub_Tensor(getitem, arg0_1);  getitem = arg0_1 = None

            # No stacktrace found for following nodes
            lowered_module_1 = self.lowered_module_1
            executorch_call_delegate_1 = torch.ops.executorch_call_delegate(lowered_module_1, aten_sub_tensor, arg1_1, arg2_1);  lowered_module_1 = aten_sub_tensor = arg1_1 = arg2_1 = None
            getitem_1: f32[2, 2] = executorch_call_delegate_1[0];  executorch_call_delegate_1 = None
            return (getitem_1,)

Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1', 'arg2_1'], user_outputs=['getitem_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Range constraints: {}
Equality constraints: []

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[2, 2], arg1_1: f32[2, 2], arg2_1: f32[2, 2]):
            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:486, code: y = torch.mm(a, x)
            aten_mm_default: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_mm_default(arg0_1, arg1_1);  arg0_1 = arg1_1 = None

            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:487, code: z = y + b
            aten_add_tensor: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_mm_default, arg2_1);  aten_mm_default = arg2_1 = None
            return [aten_add_tensor]

Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1', 'arg2_1'], user_outputs=['aten_add_tensor'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Range constraints: {}
Equality constraints: []

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, aten_sub_tensor: f32[2, 2], arg1_1: f32[2, 2], arg2_1: f32[2, 2]):
            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:489, code: y = torch.mm(a, x)
            aten_mm_default_1: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_mm_default(aten_sub_tensor, arg1_1);  aten_sub_tensor = arg1_1 = None

            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:490, code: z = y + b
            aten_add_tensor_1: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_mm_default_1, arg2_1);  aten_mm_default_1 = arg2_1 = None
            return [aten_add_tensor_1]

Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['aten_sub_tensor', 'arg1_1', 'arg2_1'], user_outputs=['aten_add_tensor_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Range constraints: {}
Equality constraints: []

Notice that there are now 2 torch.ops.executorch_call_delegate nodes in the graph, one containing the operations add, mul and the other containing the operations mul, add.

Alternatively, a more cohesive API to lower parts of a module is to directly call to_backend on it:

def f(a, x, b):
    y = torch.mm(a, x)
    z = y + b
    a = z - a
    y = torch.mm(a, x)
    z = y + b
    return z


example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
pre_autograd_aten_dialect = capture_pre_autograd_graph(f, example_args)
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
edge_program: EdgeProgramManager = to_edge(aten_dialect)
exported_program = edge_program.exported_program()
delegated_program = edge_program.to_backend(AddMulPartitionerDemo)

print("Delegated program")
print(delegated_program.exported_program())
Delegated program
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[2, 2], arg1_1: f32[2, 2], arg2_1: f32[2, 2]):
            # No stacktrace found for following nodes
            lowered_module_0 = self.lowered_module_0
            executorch_call_delegate = torch.ops.executorch_call_delegate(lowered_module_0, arg0_1, arg1_1, arg2_1);  lowered_module_0 = None
            getitem: f32[2, 2] = executorch_call_delegate[0];  executorch_call_delegate = None

            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:522, code: a = z - a
            aten_sub_tensor: f32[2, 2] = executorch_exir_dialects_edge__ops_aten_sub_Tensor(getitem, arg0_1);  getitem = arg0_1 = None

            # No stacktrace found for following nodes
            lowered_module_1 = self.lowered_module_1
            executorch_call_delegate_1 = torch.ops.executorch_call_delegate(lowered_module_1, aten_sub_tensor, arg1_1, arg2_1);  lowered_module_1 = aten_sub_tensor = arg1_1 = arg2_1 = None
            getitem_1: f32[2, 2] = executorch_call_delegate_1[0];  executorch_call_delegate_1 = None
            return (getitem_1,)

Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1', 'arg2_1'], user_outputs=['getitem_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Range constraints: {}
Equality constraints: []

Running User-Defined Passes and Memory Planning

As a final step of lowering, we can use the to_executorch() API to pass in backend-specific passes, such as replacing sets of operators with a custom backend operator, and a memory planning pass, to tell the runtime how to allocate memory ahead of time when running the program.

A default memory planning pass is provided, but we can also choose a backend-specific memory planning pass if it exists. More information on writing a custom memory planning pass can be found here

from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager
from executorch.exir.passes import MemoryPlanningPass

executorch_program: ExecutorchProgramManager = edge_program.to_executorch(
    ExecutorchBackendConfig(
        passes=[],  # User-defined passes
        memory_planning_pass=MemoryPlanningPass(
            "greedy"
        ),  # Default memory planning pass
    )
)

print("ExecuTorch Dialect")
print(executorch_program.exported_program())

import executorch.exir as exir
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/_pytree.py:590: UserWarning: pytree_to_str is deprecated. Please use treespec_dumps
  warnings.warn("pytree_to_str is deprecated. Please use treespec_dumps")
ExecuTorch Dialect
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[2, 2], arg1_1: f32[2, 2], arg2_1: f32[2, 2]):
            # No stacktrace found for following nodes
            alloc: f32[2, 2] = executorch_exir_memory_alloc(((2, 2), torch.float32))

            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:520, code: y = torch.mm(a, x)
            aten_mm_default: f32[2, 2] = torch.ops.aten.mm.out(arg0_1, arg1_1, out = alloc);  alloc = None

            # No stacktrace found for following nodes
            alloc_1: f32[2, 2] = executorch_exir_memory_alloc(((2, 2), torch.float32))

            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:521, code: z = y + b
            aten_add_tensor: f32[2, 2] = torch.ops.aten.add.out(aten_mm_default, arg2_1, out = alloc_1);  aten_mm_default = alloc_1 = None

            # No stacktrace found for following nodes
            alloc_2: f32[2, 2] = executorch_exir_memory_alloc(((2, 2), torch.float32))

            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:522, code: a = z - a
            aten_sub_tensor: f32[2, 2] = torch.ops.aten.sub.out(aten_add_tensor, arg0_1, out = alloc_2);  aten_add_tensor = arg0_1 = alloc_2 = None

            # No stacktrace found for following nodes
            alloc_3: f32[2, 2] = executorch_exir_memory_alloc(((2, 2), torch.float32))

            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:523, code: y = torch.mm(a, x)
            aten_mm_default_1: f32[2, 2] = torch.ops.aten.mm.out(aten_sub_tensor, arg1_1, out = alloc_3);  aten_sub_tensor = arg1_1 = alloc_3 = None

            # No stacktrace found for following nodes
            alloc_4: f32[2, 2] = executorch_exir_memory_alloc(((2, 2), torch.float32))

            # File: /pytorch/executorch/docs/source/tutorials_source/export-to-executorch-tutorial.py:524, code: z = y + b
            aten_add_tensor_1: f32[2, 2] = torch.ops.aten.add.out(aten_mm_default_1, arg2_1, out = alloc_4);  aten_mm_default_1 = arg2_1 = alloc_4 = None
            return (aten_add_tensor_1,)

Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1', 'arg2_1'], user_outputs=['aten_add_tensor_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Range constraints: {}
Equality constraints: []

Notice that in the graph we now see operators like torch.ops.aten.sub.out and torch.ops.aten.div.out rather than torch.ops.aten.sub.Tensor and torch.ops.aten.div.Tensor.

This is because between running the backend passes and memory planning passes, to prepare the graph for memory planning, an out-variant pass is run on the graph to convert all of the operators to their out variants. Instead of allocating returned tensors in the kernel implementations, an operator’s out variant will take in a prealloacated tensor to its out kwarg, and store the result there, making it easier for memory planners to do tensor lifetime analysis.

We also insert alloc nodes into the graph containing calls to a special executorch.exir.memory.alloc operator. This tells us how much memory is needed to allocate each tensor output by the out-variant operator.

Saving to a File

Finally, we can save the ExecuTorch Program to a file and load it to a device to be run.

Here is an example for an entire end-to-end workflow:

import torch
from torch._export import capture_pre_autograd_graph
from torch.export import export, ExportedProgram


class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)


example_args = (torch.randn(3, 4),)
pre_autograd_aten_dialect = capture_pre_autograd_graph(M(), example_args)
# Optionally do quantization:
# pre_autograd_aten_dialect = convert_pt2e(prepare_pt2e(pre_autograd_aten_dialect, CustomBackendQuantizer))
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
edge_program: exir.EdgeProgramManager = exir.to_edge(aten_dialect)
# Optionally do delegation:
# edge_program = edge_program.to_backend(CustomBackendPartitioner)
executorch_program: exir.ExecutorchProgramManager = edge_program.to_executorch(
    ExecutorchBackendConfig(
        passes=[],  # User-defined passes
    )
)

with open("model.pte", "wb") as file:
    file.write(executorch_program.buffer)
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/_pytree.py:590: UserWarning: pytree_to_str is deprecated. Please use treespec_dumps
  warnings.warn("pytree_to_str is deprecated. Please use treespec_dumps")

Conclusion

In this tutorial, we went over the APIs and steps required to lower a PyTorch program to a file that can be run on the ExecuTorch runtime.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources