python.control-flow
=======================
dynamic_shape_if_guard
^^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`python.control-flow <python.control-flow>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    import torch
    
    
    
    class DynamicShapeIfGuard(torch.nn.Module):
        """
        `if` statement with backed dynamic shape predicate will be specialized into
        one particular branch and generate a guard. However, export will fail if the
        the dimension is marked as dynamic shape from higher level API.
        """
    
        def forward(self, x):
            if x.shape[0] == 3:
                return x.cos()
    
            return x.sin()
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[3, 2, 2]):
                # 
                cos: f32[3, 2, 2] = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
                return (cos,)
                
    Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['cos'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Symbol to range: {}
    


list_unpack
^^^^^^^^^^^

.. note::

    Tags: :doc:`python.data-structure <python.data-structure>`, :doc:`python.control-flow <python.control-flow>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    from typing import List
    
    import torch
    
    
    
    def list_unpack(args: List[torch.Tensor]):
        """
        Lists are treated as static construct, therefore unpacking should be
        erased after tracing.
        """
        x, *y = args
        return x + y[0]
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[3, 2], arg1_1: i64[], arg2_1: i64[]):
                # 
                add: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                return (add,)
                
    Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1', 'arg2_1'], user_outputs=['add'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Symbol to range: {}
    


static_for_loop
^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`python.control-flow <python.control-flow>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    import torch
    
    
    
    class StaticForLoop(torch.nn.Module):
        """
        A for loop with constant number of iterations should be unrolled in the exported graph.
        """
    
        def __init__(self):
            super().__init__()
    
        def forward(self, x):
            ret = []
            for i in range(10):  # constant
                ret.append(i + x)
            return ret
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[3, 2]):
                # 
                add: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 0)
                add_1: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 1)
                add_2: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 2)
                add_3: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 3)
                add_4: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 4)
                add_5: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 5)
                add_6: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 6)
                add_7: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 7)
                add_8: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 8)
                add_9: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 9);  arg0_1 = None
                return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9)
                
    Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['add', 'add_1', 'add_2', 'add_3', 'add_4', 'add_5', 'add_6', 'add_7', 'add_8', 'add_9'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Symbol to range: {}
    


static_if
^^^^^^^^^

.. note::

    Tags: :doc:`python.control-flow <python.control-flow>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    import torch
    
    
    
    class StaticIf(torch.nn.Module):
        """
        `if` statement with static predicate value should be traced through with the
        taken branch.
        """
    
        def __init__(self):
            super().__init__()
    
        def forward(self, x):
            if len(x.shape) == 3:
                return x + torch.ones(1, 1, 1)
    
            return x
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[3, 2, 2]):
                # 
                full: f32[1, 1, 1] = torch.ops.aten.full.default([1, 1, 1], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
                add: f32[3, 2, 2] = torch.ops.aten.add.Tensor(arg0_1, full);  arg0_1 = full = None
                return (add,)
                
    Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['add'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
    Symbol to range: {}