torch.dynamic-shape
=======================
cond_branch_class_method
^^^^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    import torch
    
    from functorch.experimental.control_flow import cond
    
    
    class MySubModule(torch.nn.Module):
        def foo(self, x):
            return x.cos()
    
        def forward(self, x):
            return self.foo(x)
    
    
    class CondBranchClassMethod(torch.nn.Module):
        """
        The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
          - both branches must take the same args, which must also match the branch args passed to cond.
          - both branches must return a single tensor
          - returned tensor must have the same tensor metadata, e.g. shape and dtype
          - branch function can be free function, nested function, lambda, class methods
          - branch function can not have closure variables
          - no inplace mutations on inputs or global variables
    
    
        This example demonstrates using class method in cond().
    
        NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
        """
    
        def __init__(self):
            super().__init__()
            self.subm = MySubModule()
    
        def bar(self, x):
            return x.sin()
    
        def forward(self, x):
            return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: "f32[3]"):
                    true_graph_0 = self.true_graph_0
                false_graph_0 = self.false_graph_0
                conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]);  true_graph_0 = false_graph_0 = arg0_1 = None
                getitem: "f32[3]" = conditional[0];  conditional = None
                return (getitem,)
                
            class <lambda>(torch.nn.Module):
                def forward(self, arg0_1: "f32[3]"):
                            cos: "f32[3]" = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
                    return (cos,)
                    
            class <lambda>(torch.nn.Module):
                def forward(self, arg0_1: "f32[3]"):
                            sin: "f32[3]" = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                    return (sin,)
                    
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
    Range constraints: {}
    


cond_branch_nested_function
^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    import torch
    
    from functorch.experimental.control_flow import cond
    
    
    class CondBranchNestedFunction(torch.nn.Module):
        """
        The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
          - both branches must take the same args, which must also match the branch args passed to cond.
          - both branches must return a single tensor
          - returned tensor must have the same tensor metadata, e.g. shape and dtype
          - branch function can be free function, nested function, lambda, class methods
          - branch function can not have closure variables
          - no inplace mutations on inputs or global variables
    
        This example demonstrates using nested function in cond().
    
        NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
        """
        def __init__(self):
            super().__init__()
    
        def forward(self, x):
            def true_fn(x):
                def inner_true_fn(y):
                    return x + y
    
                return inner_true_fn(x)
    
            def false_fn(x):
                def inner_false_fn(y):
                    return x - y
    
                return inner_false_fn(x)
    
            return cond(x.shape[0] < 10, true_fn, false_fn, [x])
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: "f32[3]"):
                    true_graph_0 = self.true_graph_0
                false_graph_0 = self.false_graph_0
                conditional = torch.ops.higher_order.cond(True, true_graph_0, false_graph_0, [arg0_1]);  true_graph_0 = false_graph_0 = arg0_1 = None
                getitem: "f32[3]" = conditional[0];  conditional = None
                return (getitem,)
                
            class <lambda>(torch.nn.Module):
                def forward(self, arg0_1: "f32[3]"):
                            add: "f32[3]" = torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 = None
                    return (add,)
                    
            class <lambda>(torch.nn.Module):
                def forward(self, arg0_1: "f32[3]"):
                            sub: "f32[3]" = torch.ops.aten.sub.Tensor(arg0_1, arg0_1);  arg0_1 = None
                    return (sub,)
                    
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
    Range constraints: {}
    


cond_branch_nonlocal_variables
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    import torch
    
    from functorch.experimental.control_flow import cond
    
    
    class CondBranchNonlocalVariables(torch.nn.Module):
        """
        The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
        - both branches must take the same args, which must also match the branch args passed to cond.
        - both branches must return a single tensor
        - returned tensor must have the same tensor metadata, e.g. shape and dtype
        - branch function can be free function, nested function, lambda, class methods
        - branch function can not have closure variables
        - no inplace mutations on inputs or global variables
    
        This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions.
    
        The code below will not work because capturing closure variables is not supported.
        ```
        my_tensor_var = x + 100
        my_primitive_var = 3.14
    
        def true_fn(y):
            nonlocal my_tensor_var, my_primitive_var
            return y + my_tensor_var + my_primitive_var
    
        def false_fn(y):
            nonlocal my_tensor_var, my_primitive_var
            return y - my_tensor_var - my_primitive_var
    
        return cond(x.shape[0] > 5, true_fn, false_fn, [x])
        ```
    
        NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
        """
    
        def __init__(self):
            super().__init__()
    
        def forward(self, x):
            my_tensor_var = x + 100
            my_primitive_var = 3.14
    
            def true_fn(x, y, z):
                return x + y + z
    
            def false_fn(x, y, z):
                return x - y - z
    
            return cond(
                x.shape[0] > 5,
                true_fn,
                false_fn,
                [x, my_tensor_var, torch.tensor(my_primitive_var)],
            )
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, _lifted_tensor_constant0: "f32[]", arg0_1: "f32[6]"):
                    add: "f32[6]" = torch.ops.aten.add.Tensor(arg0_1, 100)
                
                    lift_fresh_copy: "f32[]" = torch.ops.aten.lift_fresh_copy.default(_lifted_tensor_constant0);  _lifted_tensor_constant0 = None
                
                    true_graph_0 = self.true_graph_0
                false_graph_0 = self.false_graph_0
                conditional = torch.ops.higher_order.cond(True, true_graph_0, false_graph_0, [arg0_1, add, lift_fresh_copy]);  true_graph_0 = false_graph_0 = arg0_1 = add = lift_fresh_copy = None
                getitem: "f32[6]" = conditional[0];  conditional = None
                return (getitem,)
                
            class <lambda>(torch.nn.Module):
                def forward(self, arg0_1: "f32[6]", arg1_1: "f32[6]", arg2_1: "f32[]"):
                            add: "f32[6]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                    add_1: "f32[6]" = torch.ops.aten.add.Tensor(add, arg2_1);  add = arg2_1 = None
                    return (add_1,)
                    
            class <lambda>(torch.nn.Module):
                def forward(self, arg0_1: "f32[6]", arg1_1: "f32[6]", arg2_1: "f32[]"):
                            sub: "f32[6]" = torch.ops.aten.sub.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                    sub_1: "f32[6]" = torch.ops.aten.sub.Tensor(sub, arg2_1);  sub = arg2_1 = None
                    return (sub_1,)
                    
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='_lifted_tensor_constant0'), target='_lifted_tensor_constant0', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
    Range constraints: {}
    


