Shortcuts

TorchDynamo-based ONNX Exporter

Warning

The ONNX exporter for TorchDynamo is a rapidly evolving beta technology.

Overview

The ONNX exporter leverages TorchDynamo engine to hook into Python’s frame evaluation API and dynamically rewrite its bytecode into an FX Graph. The resulting FX Graph is then polished before it is finally translated into an ONNX graph.

The main advantage of this approach is that the FX graph is captured using bytecode analysis that preserves the dynamic nature of the model instead of using traditional static tracing techniques.

The exporter is designed to be modular and extensible. It is composed of the following components:

  • ONNX Exporter: Exporter main class that orchestrates the export process.

  • ONNX Export Options: ExportOptions has a set of options that control the export process.

  • ONNX Registry: OnnxRegistry is the registry of ONNX operators and functions.

  • FX Graph Extractor: FXGraphExtractor extracts the FX graph from the PyTorch model.

  • Fake Mode: ONNXFakeContext is a context manager that enables fake mode for large scale models.

  • ONNX Program: ONNXProgram is the output of the exporter that contains the exported ONNX graph and diagnostics.

  • ONNX Program Serializer: ONNXProgramSerializer serializes the exported model to a file.

  • ONNX Diagnostic Options: DiagnosticOptions has a set of options that control the diagnostics emitted by the exporter.

Dependencies

The ONNX exporter depends on extra Python packages:

They can be installed through pip:

pip install --upgrade onnx onnxscript

A simple example

See below a demonstration of exporter API in action with a simple Multilayer Perceptron (MLP) as example:

import torch
import torch.nn as nn

class MLPModel(nn.Module):
  def __init__(self):
      super().__init__()
      self.fc0 = nn.Linear(8, 8, bias=True)
      self.fc1 = nn.Linear(8, 4, bias=True)
      self.fc2 = nn.Linear(4, 2, bias=True)
      self.fc3 = nn.Linear(2, 2, bias=True)

  def forward(self, tensor_x: torch.Tensor):
      tensor_x = self.fc0(tensor_x)
      tensor_x = torch.sigmoid(tensor_x)
      tensor_x = self.fc1(tensor_x)
      tensor_x = torch.sigmoid(tensor_x)
      tensor_x = self.fc2(tensor_x)
      tensor_x = torch.sigmoid(tensor_x)
      output = self.fc3(tensor_x)
      return output

model = MLPModel()
tensor_x = torch.rand((97, 8), dtype=torch.float32)
onnx_program = torch.onnx.dynamo_export(model, tensor_x)

As the code above shows, all you need is to provide torch.onnx.dynamo_export() with an instance of the model and its input. The exporter will then return an instance of torch.onnx.ONNXProgram that contains the exported ONNX graph along with extra information.

The in-memory model available through onnx_program.model_proto is an onnx.ModelProto object in compliance with the ONNX IR spec. The ONNX model may then be serialized into a Protobuf file using the torch.onnx.ONNXProgram.save() API.

onnx_program.save("mlp.onnx")

Inspecting the ONNX model using GUI

You can view the exported model using Netron.

MLP model as viewed using Netron

Note that each layer is represented in a rectangular box with a f icon in the top right corner.

ONNX function highlighted on MLP model

By expanding it, the function body is shown.

ONNX function body

The function body is a sequence of ONNX operators or other functions.

Diagnosing issues with SARIF

ONNX diagnostics goes beyond regular logs through the adoption of Static Analysis Results Interchange Format (aka SARIF) to help users debug and improve their model using a GUI, such as Visual Studio Code’s SARIF Viewer.

The main advantages are:

  • The diagnostics are emitted in machine parseable Static Analysis Results Interchange Format (SARIF).

  • A new clearer, structured way to add new and keep track of diagnostic rules.

  • Serve as foundation for more future improvements consuming the diagnostics.

API Reference

torch.onnx.dynamo_export(model, /, *model_args, export_options=None, **model_kwargs)

