python.control-flow
=======================
dynamic_shape_if_guard
^^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`python.control-flow <python.control-flow>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    class DynamicShapeIfGuard(torch.nn.Module):
        """
        `if` statement with backed dynamic shape predicate will be specialized into
        one particular branch and generate a guard. However, export will fail if the
        the dimension is marked as dynamic shape from higher level API.
        """
    
        def forward(self, x):
            if x.shape[0] == 3:
                return x.cos()
    
            return x.sin()
    
    example_args = (torch.randn(3, 2, 2),)
    tags = {"torch.dynamic-shape", "python.control-flow"}
    model = DynamicShapeIfGuard()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, 2, 2]"):
                     cos: "f32[3, 2, 2]" = torch.ops.aten.cos.default(x);  x = None
                return (cos,)
                
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cos'), target=None)])
    Range constraints: {}
    


list_unpack
^^^^^^^^^^^

.. note::

    Tags: :doc:`python.control-flow <python.control-flow>`, :doc:`python.data-structure <python.data-structure>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    from typing import List
    
    import torch
    
    class ListUnpack(torch.nn.Module):
        """
        Lists are treated as static construct, therefore unpacking should be
        erased after tracing.
        """
    
        def forward(self, args: List[torch.Tensor]):
            """
            Lists are treated as static construct, therefore unpacking should be
            erased after tracing.
            """
            x, *y = args
            return x + y[0]
    
    example_args = ([torch.randn(3, 2), torch.tensor(4), torch.tensor(5)],)
    tags = {"python.control-flow", "python.data-structure"}
    model = ListUnpack()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, args_0: "f32[3, 2]", args_1: "i64[]", args_2: "i64[]"):
                     add: "f32[3, 2]" = torch.ops.aten.add.Tensor(args_0, args_1);  args_0 = args_1 = None
                return (add,)
                
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='args_0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='args_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='args_2'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
    Range constraints: {}
    


static_for_loop
^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`python.control-flow <python.control-flow>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    class StaticForLoop(torch.nn.Module):
        """
        A for loop with constant number of iterations should be unrolled in the exported graph.
        """
    
        def forward(self, x):
            # constant
            ret = [i + x for i in range(10)]
            return ret
    
    example_args = (torch.randn(3, 2),)
    tags = {"python.control-flow"}
    model = StaticForLoop()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, 2]"):
                     add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 0)
                add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 1)
                add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 2)
                add_3: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 3)
                add_4: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 4)
                add_5: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 5)
                add_6: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 6)
                add_7: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 7)
                add_8: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 8)
                add_9: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 9);  x = None
                return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9)
                
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_3'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_4'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_5'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_6'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_7'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_8'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_9'), target=None)])
    Range constraints: {}
    


static_if
^^^^^^^^^

.. note::

    Tags: :doc:`python.control-flow <python.control-flow>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    class StaticIf(torch.nn.Module):
        """
        `if` statement with static predicate value should be traced through with the
        taken branch.
        """
    
        def forward(self, x):
            if len(x.shape) == 3:
                return x + torch.ones(1, 1, 1)
    
            return x
    
    example_args = (torch.randn(3, 2, 2),)
    tags = {"python.control-flow"}
    model = StaticIf()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, 2, 2]"):
                     ones: "f32[1, 1, 1]" = torch.ops.aten.ones.default([1, 1, 1], device = device(type='cpu'), pin_memory = False)
                add: "f32[3, 2, 2]" = torch.ops.aten.add.Tensor(x, ones);  x = ones = None
                return (add,)
                
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
    Range constraints: {}