cond_operands
^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    import torch
    
    from torch.export import Dim
    from functorch.experimental.control_flow import cond
    
    x = torch.randn(3, 2)
    y = torch.ones(2)
    dim0_x = Dim("dim0_x")
    
    class CondOperands(torch.nn.Module):
        """
        The operands passed to cond() must be:
        - a list of tensors
        - match arguments of `true_fn` and `false_fn`
    
        NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
        """
    
        def __init__(self):
            super().__init__()
    
        def forward(self, x, y):
            def true_fn(x, y):
                return x + y
    
            def false_fn(x, y):
                return x - y
    
            return cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: "f32[s0, 2]", arg1_1: "f32[2]"):
                    sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(arg0_1, 0)
                gt: "Sym(s0 > 2)" = sym_size_int > 2;  sym_size_int = None
                true_graph_0 = self.true_graph_0
                false_graph_0 = self.false_graph_0
                conditional = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1, arg1_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = arg1_1 = None
                getitem: "f32[s0, 2]" = conditional[0];  conditional = None
                return (getitem,)
                
            class <lambda>(torch.nn.Module):
                def forward(self, arg0_1: "f32[s0, 2]", arg1_1: "f32[2]"):
                            add: "f32[s0, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                    return (add,)
                    
            class <lambda>(torch.nn.Module):
                def forward(self, arg0_1: "f32[s0, 2]", arg1_1: "f32[2]"):
                            sub: "f32[s0, 2]" = torch.ops.aten.sub.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                    return (sub,)
                    
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
    Range constraints: {s0: ValueRanges(lower=2, upper=oo, is_bool=False)}
    


cond_predicate
^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    import torch
    
    from functorch.experimental.control_flow import cond
    
    
    class CondPredicate(torch.nn.Module):
        """
        The conditional statement (aka predicate) passed to cond() must be one of the following:
          - torch.Tensor with a single element
          - boolean expression
    
        NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
        """
    
        def __init__(self):
            super().__init__()
    
        def forward(self, x):
            pred = x.dim() > 2 and x.shape[2] > 10
    
            return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: "f32[6, 4, 3]"):
                    true_graph_0 = self.true_graph_0
                false_graph_0 = self.false_graph_0
                conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]);  true_graph_0 = false_graph_0 = arg0_1 = None
                getitem: "f32[6, 4, 3]" = conditional[0];  conditional = None
                return (getitem,)
                
            class <lambda>(torch.nn.Module):
                def forward(self, arg0_1: "f32[6, 4, 3]"):
                            cos: "f32[6, 4, 3]" = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
                    return (cos,)
                    
            class <lambda>(torch.nn.Module):
                def forward(self, arg0_1: "f32[6, 4, 3]"):
                            sin: "f32[6, 4, 3]" = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                    return (sin,)
                    
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
    Range constraints: {}
    


dynamic_shape_constructor
^^^^^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    import torch
    
    
    
    class DynamicShapeConstructor(torch.nn.Module):
        """
        Tensor constructors should be captured with dynamic shape inputs rather
        than being baked in with static shape.
        """
        def __init__(self):
            super().__init__()
    
        def forward(self, x):
            return torch.ones(x.shape[0] * 2)
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: "f32[3, 2]"):
                    ones: "f32[6]" = torch.ops.aten.ones.default([6], device = device(type='cpu'), pin_memory = False)
                return (ones,)
                
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='ones'), target=None)])
    Range constraints: {}
    


