Shortcuts

Introduction to ONNX || Exporting a PyTorch model to ONNX || Extending the ONNX exporter operator support || `Export a model with control flow to ONNX

Export a model with control flow to ONNX

Author: Xavier Dupré

Overview

This tutorial demonstrates how to handle control flow logic while exporting a PyTorch model to ONNX. It highlights the challenges of exporting conditional statements directly and provides solutions to circumvent them.

Conditional logic cannot be exported into ONNX unless they refactored to use torch.cond(). Let’s start with a simple model implementing a test.

What you will learn:

  • How to refactor the model to use torch.cond() for exporting.

  • How to export a model with control flow logic to ONNX.

  • How to optimize the exported model using the ONNX optimizer.

Prerequisites

  • torch >= 2.6

import torch

Define the Models

Two models are defined:

ForwardWithControlFlowTest: A model with a forward method containing an if-else conditional.

ModelWithControlFlowTest: A model that incorporates ForwardWithControlFlowTest as part of a simple MLP. The models are tested with a random input tensor to confirm they execute as expected.

class ForwardWithControlFlowTest(torch.nn.Module):
    def forward(self, x):
        if x.sum():
            return x * 2
        return -x


class ModelWithControlFlowTest(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(3, 2),
            torch.nn.Linear(2, 1),
            ForwardWithControlFlowTest(),
        )

    def forward(self, x):
        out = self.mlp(x)
        return out


model = ModelWithControlFlowTest()

Exporting the Model: First Attempt

Exporting this model using torch.export.export fails because the control flow logic in the forward pass creates a graph break that the exporter cannot handle. This behavior is expected, as conditional logic not written using torch.cond() is unsupported.

A try-except block is used to capture the expected failure during the export process. If the export unexpectedly succeeds, an AssertionError is raised.

x = torch.randn(3)
model(x)

try:
    torch.export.export(model, (x,), strict=False)
    raise AssertionError("This export should failed unless PyTorch now supports this model.")
except Exception as e:
    print(e)
Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: none)

Caused by: (_export/non_strict_utils.py:557 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "/var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py", line 56, in forward
    if x.sum():

Using torch.onnx.export() with JIT Tracing

When exporting the model using torch.onnx.export() with the dynamo=True argument, the exporter defaults to using JIT tracing. This fallback allows the model to export, but the resulting ONNX graph may not faithfully represent the original model logic due to the limitations of tracing.

onnx_program = torch.onnx.export(model, (x,), dynamo=True)
print(onnx_program.model)
/usr/local/lib/python3.10/dist-packages/onnxscript/converter.py:823: FutureWarning:

'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.

/usr/local/lib/python3.10/dist-packages/onnxscript/converter.py:823: FutureWarning:

'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.

[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ❌
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export`...
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export`... ❌
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with Torch Script...
/var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56: TracerWarning:

Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with Torch Script... ✅
[torch.onnx] Run decomposition...
/usr/local/lib/python3.10/dist-packages/torch/export/_unlift.py:75: UserWarning:

Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

/usr/local/lib/python3.10/dist-packages/torch/fx/graph.py:1801: UserWarning:

Node lifted_tensor_6 target lifted_tensor_6 lifted_tensor_6 of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target

[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
<
    ir_version=10,
    opset_imports={'pkg.onnxscript.torch_lib.common': 1, '': 18},
    producer_name='pytorch',
    producer_version='2.6.0+cu124',
    domain=None,
    model_version=None,
>
graph(
    name=main_graph,
    inputs=(
        %"input_1"<FLOAT,[3]>
    ),
    outputs=(
        %"mul"<FLOAT,[1]>
    ),
    initializers=(
        %"model.mlp.0.bias"<FLOAT,[2]>,
        %"model.mlp.0.weight"<FLOAT,[2,3]>,
        %"model.mlp.1.bias"<FLOAT,[1]>,
        %"model.mlp.1.weight"<FLOAT,[1,2]>
    ),
) {
    0 |  # node_Transpose_0
         %"val_0"<?,?> ⬅️ ::Transpose(%"model.mlp.0.weight") {perm=[1, 0]}
    1 |  # node_MatMul_1
         %"val_1"<?,?> ⬅️ ::MatMul(%"input_1", %"val_0")
    2 |  # node_Add_2
         %"linear"<FLOAT,[2]> ⬅️ ::Add(%"val_1", %"model.mlp.0.bias")
    3 |  # node_Transpose_3
         %"val_2"<?,?> ⬅️ ::Transpose(%"model.mlp.1.weight") {perm=[1, 0]}
    4 |  # node_MatMul_4
         %"val_3"<?,?> ⬅️ ::MatMul(%"linear", %"val_2")
    5 |  # node_Add_5
         %"linear_1"<FLOAT,[1]> ⬅️ ::Add(%"val_3", %"model.mlp.1.bias")
    6 |  # node_Constant_6
         %"val_4"<?,?> ⬅️ ::Constant() {value=Tensor<INT64,[]>(array(2), name=None)}
    7 |  # node_Cast_7
         %"scalar_tensor_default"<FLOAT,[]> ⬅️ ::Cast(%"val_4") {to=FLOAT}
    8 |  # node_Mul_8
         %"mul"<FLOAT,[1]> ⬅️ ::Mul(%"linear_1", %"scalar_tensor_default")
    return %"mul"<FLOAT,[1]>
}

<
    opset_imports={'': 18},
>
def pkg.onnxscript.torch_lib.common::Rank(
    inputs=(
        %"input"<?,?>
    ),
    outputs=(
        %"return_val"<?,?>
    ),
) {
    0 |  # n0
         %"tmp"<?,?> ⬅️ ::Shape(%"input")
    1 |  # n1
         %"return_val"<?,?> ⬅️ ::Size(%"tmp")
    return %"return_val"<?,?>
}

<
    opset_imports={'': 18},
>
def pkg.onnxscript.torch_lib.common::IsScalar(
    inputs=(
        %"input"<?,?>
    ),
    outputs=(
        %"return_val"<?,?>
    ),
) {
    0 |  # n0
         %"tmp"<?,?> ⬅️ ::Shape(%"input")
    1 |  # n1
         %"tmp_0"<?,?> ⬅️ ::Size(%"tmp")
    2 |  # n2
         %"tmp_1"<?,?> ⬅️ ::Constant() {value_int=0}
    3 |  # n3
         %"return_val"<?,?> ⬅️ ::Equal(%"tmp_0", %"tmp_1")
    return %"return_val"<?,?>
}

Suggested Patch: Refactoring with torch.cond()

To make the control flow exportable, the tutorial demonstrates replacing the forward method in ForwardWithControlFlowTest with a refactored version that uses torch.cond`().

