Note
This page describes an internal API which is not intended to be used outside of the PyTorch codebase and can be modified or removed without notice.
TorchDynamo-based ONNX Exporter
Overview
The ONNX exporter leverages TorchDynamo engine to hook into Python’s frame evaluation API and dynamically rewrite its bytecode into an FX Graph. The resulting FX Graph is then polished before it is finally translated into an ONNX graph.
The main advantage of this approach is that the FX graph is captured using bytecode analysis that preserves the dynamic nature of the model instead of using traditional static tracing techniques.
In addition, during the export process, memory usage is significantly reduced compared to the TorchScript-enabled exporter. See the memory usage documentation for more information.
Dependencies
The ONNX exporter depends on extra Python packages:
They can be installed through pip:
pip install --upgrade onnx onnxscript
onnxruntime can then be used to execute the model on a large variety of processors.
A simple example
See below a demonstration of exporter API in action with a simple Multilayer Perceptron (MLP) as example:
import torch
import torch.nn as nn
class MLPModel(nn.Module):
def __init__(self):
super().__init__()
self.fc0 = nn.Linear(8, 8, bias=True)
self.fc1 = nn.Linear(8, 4, bias=True)
self.fc2 = nn.Linear(4, 2, bias=True)
self.fc3 = nn.Linear(2, 2, bias=True)
def forward(self, tensor_x: torch.Tensor):
tensor_x = self.fc0(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
tensor_x = self.fc1(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
tensor_x = self.fc2(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
output = self.fc3(tensor_x)
return output
model = MLPModel()
tensor_x = torch.rand((97, 8), dtype=torch.float32)
onnx_program = torch.onnx.export(model, (tensor_x,), dynamo=True)
As the code above shows, all you need is to provide torch.onnx.export()
with an instance of the model and its input.
The exporter will then return an instance of torch.onnx.ONNXProgram
that contains the exported ONNX graph along with extra information.
onnx_program.optimize()
can be called to optimize the ONNX graph with constant folding and elimination of redundant operators. The optimization is done in-place.
onnx_program.optimize()
The in-memory model available through onnx_program.model_proto
is an onnx.ModelProto
object in compliance with the ONNX IR spec.
The ONNX model may then be serialized into a Protobuf file using the torch.onnx.ONNXProgram.save()
API.
onnx_program.save("mlp.onnx")
When the conversion fails
Function torch.onnx.export()
should called a second time with
parameter report=True
. A markdown report is generated to help the user
to resolve the issue.
API Reference
- torch.onnx.export(model, args=(), f=None, *, kwargs=None, export_params=True, verbose=None, input_names=None, output_names=None, opset_version=None, dynamic_axes=None, keep_initializers_as_inputs=False, dynamo=False, external_data=True, dynamic_shapes=None, custom_translation_table=None, report=False, optimize=True, verify=False, profile=False, dump_exported_program=False, artifacts_dir='.', fallback=False, training=<TrainingMode.EVAL: 0>, operator_export_type=<OperatorExportTypes.ONNX: 0>, do_constant_folding=True, custom_opsets=None, export_modules_as_functions=False, autograd_inlining=True)[source][source]
Exports a model into ONNX format.
Setting
dynamo=True
enables the new ONNX export logic which is based ontorch.export.ExportedProgram
and a more modern set of translation logic. This is the recommended way to export models to ONNX.When
dynamo=True
:The exporter tries the following strategies to get an ExportedProgram for conversion to ONNX.
If the model is already an ExportedProgram, it will be used as-is.
Use
torch.export.export()
and setstrict=False
.Use
torch.export.export()
and setstrict=True
.Use
draft_export
which removes some soundness guarantees in data-dependent operations to allow export to proceed. You will get a warning if the exporter encounters any unsound data-dependent operation.Use
torch.jit.trace()
to trace the model then convert to ExportedProgram. This is the most unsound strategy but may be useful for converting TorchScript models to ONNX.
- Parameters
model (torch.nn.Module | torch.export.ExportedProgram | torch.jit.ScriptModule | torch.jit.ScriptFunction) – The model to be exported.
args (tuple[Any, ...]) – Example positional inputs. Any non-Tensor arguments will be hard-coded into the exported model; any Tensor arguments will become inputs of the exported model, in the order they occur in the tuple.
f (str | os.PathLike | None) – Path to the output ONNX model file. E.g. “model.onnx”.
kwargs (dict[str, Any] | None) – Optional example keyword inputs.
export_params (bool) – If false, parameters (weights) will not be exported.
verbose (bool | None) – Whether to enable verbose logging.
input_names (Sequence[str] | None) – names to assign to the input nodes of the graph, in order.
output_names (Sequence[str] | None) – names to assign to the output nodes of the graph, in order.
opset_version (int | None) – The version of the default (ai.onnx) opset to target. Must be >= 7.
dynamic_axes (Mapping[str, Mapping[int, str]] | Mapping[str, Sequence[int]] | None) –
By default the exported model will have the shapes of all input and output tensors set to exactly match those given in
args
. To specify axes of tensors as dynamic (i.e. known only at run-time), setdynamic_axes
to a dict with schema:- KEY (str): an input or output name. Each name must also be provided in
input_names
or output_names
.
- KEY (str): an input or output name. Each name must also be provided in
- VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a
list, each element is an axis index.
For example:
class SumModule(torch.nn.Module): def forward(self, x): return torch.sum(x, dim=1) torch.onnx.export( SumModule(), (torch.ones(2, 2),), "onnx.pb", input_names=["x"], output_names=["sum"], )
Produces:
input { name: "x" ... shape { dim { dim_value: 2 # axis 0 } dim { dim_value: 2 # axis 1 ... output { name: "sum" ... shape { dim { dim_value: 2 # axis 0 ...
While:
torch.onnx.export( SumModule(), (torch.ones(2, 2),), "onnx.pb", input_names=["x"], output_names=["sum"], dynamic_axes={ # dict value: manually named axes "x": {0: "my_custom_axis_name"}, # list value: automatic names "sum": [0], }, )
Produces:
input { name: "x" ... shape { dim { dim_param: "my_custom_axis_name" # axis 0 } dim { dim_value: 2 # axis 1 ... output { name: "sum" ... shape { dim { dim_param: "sum_dynamic_axes_1" # axis 0 ...
keep_initializers_as_inputs (bool) –
If True, all the initializers (typically corresponding to model weights) in the exported graph will also be added as inputs to the graph. If False, then initializers are not added as inputs to the graph, and only the user inputs are added as inputs.
Set this to True if you intend to supply model weights at runtime. Set it to False if the weights are static to allow for better optimizations (e.g. constant folding) by backends/runtimes.
dynamo (bool) – Whether to export the model with
torch.export
ExportedProgram instead of TorchScript.external_data (bool) – Whether to save the model weights as an external data file. This is required for models with large weights that exceed the ONNX file size limit (2GB). When False, the weights are saved in the ONNX file with the model architecture.
dynamic_shapes (dict[str, Any] | tuple[Any, ...] | list[Any] | None) – A dictionary or a tuple of dynamic shapes for the model inputs. Refer to
torch.export.export()
for more details. This is only used (and preferred) when dynamo is True. Note that dynamic_shapes is designed to be used when the model is exported with dynamo=True, while dynamic_axes is used when dynamo=False.custom_translation_table (dict[Callable, Callable | Sequence[Callable]] | None) – A dictionary of custom decompositions for operators in the model. The dictionary should have the callable target in the fx Node as the key (e.g.
torch.ops.aten.stft.default
), and the value should be a function that builds that graph using ONNX Script. This option is only valid when dynamo is True.report (bool) – Whether to generate a markdown report for the export process. This option is only valid when dynamo is True.
optimize (bool) – Whether to optimize the exported model. This option is only valid when dynamo is True. Default is True.
verify (bool) – Whether to verify the exported model using ONNX Runtime. This option is only valid when dynamo is True.
profile (bool) – Whether to profile the export process. This option is only valid when dynamo is True.
dump_exported_program (bool) – Whether to dump the
torch.export.ExportedProgram
to a file. This is useful for debugging the exporter. This option is only valid when dynamo is True.artifacts_dir (str | os.PathLike) – The directory to save the debugging artifacts like the report and the serialized exported program. This option is only valid when dynamo is True.
fallback (bool) – Whether to fallback to the TorchScript exporter if the dynamo exporter fails. This option is only valid when dynamo is True. When fallback is enabled, It is recommended to set dynamic_axes even when dynamic_shapes is provided.
training (_C_onnx.TrainingMode) – Deprecated option. Instead, set the training mode of the model before exporting.
operator_export_type (_C_onnx.OperatorExportTypes) – Deprecated option. Only ONNX is supported.
do_constant_folding (bool) – Deprecated option.
custom_opsets (Mapping[str, int] | None) –
Deprecated. A dictionary:
KEY (str): opset domain name
VALUE (int): opset version
If a custom opset is referenced by
model
but not mentioned in this dictionary, the opset version is set to 1. Only custom opset domain name and version should be indicated through this argument.export_modules_as_functions (bool | Collection[type[torch.nn.Module]]) –
Deprecated option.
Flag to enable exporting all
nn.Module
forward calls as local functions in ONNX. Or a set to indicate the particular types of modules to export as local functions in ONNX. This feature requiresopset_version
>= 15, otherwise the export will fail. This is becauseopset_version
< 15 implies IR version < 8, which means no local function support. Module variables will be exported as function attributes. There are two categories of function attributes.1. Annotated attributes: class variables that have type annotations via PEP 526-style will be exported as attributes. Annotated attributes are not used inside the subgraph of ONNX local function because they are not created by PyTorch JIT tracing, but they may be used by consumers to determine whether or not to replace the function with a particular fused kernel.
2. Inferred attributes: variables that are used by operators inside the module. Attribute names will have prefix “inferred::”. This is to differentiate from predefined attributes retrieved from python module annotations. Inferred attributes are used inside the subgraph of ONNX local function.
False
(default): exportnn.Module
forward calls as fine grained nodes.True
: export allnn.Module
forward calls as local function nodes.- Set of type of nn.Module: export
nn.Module
forward calls as local function nodes, only if the type of the
nn.Module
is found in the set.
- Set of type of nn.Module: export
autograd_inlining (bool) – Deprecated. Flag used to control whether to inline autograd functions. Refer to https://github.com/pytorch/pytorch/pull/74765 for more details.
- Returns
torch.onnx.ONNXProgram
if dynamo is True, otherwise None.- Return type
ONNXProgram | None
Changed in version 2.6: training is now deprecated. Instead, set the training mode of the model before exporting. operator_export_type is now deprecated. Only ONNX is supported. do_constant_folding is now deprecated. It is always enabled. export_modules_as_functions is now deprecated. autograd_inlining is now deprecated.
Changed in version 2.7: optimize is now True by default.
- class torch.onnx.ONNXProgram(model, exported_program)
A class to represent an ONNX program that is callable with torch tensors.
- Variables
model – The ONNX model as an ONNX IR model object.
exported_program – The exported program that produced the ONNX model.
- apply_weights(state_dict)[source][source]
Apply the weights from the specified state dict to the ONNX model.
Use this method to replace FakeTensors or other weights.
- Parameters
state_dict (dict[str, torch.Tensor]) – The state dict containing the weights to apply to the ONNX model.
- compute_values(value_names, args=(), kwargs=None)[source][source]
Compute the values of the specified names in the ONNX model.
This method is used to compute the values of the specified names in the ONNX model. The values are returned as a dictionary mapping names to tensors.
- Parameters
value_names (Sequence[str]) – The names of the values to compute.
- Returns
A dictionary mapping names to tensors.
- Return type
Sequence[torch.Tensor]
- initialize_inference_session(initializer=<function _ort_session_initializer>)[source][source]
Initialize the ONNX Runtime inference session.
- optimize()[source][source]
Optimize the ONNX model.
This method optimizes the ONNX model by performing constant folding and eliminating redundancies in the graph. The optimization is done in-place.
- release()[source][source]
Release the inference session.
You may call this method to release the resources used by the inference session.
- save(destination, *, include_initializers=True, keep_initializers_as_inputs=False, external_data=None)[source][source]
Save the ONNX model to the specified destination.
When
external_data
isTrue
or the model is larger than 2GB, the weights are saved as external data in a separate file.Initializer (model weights) serialization behaviors: *
include_initializers=True
,keep_initializers_as_inputs=False
(default): The initializers are included in the saved model. *include_initializers=True
,keep_initializers_as_inputs=True
: The initializers are included in the saved model and kept as model inputs. Choose this option if you want the ability to override the model weights during inference. *include_initializers=False
,keep_initializers_as_inputs=False
: The initializers are not included in the saved model and are not listed as model inputs. Choose this option if you want to attach the initializers to the ONNX model in a separate, post-processing, step. *include_initializers=False
,keep_initializers_as_inputs=True
: The initializers are not included in the saved model but are listed as model inputs. Choose this option if you want to supply the initializers during inference and want to minimize the size of the saved model.- Parameters
destination (str | os.PathLike) – The path to save the ONNX model to.
include_initializers (bool) – Whether to include the initializers in the saved model.
keep_initializers_as_inputs (bool) – Whether to keep the initializers as inputs in the saved model. If True, the initializers are added as inputs to the model which means they can be overwritten. by providing the initializers as model inputs.
external_data (bool | None) – Whether to save the weights as external data in a separate file.
- Raises
TypeError – If
external_data
isTrue
anddestination
is not a file path.
- torch.onnx.is_in_onnx_export()[source][source]
Returns whether it is in the middle of ONNX export.
- Return type
- class torch.onnx.OnnxExporterError
Errors raised by the ONNX exporter. This is the base class for all exporter errors.
- torch.onnx.enable_fake_mode()[source]
Enable fake mode for the duration of the context.
Internally it instantiates a
torch._subclasses.fake_tensor.FakeTensorMode
context manager that converts user input and model parameters intotorch._subclasses.fake_tensor.FakeTensor
.A
torch._subclasses.fake_tensor.FakeTensor
is atorch.Tensor
with the ability to run PyTorch code without having to actually do computation through tensors allocated on ameta
device. Because there is no actual data being allocated on the device, this API allows for initializing and exporting large models without the actual memory footprint needed for executing it.It is highly recommended to initialize the model in fake mode when exporting models that are too large to fit into memory.
Note
This function does not support torch.onnx.export(…, dynamo=True, optimize=True). Please call ONNXProgram.optimize() outside of the function after the model is exported.
Example:
# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) >>> import torch >>> class MyModel(torch.nn.Module): # Model with a parameter ... def __init__(self) -> None: ... super().__init__() ... self.weight = torch.nn.Parameter(torch.tensor(42.0)) ... def forward(self, x): ... return self.weight + x >>> with torch.onnx.enable_fake_mode(): ... # When initialized in fake mode, the model's parameters are fake tensors ... # They do not take up memory so we can initialize large models ... my_nn_module = MyModel() ... arg1 = torch.randn(2, 2, 2) >>> onnx_program = torch.onnx.export(my_nn_module, (arg1,), dynamo=True, optimize=False) >>> # Saving model WITHOUT initializers (only the architecture) >>> onnx_program.save( ... "my_model_without_initializers.onnx", ... include_initializers=False, ... keep_initializers_as_inputs=True, ... ) >>> # Saving model WITH initializers after applying concrete weights >>> onnx_program.apply_weights({"weight": torch.tensor(42.0)}) >>> onnx_program.save("my_model_with_initializers.onnx")
Warning
This API is experimental and is NOT backward-compatible.
Deprecated
The following classes and functions are deprecated and will be removed.
- torch.onnx.dynamo_export(model, /, *model_args, export_options=None, **model_kwargs)[source][source]
Export a torch.nn.Module to an ONNX graph.
Deprecated since version 2.7: Please use
torch.onnx.export(..., dynamo=True)
instead.- Parameters
model (torch.nn.Module | Callable | torch.export.ExportedProgram) – The PyTorch model to be exported to ONNX.
model_args – Positional inputs to
model
.model_kwargs – Keyword inputs to
model
.export_options (ExportOptions | None) – Options to influence the export to ONNX.
- Returns
An in-memory representation of the exported ONNX model.
- Return type
- class torch.onnx.ExportOptions(*, dynamic_shapes=None)[source][source]
Options for dynamo_export.
Deprecated since version 2.7: Please use
torch.onnx.export(..., dynamo=True)
instead.- Variables
dynamic_shapes – Shape information hint for input/output tensors. When
None
, the exporter determines the most compatible setting. WhenTrue
, all input shapes are considered dynamic. WhenFalse
, all input shapes are considered static.