dynamic_shape_if_guard
^^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`python.control-flow <python.control-flow>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    import torch
    
    
    
    class DynamicShapeIfGuard(torch.nn.Module):
        """
        `if` statement with backed dynamic shape predicate will be specialized into
        one particular branch and generate a guard. However, export will fail if the
        the dimension is marked as dynamic shape from higher level API.
        """
    
        def forward(self, x):
            if x.shape[0] == 3:
                return x.cos()
    
            return x.sin()
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: "f32[3, 2, 2]"):
                    cos: "f32[3, 2, 2]" = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
                return (cos,)
                
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cos'), target=None)])
    Range constraints: {}
    


dynamic_shape_map
^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.map <torch.map>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    import torch
    
    from functorch.experimental.control_flow import map
    
    
    class DynamicShapeMap(torch.nn.Module):
        """
        functorch map() maps a function over the first tensor dimension.
        """
    
        def __init__(self):
            super().__init__()
    
        def forward(self, xs, y):
            def body(x, y):
                return x + y
    
            return map(body, xs, y)
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: "f32[3, 2]", arg1_1: "f32[2]"):
                    body_graph_0 = self.body_graph_0
                map_impl = torch.ops.higher_order.map_impl(body_graph_0, [arg0_1], [arg1_1]);  body_graph_0 = arg0_1 = arg1_1 = None
                getitem: "f32[3, 2]" = map_impl[0];  map_impl = None
                return (getitem,)
                
            class <lambda>(torch.nn.Module):
                def forward(self, arg0_1: "f32[2]", arg1_1: "f32[2]"):
                            add: "f32[2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                    return (add,)
                    
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
    Range constraints: {}
    


dynamic_shape_round
^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`python.builtin <python.builtin>`

    Support Level: NOT_SUPPORTED_YET

Original source code:

.. code-block:: python

    import torch
    
    from torch.export import Dim
    
    x = torch.ones(3, 2)
    dim0_x = Dim("dim0_x")
    
    class DynamicShapeRound(torch.nn.Module):
        """
        Calling round on dynamic shapes is not supported.
        """
    
        def __init__(self):
            super().__init__()
    
        def forward(self, x):
            return x[: round(x.shape[0] / 2)]
    

Result:

.. code-block::

    AssertionError: 


dynamic_shape_slicing
^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    import torch
    
    
    
    class DynamicShapeSlicing(torch.nn.Module):
        """
        Slices with dynamic shape arguments should be captured into the graph
        rather than being baked in.
        """
    
        def __init__(self):
            super().__init__()
    
        def forward(self, x):
            return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: "f32[3, 2]"):
                    slice_1: "f32[1, 2]" = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 1);  arg0_1 = None
                slice_2: "f32[1, 1]" = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 9223372036854775807, 2);  slice_1 = None
                return (slice_2,)
                
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='slice_2'), target=None)])
    Range constraints: {}
    


dynamic_shape_view
^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    import torch
    
    
    
    class DynamicShapeView(torch.nn.Module):
        """
        Dynamic shapes should be propagated to view arguments instead of being
        baked into the exported graph.
        """
    
        def __init__(self):
            super().__init__()
    
        def forward(self, x):
            new_x_shape = x.size()[:-1] + (2, 5)
            x = x.view(*new_x_shape)
            return x.permute(0, 2, 1)
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: "f32[10, 10]"):
                    view: "f32[10, 2, 5]" = torch.ops.aten.view.default(arg0_1, [10, 2, 5]);  arg0_1 = None
                
                    permute: "f32[10, 5, 2]" = torch.ops.aten.permute.default(view, [0, 2, 1]);  view = None
                return (permute,)
                
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='permute'), target=None)])
    Range constraints: {}
    


list_contains
^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`python.assert <python.assert>`, :doc:`python.data-structure <python.data-structure>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    import torch
    
    
    
    class ListContains(torch.nn.Module):
        """
        List containment relation can be checked on a dynamic shape or constants.
        """
        def __init__(self):
            super().__init__()
    
        def forward(self, x):
            assert x.size(-1) in [6, 2]
            assert x.size(0) not in [4, 5, 6]
            assert "monkey" not in ["cow", "pig"]
            return x + x
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: "f32[3, 2]"):
                    add: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 = None
                return (add,)
                
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
    Range constraints: {}
    


scalar_output
^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    import torch
    
    from torch.export import Dim
    
    x = torch.ones(3, 2)
    dim1_x = Dim("dim1_x")
    
    class ScalarOutput(torch.nn.Module):
        """
        Returning scalar values from the graph is supported, in addition to Tensor
        outputs. Symbolic shapes are captured and rank is specialized.
        """
        def __init__(self):
            super().__init__()
    
        def forward(self, x):
            return x.shape[1] + 1
    

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: "f32[3, s0]"):
                # No stacktrace found for following nodes
                sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(arg0_1, 1);  arg0_1 = None
                add: "Sym(s0 + 1)" = sym_size_int + 1;  sym_size_int = None
                return (add,)
                
    Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=SymIntArgument(name='add'), target=None)])
    Range constraints: {s0: ValueRanges(lower=2, upper=oo, is_bool=False)}