TorchDynamo-based ONNX Exporter¶
Warning
The ONNX exporter for TorchDynamo is a rapidly evolving beta technology.
Overview¶
The ONNX exporter leverages TorchDynamo engine to hook into Python’s frame evaluation API and dynamically rewrite its bytecode into an FX Graph. The resulting FX Graph is then polished before it is finally translated into an ONNX graph.
The main advantage of this approach is that the FX graph is captured using bytecode analysis that preserves the dynamic nature of the model instead of using traditional static tracing techniques.
In addition, during the export process, memory usage is significantly reduced compared to the TorchScript-enabled exporter. See the documentation for more information.
The exporter is designed to be modular and extensible. It is composed of the following components:
ONNX Exporter:
Exporter
main class that orchestrates the export process.ONNX Export Options:
ExportOptions
has a set of options that control the export process.ONNX Registry:
OnnxRegistry
is the registry of ONNX operators and functions.FX Graph Extractor:
FXGraphExtractor
extracts the FX graph from the PyTorch model.Fake Mode:
ONNXFakeContext
is a context manager that enables fake mode for large scale models.ONNX Program:
ONNXProgram
is the output of the exporter that contains the exported ONNX graph and diagnostics.ONNX Diagnostic Options:
DiagnosticOptions
has a set of options that control the diagnostics emitted by the exporter.
Dependencies¶
The ONNX exporter depends on extra Python packages:
They can be installed through pip:
pip install --upgrade onnx onnxscript
onnxruntime can then be used to execute the model on a large variety of processors.
A simple example¶
See below a demonstration of exporter API in action with a simple Multilayer Perceptron (MLP) as example:
import torch
import torch.nn as nn
class MLPModel(nn.Module):
def __init__(self):
super().__init__()
self.fc0 = nn.Linear(8, 8, bias=True)
self.fc1 = nn.Linear(8, 4, bias=True)
self.fc2 = nn.Linear(4, 2, bias=True)
self.fc3 = nn.Linear(2, 2, bias=True)
def forward(self, tensor_x: torch.Tensor):
tensor_x = self.fc0(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
tensor_x = self.fc1(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
tensor_x = self.fc2(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
output = self.fc3(tensor_x)
return output
model = MLPModel()
tensor_x = torch.rand((97, 8), dtype=torch.float32)
onnx_program = torch.onnx.export(model, (tensor_x,), dynamo=True)
As the code above shows, all you need is to provide torch.onnx.export()
with an instance of the model and its input.
The exporter will then return an instance of torch.onnx.ONNXProgram
that contains the exported ONNX graph along with extra information.
The in-memory model available through onnx_program.model_proto
is an onnx.ModelProto
object in compliance with the ONNX IR spec.
The ONNX model may then be serialized into a Protobuf file using the torch.onnx.ONNXProgram.save()
API.
onnx_program.save("mlp.onnx")
Two functions exist to export the model to ONNX based on TorchDynamo engine.
They slightly differ in the way they produce the ExportedProgram
.
torch.onnx.dynamo_export()
was introduced with PyTorch 2.1 and
torch.onnx.export()
was extended with PyTorch 2.5 to easily switch
from TorchScript to TorchDynamo. To call the former function,
the last line of the previous example can be replaced by the following one.
onnx_program = torch.onnx.dynamo_export(model, tensor_x)
Inspecting the ONNX model using GUI¶
You can view the exported model using Netron.
Note that each layer is represented in a rectangular box with a f icon in the top right corner.
By expanding it, the function body is shown.
The function body is a sequence of ONNX operators or other functions.
When the conversion fails¶
Function torch.onnx.export()
should called a second time with
parameter report=True
. A markdown report is generated to help the user
to resolve the issue.
Function torch.onnx.dynamo_export()
generates a report using ‘SARIF’ format.
ONNX diagnostics goes beyond regular logs through the adoption of
Static Analysis Results Interchange Format (aka SARIF)
to help users debug and improve their model using a GUI, such as
Visual Studio Code’s SARIF Viewer.
The main advantages are:
The diagnostics are emitted in machine parseable Static Analysis Results Interchange Format (SARIF).
A new clearer, structured way to add new and keep track of diagnostic rules.
Serve as foundation for more future improvements consuming the diagnostics.
- FXE0007:fx-graph-to-onnx
- FXE0008:fx-node-to-onnx
- FXE0010:fx-pass
- FXE0011:no-symbolic-function-for-call-function
- FXE0012:unsupported-fx-node-analysis
- FXE0013:op-level-debugging
- FXE0014:find-opschema-matched-symbolic-function
- FXE0015:fx-node-insert-type-promotion
- FXE0016:find-operator-overloads-in-onnx-registry
API Reference¶
- torch.onnx.dynamo_export(model, /, *model_args, export_options=None, **model_kwargs)[source]¶
Export a torch.nn.Module to an ONNX graph.
- Parameters
model (torch.nn.Module | Callable | torch.export.ExportedProgram) – The PyTorch model to be exported to ONNX.
model_args – Positional inputs to
model
.model_kwargs – Keyword inputs to
model
.export_options (ExportOptions | None) – Options to influence the export to ONNX.
- Returns
An in-memory representation of the exported ONNX model.
- Return type
Example 1 - Simplest export
class MyModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(2, 2) def forward(self, x, bias=None): out = self.linear(x) out = out + bias return out model = MyModel() kwargs = {"bias": 3.0} args = (torch.randn(2, 2, 2),) onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save( "my_simple_model.onnx" )
Example 2 - Exporting with dynamic shapes
# The previous model can be exported with dynamic shapes export_options = torch.onnx.ExportOptions(dynamic_shapes=True) onnx_program = torch.onnx.dynamo_export( model, *args, **kwargs, export_options=export_options ) onnx_program.save("my_dynamic_model.onnx")
- class torch.onnx.ExportOptions(*, dynamic_shapes=None, fake_context=None, onnx_registry=None, diagnostic_options=None)¶
Options to influence the TorchDynamo ONNX exporter.
- Variables
dynamic_shapes (bool | None) – Shape information hint for input/output tensors. When
None
, the exporter determines the most compatible setting. WhenTrue
, all input shapes are considered dynamic. WhenFalse
, all input shapes are considered static.diagnostic_options (DiagnosticOptions) – The diagnostic options for the exporter.
fake_context (ONNXFakeContext | None) – The fake context used for symbolic tracing.
onnx_registry (OnnxRegistry | None) – The ONNX registry used to register ATen operators to ONNX functions.
- torch.onnx.enable_fake_mode()¶
Enable fake mode for the duration of the context.
Internally it instantiates a
torch._subclasses.fake_tensor.FakeTensorMode
context manager that converts user input and model parameters intotorch._subclasses.fake_tensor.FakeTensor
.A
torch._subclasses.fake_tensor.FakeTensor
is atorch.Tensor
with the ability to run PyTorch code without having to actually do computation through tensors allocated on ameta
device. Because there is no actual data being allocated on the device, this API allows for initializing and exporting large models without the actual memory footprint needed for executing it.It is highly recommended to initialize the model in fake mode when exporting models that are too large to fit into memory.
Example:
# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) >>> import torch >>> class MyModel(torch.nn.Module): # Model with a parameter ... def __init__(self) -> None: ... super().__init__() ... self.weight = torch.nn.Parameter(torch.tensor(42.0)) ... def forward(self, x): ... return self.weight + x >>> with torch.onnx.enable_fake_mode(): ... # When initialized in fake mode, the model's parameters are fake tensors ... # They do not take up memory so we can initialize large models ... my_nn_module = MyModel() ... arg1 = torch.randn(2, 2, 2) >>> onnx_program = torch.onnx.export(my_nn_module, (arg1,), dynamo=True) >>> # Saving model WITHOUT initializers (only the architecture) >>> onnx_program.save( ... "my_model_without_initializers.onnx", ... include_initializers=False, ... keep_initializers_as_inputs=True, ... ) >>> # Saving model WITH initializers after applying concrete weights >>> onnx_program.apply_weights({"weight": torch.tensor(42.0)}) >>> onnx_program.save("my_model_with_initializers.onnx")
Warning
This API is experimental and is NOT backward-compatible.
- class torch.onnx.ONNXProgram(model, exported_program)¶
A class to represent an ONNX program that is callable with torch tensors.
- apply_weights(state_dict)[source]¶
Apply the weights from the specified state dict to the ONNX model.
Use this method to replace FakeTensors or other weights.
- Parameters
state_dict (dict[str, torch.Tensor]) – The state dict containing the weights to apply to the ONNX model.
- initialize_inference_session(initializer=<function _ort_session_initializer>)[source]¶
Initialize the ONNX Runtime inference session.
- property model_proto: ModelProto¶
Return the ONNX
ModelProto
object.
- optimize()[source]¶
Optimize the ONNX model.
This method optimizes the ONNX model by performing constant folding and eliminating redundancies in the graph. The optimization is done in-place.
- release()[source]¶
Release the inference session.
You may call this method to release the resources used by the inference session.
- save(destination, *, include_initializers=True, keep_initializers_as_inputs=False, external_data=None)[source]¶
Save the ONNX model to the specified destination.
When
external_data
isTrue
or the model is larger than 2GB, the weights are saved as external data in a separate file.Initializer (model weights) serialization behaviors: *
include_initializers=True
,keep_initializers_as_inputs=False
(default): The initializers are included in the saved model. *include_initializers=True
,keep_initializers_as_inputs=True
: The initializers are included in the saved model and kept as model inputs. Choose this option if you want the ability to override the model weights during inference. *include_initializers=False
,keep_initializers_as_inputs=False
: The initializers are not included in the saved model and are not listed as model inputs. Choose this option if you want to attach the initializers to the ONNX model in a separate, post-processing, step. *include_initializers=False
,keep_initializers_as_inputs=True
: The initializers are not included in the saved model but are listed as model inputs. Choose this option if you want to supply the initializers during inference and want to minimize the size of the saved model.- Parameters
destination (str | os.PathLike) – The path to save the ONNX model to.
include_initializers (bool) – Whether to include the initializers in the saved model.
keep_initializers_as_inputs (bool) – Whether to keep the initializers as inputs in the saved model. If True, the initializers are added as inputs to the model which means they can be overwritten. by providing the initializers as model inputs.
external_data (bool | None) – Whether to save the weights as external data in a separate file.
- Raises
TypeError – If
external_data
isTrue
anddestination
is not a file path.
- class torch.onnx.ONNXRuntimeOptions(*, session_options=None, execution_providers=None, execution_provider_options=None)¶
Options to influence the execution of the ONNX model through ONNX Runtime.
- Variables
session_options (Sequence[onnxruntime.SessionOptions] | None) – ONNX Runtime session options.
execution_providers (Sequence[str | tuple[str, dict[Any, Any]]] | None) – ONNX Runtime execution providers to use during model execution.
execution_provider_options (Sequence[dict[Any, Any]] | None) – ONNX Runtime execution provider options.
- class torch.onnx.OnnxExporterError¶
Errors raised by the ONNX exporter. This is the base class for all exporter errors.
- class torch.onnx.OnnxRegistry¶
Registry for ONNX functions.
The registry maintains a mapping from qualified names to symbolic functions under a fixed opset version. It supports registering custom onnx-script functions and for dispatcher to dispatch calls to the appropriate function.
- get_op_functions(namespace, op_name, overload=None)[source]¶
Returns a list of ONNXFunctions for the given op: torch.ops.<namespace>.<op_name>.<overload>.
The list is ordered by the time of registration. The custom operators should be in the second half of the list.
- Parameters
- Returns
A list of ONNXFunctions corresponding to the given name, or None if the name is not in the registry.
- Return type
list[registration.ONNXFunction] | None
- is_registered_op(namespace, op_name, overload=None)[source]¶
Returns whether the given op is registered: torch.ops.<namespace>.<op_name>.<overload>.
- Parameters
- Returns
True if the given op is registered, otherwise False.
- Return type
- property opset_version: int¶
The ONNX opset version the exporter should target. Defaults to the latest supported ONNX opset version: 18. The default version will increment over time as ONNX continues to evolve.
- register_op(function, namespace, op_name, overload=None, is_complex=False)[source]¶
Registers a custom operator: torch.ops.<namespace>.<op_name>.<overload>.
- Parameters
function (onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction) – The onnx-sctip function to register.
namespace (str) – The namespace of the operator to register.
op_name (str) – The name of the operator to register.
overload (str | None) – The overload of the operator to register. If it’s default overload, leave it to None.
is_complex (bool) – Whether the function is a function that handles complex valued inputs.
- Raises
ValueError – If the name is not in the form of ‘namespace::op’.
- class torch.onnx.DiagnosticOptions(verbosity_level=20, warnings_as_errors=False)¶
Options for diagnostic context.