Details of the Refactoring:

Two helper functions (identity2 and neg) represent the branches of the conditional logic: * torch.cond`() is used to specify the condition and the two branches along with the input arguments. * The updated forward method is then dynamically assigned to the ForwardWithControlFlowTest instance within the model. A list of submodules is printed to confirm the replacement.

def new_forward(x):
    def identity2(x):
        return x * 2

    def neg(x):
        return -x

    return torch.cond(x.sum() > 0, identity2, neg, (x,))


print("the list of submodules")
for name, mod in model.named_modules():
    print(name, type(mod))
    if isinstance(mod, ForwardWithControlFlowTest):
        mod.forward = new_forward
the list of submodules
 <class '__main__.ModelWithControlFlowTest'>
mlp <class 'torch.nn.modules.container.Sequential'>
mlp.0 <class 'torch.nn.modules.linear.Linear'>
mlp.1 <class 'torch.nn.modules.linear.Linear'>
mlp.2 <class '__main__.ForwardWithControlFlowTest'>

Let’s see what the FX graph looks like.

print(torch.export.export(model, (x,), strict=False))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_mlp_0_weight: "f32[2, 3]", p_mlp_0_bias: "f32[2]", p_mlp_1_weight: "f32[1, 2]", p_mlp_1_bias: "f32[1]", x: "f32[3]"):
             # File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[2]" = torch.ops.aten.linear.default(x, p_mlp_0_weight, p_mlp_0_bias);  x = p_mlp_0_weight = p_mlp_0_bias = None
            linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, p_mlp_1_weight, p_mlp_1_bias);  linear = p_mlp_1_weight = p_mlp_1_bias = None

             # File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py:250 in forward, code: input = module(input)
            sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1)
            gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0);  sum_1 = None

             # File: <eval_with_key>.24:9 in forward, code: cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, [l_args_3_0_]);  l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = None
            true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [linear_1]);  gt = true_graph_0 = false_graph_0 = linear_1 = None
            getitem: "f32[1]" = cond[0];  cond = None
            return (getitem,)

        class true_graph_0(torch.nn.Module):
            def forward(self, linear_1: "f32[1]"):
                 # File: <eval_with_key>.21:6 in forward, code: mul = l_args_3_0__1.mul(2);  l_args_3_0__1 = None
                mul: "f32[1]" = torch.ops.aten.mul.Tensor(linear_1, 2);  linear_1 = None
                return (mul,)

        class false_graph_0(torch.nn.Module):
            def forward(self, linear_1: "f32[1]"):
                 # File: <eval_with_key>.22:6 in forward, code: neg = l_args_3_0__1.neg();  l_args_3_0__1 = None
                neg: "f32[1]" = torch.ops.aten.neg.default(linear_1);  linear_1 = None
                return (neg,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_mlp_0_weight'), target='mlp.0.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_mlp_0_bias'), target='mlp.0.bias', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_mlp_1_weight'), target='mlp.1.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_mlp_1_bias'), target='mlp.1.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}

Let’s export again.

onnx_program = torch.onnx.export(model, (x,), dynamo=True)
print(onnx_program.model)
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
<
    ir_version=10,
    opset_imports={'pkg.onnxscript.torch_lib.common': 1, '': 18, 'pkg.torch.__subgraph__': 1},
    producer_name='pytorch',
    producer_version='2.6.0+cu124',
    domain=None,
    model_version=None,
>
graph(
    name=main_graph,
    inputs=(
        %"x"<FLOAT,[3]>
    ),
    outputs=(
        %"getitem"<FLOAT,[1]>
    ),
    initializers=(
        %"mlp.0.weight"<FLOAT,[2,3]>,
        %"mlp.0.bias"<FLOAT,[2]>,
        %"mlp.1.weight"<FLOAT,[1,2]>,
        %"mlp.1.bias"<FLOAT,[1]>
    ),
) {
     0 |  # node_Transpose_0
          %"val_0"<?,?> ⬅️ ::Transpose(%"mlp.0.weight") {perm=[1, 0]}
     1 |  # node_MatMul_1
          %"val_1"<?,?> ⬅️ ::MatMul(%"x", %"val_0")
     2 |  # node_Add_2
          %"linear"<FLOAT,[2]> ⬅️ ::Add(%"val_1", %"mlp.0.bias")
     3 |  # node_Transpose_3
          %"val_2"<?,?> ⬅️ ::Transpose(%"mlp.1.weight") {perm=[1, 0]}
     4 |  # node_MatMul_4
          %"val_3"<?,?> ⬅️ ::MatMul(%"linear", %"val_2")
     5 |  # node_Add_5
          %"linear_1"<FLOAT,[1]> ⬅️ ::Add(%"val_3", %"mlp.1.bias")
     6 |  # node_ReduceSum_6
          %"sum_1"<FLOAT,[]> ⬅️ ::ReduceSum(%"linear_1") {keepdims=False, noop_with_empty_axes=0}
     7 |  # node_Constant_7
          %"val_4"<?,?> ⬅️ ::Constant() {value=Tensor<INT64,[]>(array(0), name=None)}
     8 |  # node_Cast_8
          %"scalar_tensor_default"<FLOAT,[]> ⬅️ ::Cast(%"val_4") {to=FLOAT}
     9 |  # node_Greater_9
          %"gt"<BOOL,[]> ⬅️ ::Greater(%"sum_1", %"scalar_tensor_default")
    10 |  # node_If_10
          %"getitem"<FLOAT,[1]> ⬅️ ::If(%"gt") {then_branch=
              graph(
                  name=true_graph_0,
                  inputs=(

                  ),
                  outputs=(
                      %"mul_true_graph_0"<?,?>
                  ),
              ) {
                  0 |  # node_true_graph_0_0
                       %"mul_true_graph_0"<?,?> ⬅️ pkg.torch.__subgraph__::true_graph_0(%"linear_1")
                  return %"mul_true_graph_0"<?,?>
              }, else_branch=
              graph(
                  name=false_graph_0,
                  inputs=(

                  ),
                  outputs=(
                      %"neg_false_graph_0"<?,?>
                  ),
              ) {
                  0 |  # node_false_graph_0_0
                       %"neg_false_graph_0"<?,?> ⬅️ pkg.torch.__subgraph__::false_graph_0(%"linear_1")
                  return %"neg_false_graph_0"<?,?>
              }}
    return %"getitem"<FLOAT,[1]>
}

<
    opset_imports={'': 18},
>
def pkg.torch.__subgraph__::false_graph_0(
    inputs=(
        %"linear_1"<FLOAT,[1]>
    ),
    outputs=(
        %"neg"<FLOAT,[1]>
    ),
) {
    0 |  # node_Neg_0
         %"neg"<FLOAT,[1]> ⬅️ ::Neg(%"linear_1")
    return %"neg"<FLOAT,[1]>
}

<
    opset_imports={'': 18},
>
def pkg.torch.__subgraph__::true_graph_0(
    inputs=(
        %"linear_1"<FLOAT,[1]>
    ),
    outputs=(
        %"mul"<FLOAT,[1]>
    ),
) {
    0 |  # node_Constant_0
         %"val_0"<?,?> ⬅️ ::Constant() {value=Tensor<INT64,[]>(array(2), name=None)}
    1 |  # node_Cast_1
         %"scalar_tensor_default"<FLOAT,[]> ⬅️ ::Cast(%"val_0") {to=FLOAT}
    2 |  # node_Mul_2
         %"mul"<FLOAT,[1]> ⬅️ ::Mul(%"linear_1", %"scalar_tensor_default")
    return %"mul"<FLOAT,[1]>
}

<
    opset_imports={'': 18},
>
def pkg.onnxscript.torch_lib.common::Rank(
    inputs=(
        %"input"<?,?>
    ),
    outputs=(
        %"return_val"<?,?>
    ),
) {
    0 |  # n0
         %"tmp"<?,?> ⬅️ ::Shape(%"input")
    1 |  # n1
         %"return_val"<?,?> ⬅️ ::Size(%"tmp")
    return %"return_val"<?,?>
}

<
    opset_imports={'': 18},
>
def pkg.onnxscript.torch_lib.common::IsScalar(
    inputs=(
        %"input"<?,?>
    ),
    outputs=(
        %"return_val"<?,?>
    ),
) {
    0 |  # n0
         %"tmp"<?,?> ⬅️ ::Shape(%"input")
    1 |  # n1
         %"tmp_0"<?,?> ⬅️ ::Size(%"tmp")
    2 |  # n2
         %"tmp_1"<?,?> ⬅️ ::Constant() {value_int=0}
    3 |  # n3
         %"return_val"<?,?> ⬅️ ::Equal(%"tmp_0", %"tmp_1")
    return %"return_val"<?,?>
}

We can optimize the model and get rid of the model local functions created to capture the control flow branches.

<
    ir_version=10,
    opset_imports={'pkg.onnxscript.torch_lib.common': 1, '': 18, 'pkg.torch.__subgraph__': 1},
    producer_name='pytorch',
    producer_version='2.6.0+cu124',
    domain=None,
    model_version=None,
>
graph(
    name=main_graph,
    inputs=(
        %"x"<FLOAT,[3]>
    ),
    outputs=(
        %"getitem"<FLOAT,[1]>
    ),
    initializers=(
        %"mlp.0.bias"<FLOAT,[2]>,
        %"mlp.1.bias"<FLOAT,[1]>
    ),
) {
     0 |  # node_Constant_11
          %"val_0"<FLOAT,[3,2]> ⬅️ ::Constant() {value=Tensor<FLOAT,[3,2]>(array([[ 0.44140652,  0.53036046],
                 [ 0.47920528, -0.1264995 ],
                 [-0.13525727,  0.11650391]], dtype=float32), name='val_0')}
     1 |  # node_MatMul_1
          %"val_1"<FLOAT,[2]> ⬅️ ::MatMul(%"x", %"val_0")
     2 |  # node_Add_2
          %"linear"<FLOAT,[2]> ⬅️ ::Add(%"val_1", %"mlp.0.bias")
     3 |  # node_Constant_12
          %"val_2"<FLOAT,[2,1]> ⬅️ ::Constant() {value=Tensor<FLOAT,[2,1]>(array([[ 0.62334496],
                 [-0.5187534 ]], dtype=float32), name='val_2')}
     4 |  # node_MatMul_4
          %"val_3"<FLOAT,[1]> ⬅️ ::MatMul(%"linear", %"val_2")
     5 |  # node_Add_5
          %"linear_1"<FLOAT,[1]> ⬅️ ::Add(%"val_3", %"mlp.1.bias")
     6 |  # node_ReduceSum_6
          %"sum_1"<FLOAT,[]> ⬅️ ::ReduceSum(%"linear_1") {keepdims=False, noop_with_empty_axes=0}
     7 |  # node_Constant_13
          %"scalar_tensor_default"<FLOAT,[]> ⬅️ ::Constant() {value=Tensor<FLOAT,[]>(array(0., dtype=float32), name='scalar_tensor_default')}
     8 |  # node_Greater_9
          %"gt"<BOOL,[]> ⬅️ ::Greater(%"sum_1", %"scalar_tensor_default")
     9 |  # node_If_10
          %"getitem"<FLOAT,[1]> ⬅️ ::If(%"gt") {then_branch=
              graph(
                  name=true_graph_0,
                  inputs=(

                  ),
                  outputs=(
                      %"mul_true_graph_0"<FLOAT,[1]>
                  ),
              ) {
                  0 |  # node_Constant_1
                       %"scalar_tensor_default_2"<FLOAT,[]> ⬅️ ::Constant() {value=Tensor<FLOAT,[]>(array(2., dtype=float32), name='scalar_tensor_default_2')}
                  1 |  # node_Mul_2
                       %"mul_true_graph_0"<FLOAT,[1]> ⬅️ ::Mul(%"linear_1", %"scalar_tensor_default_2")
                  return %"mul_true_graph_0"<FLOAT,[1]>
              }, else_branch=
              graph(
                  name=false_graph_0,
                  inputs=(

                  ),
                  outputs=(
                      %"neg_false_graph_0"<FLOAT,[1]>
                  ),
              ) {
                  0 |  # node_Neg_0
                       %"neg_false_graph_0"<FLOAT,[1]> ⬅️ ::Neg(%"linear_1")
                  return %"neg_false_graph_0"<FLOAT,[1]>
              }}
    return %"getitem"<FLOAT,[1]>
}

Conclusion

This tutorial demonstrates the challenges of exporting models with conditional logic to ONNX and presents a practical solution using torch.cond(). While the default exporters may fail or produce imperfect graphs, refactoring the model’s logic ensures compatibility and generates a faithful ONNX representation.

By understanding these techniques, we can overcome common pitfalls when working with control flow in PyTorch models and ensure smooth integration with ONNX workflows.

Further reading

The list below refers to tutorials that ranges from basic examples to advanced scenarios, not necessarily in the order they are listed. Feel free to jump directly to specific topics of your interest or sit tight and have fun going through all of them to learn all there is about the ONNX exporter.

Total running time of the script: ( 0 minutes 2.979 seconds)

Gallery generated by Sphinx-Gallery

//temporarily add a link to survey

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources