python.control-flow
=======================
dynamic_shape_if_guard
^^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`python.control-flow <python.control-flow>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    import torch
    
    
    
    class DynamicShapeIfGuard(torch.nn.Module):
        """
        `if` statement with backed dynamic shape predicate will be specialized into
        one particular branch and generate a guard. However, export will fail if the
        the dimension is marked as dynamic shape from higher level API.
        """
    
        def forward(self, x):
            if x.shape[0] == 3:
                return x.cos()
    
            return x.sin()
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: "f32[3, 2, 2]"):
                    cos: "f32[3, 2, 2]" = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
                return (cos,)
                
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cos'), target=None)])
    Range constraints: {}
    


list_unpack
^^^^^^^^^^^

.. note::

    Tags: :doc:`python.control-flow <python.control-flow>`, :doc:`python.data-structure <python.data-structure>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    from typing import List
    
    import torch
    
    
    
    class ListUnpack(torch.nn.Module):
        """
        Lists are treated as static construct, therefore unpacking should be
        erased after tracing.
        """
    
        def __init__(self):
            super().__init__()
    
        def forward(self, args: List[torch.Tensor]):
            """
            Lists are treated as static construct, therefore unpacking should be
            erased after tracing.
            """
            x, *y = args
            return x + y[0]
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: "f32[3, 2]", arg1_1: "i64[]", arg2_1: "i64[]"):
                    add: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                return (add,)
                
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg2_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
    Range constraints: {}
    


static_for_loop
^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`python.control-flow <python.control-flow>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    import torch
    
    
    
    class StaticForLoop(torch.nn.Module):
        """
        A for loop with constant number of iterations should be unrolled in the exported graph.
        """
    
        def __init__(self):
            super().__init__()
    
        def forward(self, x):
            ret = []
            for i in range(10):  # constant
                ret.append(i + x)
            return ret
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: "f32[3, 2]"):
                    add: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 0)
                add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 1)
                add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 2)
                add_3: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 3)
                add_4: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 4)
                add_5: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 5)
                add_6: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 6)
                add_7: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 7)
                add_8: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 8)
                add_9: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 9);  arg0_1 = None
                return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9)
                
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_3'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_4'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_5'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_6'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_7'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_8'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_9'), target=None)])
    Range constraints: {}
    


static_if
^^^^^^^^^

.. note::

    Tags: :doc:`python.control-flow <python.control-flow>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    import torch
    
    
    
    class StaticIf(torch.nn.Module):
        """
        `if` statement with static predicate value should be traced through with the
        taken branch.
        """
    
        def __init__(self):
            super().__init__()
    
        def forward(self, x):
            if len(x.shape) == 3:
                return x + torch.ones(1, 1, 1)
    
            return x
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: "f32[3, 2, 2]"):
                    ones: "f32[1, 1, 1]" = torch.ops.aten.ones.default([1, 1, 1], device = device(type='cpu'), pin_memory = False)
                add: "f32[3, 2, 2]" = torch.ops.aten.add.Tensor(arg0_1, ones);  arg0_1 = ones = None
                return (add,)
                
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
    Range constraints: {}