Export a torch.nn.Module to an ONNX graph.

Parameters
Returns

An in-memory representation of the exported ONNX model.

Return type

ONNXProgram

Example 1 - Simplest export

class MyModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(2, 2)
    def forward(self, x, bias=None):
        out = self.linear(x)
        out = out + bias
        return out
model = MyModel()
kwargs = {"bias": 3.}
args = (torch.randn(2, 2, 2),)
onnx_program = torch.onnx.dynamo_export(
    model,
    *args,
    **kwargs).save("my_simple_model.onnx")

Example 2 - Exporting with dynamic shapes

# The previous model can be exported with dynamic shapes
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
onnx_program = torch.onnx.dynamo_export(
    model,
    *args,
    **kwargs,
    export_options=export_options)
onnx_program.save("my_dynamic_model.onnx")

By printing input dynamic dimensions we can see the input shape is no longer (2,2,2)

>>> print(onnx_program.model_proto.graph.input[0])
name: "arg0"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_param: "arg0_dim_0"
      }
      dim {
        dim_param: "arg0_dim_1"
      }
      dim {
        dim_param: "arg0_dim_2"
      }
    }
  }
}
class torch.onnx.ExportOptions(*, dynamic_shapes=None, op_level_debug=None, fake_context=None, onnx_registry=None, diagnostic_options=None)

Options to influence the TorchDynamo ONNX exporter.

Variables
  • dynamic_shapes (Optional[bool]) – Shape information hint for input/output tensors. When None, the exporter determines the most compatible setting. When True, all input shapes are considered dynamic. When False, all input shapes are considered static.

  • op_level_debug (Optional[bool]) – Whether to export the model with op-level debug information

  • diagnostic_options (DiagnosticOptions) – The diagnostic options for the exporter.

  • fake_context (Optional[ONNXFakeContext]) – The fake context used for symbolic tracing.

  • onnx_registry (Optional[OnnxRegistry]) – The ONNX registry used to register ATen operators to ONNX functions.

torch.onnx.enable_fake_mode()

Enable fake mode for the duration of the context.

Internally it instantiates a torch._subclasses.fake_tensor.FakeTensorMode context manager that converts user input and model parameters into torch._subclasses.fake_tensor.FakeTensor.

A torch._subclasses.fake_tensor.FakeTensor is a torch.Tensor with the ability to run PyTorch code without having to actually do computation through tensors allocated on a meta device. Because there is no actual data being allocated on the device, this API allows for exporting large models without the actual memory footprint needed for executing it.

It is highly recommended to enable fake mode when exporting models that are too large to fit into memory.

Returns

A ONNXFakeContext object that must be passed to dynamo_export() through the ExportOptions.fake_context argument.

Example:

# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
>>> import torch
>>> import torch.onnx
>>> class MyModel(torch.nn.Module):  # Dummy model
...     def __init__(self) -> None:
...         super().__init__()
...         self.linear = torch.nn.Linear(2, 2)
...     def forward(self, x):
...         out = self.linear(x)
...         return out
>>> with torch.onnx.enable_fake_mode() as fake_context:
...     my_nn_module = MyModel()
...     arg1 = torch.randn(2, 2, 2)  # positional input 1
>>> export_options = torch.onnx.ExportOptions(fake_context=fake_context)
>>> onnx_program = torch.onnx.dynamo_export(
...     my_nn_module,
...     arg1,
...     export_options=export_options
... )
>>> # Saving model WITHOUT initializers
>>> onnx_program.save("my_model_without_initializers.onnx")
>>> # Saving model WITH initializers
>>> onnx_program.save("my_model_with_initializers.onnx", model_state=MyModel().state_dict())

Warning

This API is experimental and is NOT backward-compatible.

class torch.onnx.ONNXProgram(model_proto, input_adapter, output_adapter, diagnostic_context, *, fake_context=None, export_exception=None, model_signature=None, model_torch=None)

An in-memory representation of a PyTorch model that has been exported to ONNX.

