python.control-flow
=======================
dynamic_shape_if_guard
^^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`python.control-flow <python.control-flow>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    import torch
    
    
    
    class DynamicShapeIfGuard(torch.nn.Module):
        """
        `if` statement with backed dynamic shape predicate will be specialized into
        one particular branch and generate a guard. However, export will fail if the
        the dimension is marked as dynamic shape from higher level API.
        """
    
        def forward(self, x):
            if x.shape[0] == 3:
                return x.cos()
    
            return x.sin()
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, l_x_: "f32[3, 2, 2]"):
                    cos: "f32[3, 2, 2]" = torch.ops.aten.cos.default(l_x_);  l_x_ = None
                return (cos,)
                
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cos'), target=None)])
    Range constraints: {}
    Equality constraints: []
    


list_unpack
^^^^^^^^^^^

.. note::

    Tags: :doc:`python.data-structure <python.data-structure>`, :doc:`python.control-flow <python.control-flow>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    from typing import List
    
    import torch
    
    
    
    def list_unpack(args: List[torch.Tensor]):
        """
        Lists are treated as static construct, therefore unpacking should be
        erased after tracing.
        """
        x, *y = args
        return x + y[0]
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, 2]", l_args_1_: "i64[]", arg2: "i64[]"):
                    add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, l_args_1_);  x = l_args_1_ = None
                return (add,)
                
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='l_args_1_'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg2'), target=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
    Range constraints: {}
    Equality constraints: []
    


static_for_loop
^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`python.control-flow <python.control-flow>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    import torch
    
    
    
    class StaticForLoop(torch.nn.Module):
        """
        A for loop with constant number of iterations should be unrolled in the exported graph.
        """
    
        def __init__(self):
            super().__init__()
    
        def forward(self, x):
            ret = []
            for i in range(10):  # constant
                ret.append(i + x)
            return ret
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, l_x_: "f32[3, 2]"):
                    add: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, 0)
                add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, 1)
                add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, 2)
                add_3: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, 3)
                add_4: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, 4)
                add_5: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, 5)
                add_6: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, 6)
                add_7: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, 7)
                add_8: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, 8)
                add_9: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, 9);  l_x_ = None
                return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9)
                
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_3'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_4'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_5'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_6'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_7'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_8'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_9'), target=None)])
    Range constraints: {}
    Equality constraints: []
    


static_if
^^^^^^^^^

.. note::

    Tags: :doc:`python.control-flow <python.control-flow>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    import torch
    
    
    
    class StaticIf(torch.nn.Module):
        """
        `if` statement with static predicate value should be traced through with the
        taken branch.
        """
    
        def __init__(self):
            super().__init__()
    
        def forward(self, x):
            if len(x.shape) == 3:
                return x + torch.ones(1, 1, 1)
    
            return x
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, l_x_: "f32[3, 2, 2]"):
                    ones: "f32[1, 1, 1]" = torch.ops.aten.ones.default([1, 1, 1], device = device(type='cpu'), pin_memory = False)
                add: "f32[3, 2, 2]" = torch.ops.aten.add.Tensor(l_x_, ones);  l_x_ = ones = None
                return (add,)
                
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
    Range constraints: {}
    Equality constraints: []