Parameters
  • model_proto (onnx.ModelProto) – The exported ONNX model as an onnx.ModelProto.

  • input_adapter (io_adapter.InputAdapter) – The input adapter used to convert PyTorch inputs into ONNX inputs.

  • output_adapter (io_adapter.OutputAdapter) – The output adapter used to convert PyTorch outputs into ONNX outputs.

  • diagnostic_context (diagnostics.DiagnosticContext) – Context object for the SARIF diagnostic system responsible for logging errors and metadata.

  • fake_context (Optional[ONNXFakeContext]) – The fake context used for symbolic tracing.

  • export_exception (Optional[Exception]) – The exception that occurred during export, if any.

  • model_signature (Optional[torch.export.ExportGraphSignature]) – The model signature for the exported ONNX graph.

adapt_torch_inputs_to_onnx(*model_args, model_with_state_dict=None, **model_kwargs)[source]

Converts the PyTorch model inputs to exported ONNX model inputs format.

Due to design differences, input/output format between PyTorch model and exported ONNX model are often not the same. E.g., None is allowed for PyTorch model, but are not supported by ONNX. Nested constructs of tensors are allowed for PyTorch model, but only flattened tensors are supported by ONNX, etc.

The actual adapting steps are associated with each individual export. It depends on the PyTorch model, the particular set of model_args and model_kwargs used for the export, and export options.

This method replays the adapting steps recorded during export.

Parameters
  • model_args – The PyTorch model inputs.

  • model_with_state_dict (Optional[Union[Module, Callable, ExportedProgram]]) – The PyTorch model to get extra state from. If not specified, the model used during export is used. Required when enable_fake_mode() is used to extract real initializers as needed by the ONNX graph.

  • model_kwargs – The PyTorch model keyword inputs.

Returns

A sequence of tensors converted from PyTorch model inputs.

Return type

Sequence[Union[Tensor, int, float, bool, dtype]]

Example:

# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
>>> import torch
>>> import torch.onnx
>>> from typing import Dict, Tuple
>>> def func_nested_input(
...     x_dict: Dict[str, torch.Tensor],
...     y_tuple: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
... ):
...     if "a" in x_dict:
...         x = x_dict["a"]
...     elif "b" in x_dict:
...         x = x_dict["b"]
...     else:
...         x = torch.randn(3)
...
...     y1, (y2, y3) = y_tuple
...
...     return x + y1 + y2 + y3
>>> x_dict = {"a": torch.tensor(1.)}
>>> y_tuple = (torch.tensor(2.), (torch.tensor(3.), torch.tensor(4.)))
>>> onnx_program = torch.onnx.dynamo_export(func_nested_input, x_dict, y_tuple)
>>> print(x_dict, y_tuple)
{'a': tensor(1.)} (tensor(2.), (tensor(3.), tensor(4.)))
>>> print(onnx_program.adapt_torch_inputs_to_onnx(x_dict, y_tuple, model_with_state_dict=func_nested_input))
(tensor(1.), tensor(2.), tensor(3.), tensor(4.))

Warning

This API is experimental and is NOT backward-compatible.

adapt_torch_outputs_to_onnx(model_outputs, model_with_state_dict=None)[source]

Converts the PyTorch model outputs to exported ONNX model outputs format.

Due to design differences, input/output format between PyTorch model and exported ONNX model are often not the same. E.g., None is allowed for PyTorch model, but are not supported by ONNX. Nested constructs of tensors are allowed for PyTorch model, but only flattened tensors are supported by ONNX, etc.

The actual adapting steps are associated with each individual export. It depends on the PyTorch model, the particular set of model_args and model_kwargs used for the export, and export options.

This method replays the adapting steps recorded during export.

Parameters
  • model_outputs (Any) – The PyTorch model outputs.

  • model_with_state_dict (Optional[Union[Module, Callable, ExportedProgram]]) – The PyTorch model to get extra state from. If not specified, the model used during export is used. Required when enable_fake_mode() is used to extract real initializers as needed by the ONNX graph.

Returns

PyTorch model outputs in exported ONNX model outputs format.

Return type

Sequence[Union[Tensor, int, float, bool]]

Example:

# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
>>> import torch
>>> import torch.onnx
>>> def func_returning_tuples(x, y, z):
...     x = x + y
...     y = y + z
...     z = x + y
...     return (x, (y, z))
>>> x = torch.tensor(1.)
>>> y = torch.tensor(2.)
>>> z = torch.tensor(3.)
>>> onnx_program = torch.onnx.dynamo_export(func_returning_tuples, x, y, z)
>>> pt_output = func_returning_tuples(x, y, z)
>>> print(pt_output)
(tensor(3.), (tensor(5.), tensor(8.)))
>>> print(onnx_program.adapt_torch_outputs_to_onnx(pt_output, model_with_state_dict=func_returning_tuples))
[tensor(3.), tensor(5.), tensor(8.)]

Warning

This API is experimental and is NOT backward-compatible.

property diagnostic_context: diagnostics.DiagnosticContext

The diagnostic context associated with the export.

property fake_context: Optional[ONNXFakeContext]

The fake context associated with the export.

property model_proto: onnx.ModelProto

The exported ONNX model as an onnx.ModelProto.

property model_signature: Optional[ExportGraphSignature]

The model signature for the exported ONNX graph.

This information is relevant because ONNX specification often differs from PyTorch’s, resulting in a ONNX graph with input and output schema different from the actual PyTorch model implementation. By using the model signature, the users can understand the inputs and outputs differences and properly execute the model in ONNX Runtime.

NOTE: Model signature is only available when the ONNX graph was exported from a torch.export.ExportedProgram object.

NOTE: Any transformation done to the model that changes the model signature must be accompanied by updates to this model signature as well through InputAdaptStep and/or OutputAdaptStep.

Example

The following model produces different sets of inputs and outputs. The first 4 inputs are model parameters (namely conv1.weight, conv2.weight, fc1.weight, fc2.weight), and the next 2 inputs are registered buffers (namely my_buffer2, my_buffer1) and finally the last 2 inputs are user inputs (namely x and b). The first output is a buffer mutation (namely my_buffer2) and the last output is the actual model output.

>>> import pprint
>>> class CustomModule(torch.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
...         self.register_buffer("my_buffer1", torch.tensor(3.0))
...         self.register_buffer("my_buffer2", torch.tensor(4.0))
...         self.conv1 = torch.nn.Conv2d(1, 32, 3, 1, bias=False)
...         self.conv2 = torch.nn.Conv2d(32, 64, 3, 1, bias=False)
...         self.fc1 = torch.nn.Linear(9216, 128, bias=False)
...         self.fc2 = torch.nn.Linear(128, 10, bias=False)
...     def forward(self, x, b):
...         tensor_x = self.conv1(x)
...         tensor_x = torch.nn.functional.sigmoid(tensor_x)
...         tensor_x = self.conv2(tensor_x)
...         tensor_x = torch.nn.functional.sigmoid(tensor_x)
...         tensor_x = torch.nn.functional.max_pool2d(tensor_x, 2)
...         tensor_x = torch.flatten(tensor_x, 1)
...         tensor_x = self.fc1(tensor_x)
...         tensor_x = torch.nn.functional.sigmoid(tensor_x)
...         tensor_x = self.fc2(tensor_x)
...         output = torch.nn.functional.log_softmax(tensor_x, dim=1)
...         (
...         self.my_buffer2.add_(1.0) + self.my_buffer1
...         )  # Mutate buffer through in-place addition
...         return output
>>> inputs = (torch.rand((64, 1, 28, 28), dtype=torch.float32), torch.randn(3))
>>> exported_program = torch.export.export(CustomModule(), args=inputs).run_decompositions({})
>>> onnx_program = torch.onnx.dynamo_export(exported_program, *inputs)
>>> pprint.pprint(onnx_program.model_signature)
ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>,
                                      arg=TensorArgument(name='p_conv1_weight'),
                                      target='conv1.weight',
                                      persistent=None),
                            InputSpec(kind=<InputKind.PARAMETER: 2>,
                                      arg=TensorArgument(name='p_conv2_weight'),
                                      target='conv2.weight',
                                      persistent=None),
                            InputSpec(kind=<InputKind.PARAMETER: 2>,
                                      arg=TensorArgument(name='p_fc1_weight'),
                                      target='fc1.weight',
                                      persistent=None),
                            InputSpec(kind=<InputKind.PARAMETER: 2>,
                                      arg=TensorArgument(name='p_fc2_weight'),
                                      target='fc2.weight',
                                      persistent=None),
                            InputSpec(kind=<InputKind.BUFFER: 3>,
                                      arg=TensorArgument(name='b_my_buffer2'),
                                      target='my_buffer2',
                                      persistent=True),
                            InputSpec(kind=<InputKind.BUFFER: 3>,
                                      arg=TensorArgument(name='b_my_buffer1'),
                                      target='my_buffer1',
                                      persistent=True),
                            InputSpec(kind=<InputKind.USER_INPUT: 1>,
                                      arg=TensorArgument(name='x'),
                                      target=None,
                                      persistent=None),
                            InputSpec(kind=<InputKind.USER_INPUT: 1>,
                                      arg=TensorArgument(name='b'),
                                      target=None,
                                      persistent=None)],
               output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>,
                                        arg=TensorArgument(name='add'),
                                        target='my_buffer2'),
                             OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>,
                                        arg=TensorArgument(name='_log_softmax'),
                                        target=None)])
save(destination, *, include_initializers=True, model_state=None, serializer=None)[source]

Saves the in-memory ONNX model to destination using specified serializer.

Parameters
  • destination (Union[str, BufferedIOBase]) – The destination to save the ONNX model. It can be either a string or a file-like object. When used with model_state, it must be a string with a full path to the destination. If destination is a string, besides saving the ONNX model into a file, model weights are also stored in separate files in the same directory as the ONNX model. E.g. for destination=”/path/model.onnx”, the initializers are saved in “/path/” folder along with “onnx.model”.

  • include_initializers (bool) – Whether to include initializers in the ONNX graph as external data. Cannot be combined with model_state_dict.

  • model_state (Optional[Union[Dict[str, Any], str]]) – The state_dict of the PyTorch model containing all weights on it. It can be either a string with the path to a checkpoint or a dictionary with the actual model state. The supported file formats are the same as those supported by torch.load and safetensors.safe_open. Required when enable_fake_mode() is used but real initializers are needed on the ONNX graph.

  • serializer (Optional[ONNXProgramSerializer]) – The serializer to use. If not specified, the model will be serialized as Protobuf.

save_diagnostics(destination)[source]

Saves the export diagnostics as a SARIF log to the specified destination path.

Parameters

destination (str) – The destination to save the diagnostics SARIF log. It must have a .sarif extension.

Raises

ValueError – If the destination path does not end with .sarif extension.

class torch.onnx.ONNXProgramSerializer(*args, **kwargs)

Protocol for serializing an ONNX graph into a specific format (e.g. Protobuf). Note that this is an advanced usage scenario.

serialize(onnx_program, destination)[source]

Protocol method that must be implemented for serialization.

Parameters
  • onnx_program (ONNXProgram) – Represents the in-memory exported ONNX model

  • destination (BufferedIOBase) – A binary IO stream or pre-allocated buffer into which the serialized model should be written.

Example

A simple serializer that writes the exported onnx.ModelProto in Protobuf format to destination:

# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
>>> import io
>>> import torch
>>> import torch.onnx
>>> class MyModel(torch.nn.Module):  # Dummy model
...     def __init__(self) -> None:
...         super().__init__()
...         self.linear = torch.nn.Linear(2, 2)
...     def forward(self, x):
...         out = self.linear(x)
...         return out
>>> class ProtobufONNXProgramSerializer:
...     def serialize(
...         self, onnx_program: torch.onnx.ONNXProgram, destination: io.BufferedIOBase
...     ) -> None:
...         destination.write(onnx_program.model_proto.SerializeToString())
>>> model = MyModel()
>>> arg1 = torch.randn(2, 2, 2)  # positional input 1
>>> torch.onnx.dynamo_export(model, arg1).save(
...     destination="exported_model.onnx",
...     serializer=ProtobufONNXProgramSerializer(),
... )
class torch.onnx.ONNXRuntimeOptions(*, session_options=None, execution_providers=None, execution_provider_options=None)

Options to influence the execution of the ONNX model through ONNX Runtime.

Variables
  • session_options (Optional[Sequence['onnxruntime.SessionOptions']]) – ONNX Runtime session options.

  • execution_providers (Optional[Sequence[Union[str, Tuple[str, Dict[Any, Any]]]]]) – ONNX Runtime execution providers to use during model execution.

  • execution_provider_options (Optional[Sequence[Dict[Any, Any]]]) – ONNX Runtime execution provider options.

class torch.onnx.InvalidExportOptionsError

Raised when user specified an invalid value for the ExportOptions.

class torch.onnx.OnnxExporterError(onnx_program, message)

Raised when an ONNX exporter error occurs.

This exception is thrown when there’s an error during the ONNX export process. It encapsulates the ONNXProgram object generated until the failure, allowing access to the partial export results and associated metadata.

class torch.onnx.OnnxRegistry

Registry for ONNX functions.

The registry maintains a mapping from qualified names to symbolic functions under a fixed opset version. It supports registering custom onnx-script functions and for dispatcher to dispatch calls to the appropriate function.

get_op_functions(namespace, op_name, overload=None)[source]

Returns a list of ONNXFunctions for the given op: torch.ops.<namespace>.<op_name>.<overload>.

The list is ordered by the time of registration. The custom operators should be in the second half of the list.

Parameters
  • namespace (str) – The namespace of the operator to get.

  • op_name (str) – The name of the operator to get.

  • overload (Optional[str]) – The overload of the operator to get. If it’s default overload, leave it to None.

Returns

A list of ONNXFunctions corresponding to the given name, or None if the name is not in the registry.

Return type

Optional[List[ONNXFunction]]

is_registered_op(namespace, op_name, overload=None)[source]

Returns whether the given op is registered: torch.ops.<namespace>.<op_name>.<overload>.

Parameters
  • namespace (str) – The namespace of the operator to check.

  • op_name (str) – The name of the operator to check.

  • overload (Optional[str]) – The overload of the operator to check. If it’s default overload, leave it to None.

Returns

True if the given op is registered, otherwise False.

Return type

bool

property opset_version: int

The ONNX opset version the exporter should target. Defaults to the latest supported ONNX opset version: 18. The default version will increment over time as ONNX continues to evolve.

register_op(function, namespace, op_name, overload=None, is_complex=False)[source]

Registers a custom operator: torch.ops.<namespace>.<op_name>.<overload>.

Parameters
  • function (Union[onnxscript.OnnxFunction, onnxscript.TracedOnnxFunction]) – The onnx-sctip function to register.

  • namespace (str) – The namespace of the operator to register.

  • op_name (str) – The name of the operator to register.

  • overload (Optional[str]) – The overload of the operator to register. If it’s default overload, leave it to None.

  • is_complex (bool) – Whether the function is a function that handles complex valued inputs.

Raises

ValueError – If the name is not in the form of ‘namespace::op’.

class torch.onnx.DiagnosticOptions(verbosity_level=20, warnings_as_errors=False)

Options for diagnostic context.

Variables
  • verbosity_level (int) – Set the amount of information logged for each diagnostics, equivalent to the ‘level’ in Python logging module.

  • warnings_as_errors (bool) – When True, warning diagnostics are treated as error diagnostics.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources