.. _torch.export_db:

ExportDB
========

ExportDB is a centralized dataset of supported and unsupported export cases.
It is targeted towards users who want to understand specifically what types of
code are supported, the subtleties of export, and how to modify their existing
code to be compatible with export. Note that this is not an exhaustive set of
everything that is supported by exportdb, but it covers the
most common and confusing use cases that users will run into.

If you have a feature that you think needs a stronger guarantee from us to
support in export please create an issue in the pytorch/pytorch repo with a module:export tag.


.. toctree::
    :maxdepth: 1
    :caption: Tags

    torch.escape-hatch
    torch.dynamic-shape
    torch.cond
    python.closure
    torch.dynamic-value
    python.data-structure
    python.assert
    python.control-flow
    torch.map
    python.builtin
    python.object-model
    python.context-manager
    torch.operator
    torch.mutation


Supported
---------

assume_constant_result
^^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.escape-hatch <torch.escape-hatch>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    import torch._dynamo as torchdynamo
    
    
    class AssumeConstantResult(torch.nn.Module):
        """
        Applying `assume_constant_result` decorator to burn make non-tracable code as constant.
        """
    
        @torchdynamo.assume_constant_result
        def get_item(self, y):
            return y.int().item()
    
        def forward(self, x, y):
            return x[: self.get_item(y)]
    
    example_args = (torch.randn(3, 2), torch.tensor(4))
    tags = {"torch.escape-hatch"}
    model = AssumeConstantResult()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, 2]", y: "i64[]"):
                     slice_1: "f32[3, 2]" = torch.ops.aten.slice.Tensor(x, 0, 0, 4);  x = None
                return (slice_1,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        y: USER_INPUT
        
        # outputs
        slice_1: USER_OUTPUT
        
    Range constraints: {}
    


autograd_function
^^^^^^^^^^^^^^^^^

.. note::

    Tags: 

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    class MyAutogradFunction(torch.autograd.Function):
        @staticmethod
        def forward(ctx, x):
            return x.clone()
    
        @staticmethod
        def backward(ctx, grad_output):
            return grad_output + 1
    
    class AutogradFunction(torch.nn.Module):
        """
        TorchDynamo does not keep track of backward() on autograd functions. We recommend to
        use `allow_in_graph` to mitigate this problem.
        """
    
        def forward(self, x):
            return MyAutogradFunction.apply(x)
    
    example_args = (torch.randn(3, 2),)
    model = AutogradFunction()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, 2]"):
                     clone: "f32[3, 2]" = torch.ops.aten.clone.default(x);  x = None
                return (clone,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        
        # outputs
        clone: USER_OUTPUT
        
    Range constraints: {}
    


class_method
^^^^^^^^^^^^

.. note::

    Tags: 

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    class ClassMethod(torch.nn.Module):
        """
        Class methods are inlined during tracing.
        """
    
        @classmethod
        def method(cls, x):
            return x + 1
    
        def __init__(self) -> None:
            super().__init__()
            self.linear = torch.nn.Linear(4, 2)
    
        def forward(self, x):
            x = self.linear(x)
            return self.method(x) * self.__class__.method(x) * type(self).method(x)
    
    example_args = (torch.randn(3, 4),)
    model = ClassMethod()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, p_linear_weight: "f32[2, 4]", p_linear_bias: "f32[2]", x: "f32[3, 4]"):
                     linear: "f32[3, 2]" = torch.ops.aten.linear.default(x, p_linear_weight, p_linear_bias);  x = p_linear_weight = p_linear_bias = None
                
                     add: "f32[3, 2]" = torch.ops.aten.add.Tensor(linear, 1)
                add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(linear, 1)
                
                     mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(add, add_1);  add = add_1 = None
                
                     add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(linear, 1);  linear = None
                
                     mul_1: "f32[3, 2]" = torch.ops.aten.mul.Tensor(mul, add_2);  mul = add_2 = None
                return (mul_1,)
                
    Graph signature: 
        # inputs
        p_linear_weight: PARAMETER target='linear.weight'
        p_linear_bias: PARAMETER target='linear.bias'
        x: USER_INPUT
        
        # outputs
        mul_1: USER_OUTPUT
        
    Range constraints: {}
    


cond_branch_class_method
^^^^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    from functorch.experimental.control_flow import cond
    
    class MySubModule(torch.nn.Module):
        def foo(self, x):
            return x.cos()
    
        def forward(self, x):
            return self.foo(x)
    
    class CondBranchClassMethod(torch.nn.Module):
        """
        The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
          - both branches must take the same args, which must also match the branch args passed to cond.
          - both branches must return a single tensor
          - returned tensor must have the same tensor metadata, e.g. shape and dtype
          - branch function can be free function, nested function, lambda, class methods
          - branch function can not have closure variables
          - no inplace mutations on inputs or global variables
    
    
        This example demonstrates using class method in cond().
    
        NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
        """
    
        def __init__(self) -> None:
            super().__init__()
            self.subm = MySubModule()
    
        def bar(self, x):
            return x.sin()
    
        def forward(self, x):
            return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
    
    example_args = (torch.randn(3),)
    tags = {
        "torch.cond",
        "torch.dynamic-shape",
    }
    model = CondBranchClassMethod()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3]"):
                     sin: "f32[3]" = torch.ops.aten.sin.default(x);  x = None
                return (sin,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        
        # outputs
        sin: USER_OUTPUT
        
    Range constraints: {}
    


cond_branch_nested_function
^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    from functorch.experimental.control_flow import cond
    
    class CondBranchNestedFunction(torch.nn.Module):
        """
        The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
          - both branches must take the same args, which must also match the branch args passed to cond.
          - both branches must return a single tensor
          - returned tensor must have the same tensor metadata, e.g. shape and dtype
          - branch function can be free function, nested function, lambda, class methods
          - branch function can not have closure variables
          - no inplace mutations on inputs or global variables
    
        This example demonstrates using nested function in cond().
    
        NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
        """
    
        def forward(self, x):
            def true_fn(x):
                def inner_true_fn(y):
                    return x + y
    
                return inner_true_fn(x)
    
            def false_fn(x):
                def inner_false_fn(y):
                    return x - y
    
                return inner_false_fn(x)
    
            return cond(x.shape[0] < 10, true_fn, false_fn, [x])
    
    example_args = (torch.randn(3),)
    tags = {
        "torch.cond",
        "torch.dynamic-shape",
    }
    model = CondBranchNestedFunction()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3]"):
                     add: "f32[3]" = torch.ops.aten.add.Tensor(x, x);  x = None
                return (add,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        
        # outputs
        add: USER_OUTPUT
        
    Range constraints: {}
    


cond_branch_nonlocal_variables
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    from functorch.experimental.control_flow import cond
    
    class CondBranchNonlocalVariables(torch.nn.Module):
        """
        The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
        - both branches must take the same args, which must also match the branch args passed to cond.
        - both branches must return a single tensor
        - returned tensor must have the same tensor metadata, e.g. shape and dtype
        - branch function can be free function, nested function, lambda, class methods
        - branch function can not have closure variables
        - no inplace mutations on inputs or global variables
    
        This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions.
    
        The code below will not work because capturing closure variables is not supported.
        ```
        my_tensor_var = x + 100
        my_primitive_var = 3.14
    
        def true_fn(y):
            nonlocal my_tensor_var, my_primitive_var
            return y + my_tensor_var + my_primitive_var
    
        def false_fn(y):
            nonlocal my_tensor_var, my_primitive_var
            return y - my_tensor_var - my_primitive_var
    
        return cond(x.shape[0] > 5, true_fn, false_fn, [x])
        ```
    
        NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
        """
    
        def forward(self, x):
            my_tensor_var = x + 100
            my_primitive_var = 3.14
    
            def true_fn(x, y, z):
                return x + y + z
    
            def false_fn(x, y, z):
                return x - y - z
    
            return cond(
                x.shape[0] > 5,
                true_fn,
                false_fn,
                [x, my_tensor_var, torch.tensor(my_primitive_var)],
            )
    
    example_args = (torch.randn(6),)
    tags = {
        "torch.cond",
        "torch.dynamic-shape",
    }
    model = CondBranchNonlocalVariables()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, c_lifted_tensor_0: "f32[]", x: "f32[6]"):
                     add: "f32[6]" = torch.ops.aten.add.Tensor(x, 100)
                
                     lift_fresh_copy: "f32[]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0);  c_lifted_tensor_0 = None
                detach_: "f32[]" = torch.ops.aten.detach_.default(lift_fresh_copy);  lift_fresh_copy = None
                
                     add_1: "f32[6]" = torch.ops.aten.add.Tensor(x, add);  x = add = None
                add_2: "f32[6]" = torch.ops.aten.add.Tensor(add_1, detach_);  add_1 = detach_ = None
                return (add_2,)
                
    Graph signature: 
        # inputs
        c_lifted_tensor_0: CONSTANT_TENSOR target='lifted_tensor_0'
        x: USER_INPUT
        
        # outputs
        add_2: USER_OUTPUT
        
    Range constraints: {}
    


cond_closed_over_variable
^^^^^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`python.closure <python.closure>`, :doc:`torch.cond <torch.cond>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    from functorch.experimental.control_flow import cond
    
    class CondClosedOverVariable(torch.nn.Module):
        """
        torch.cond() supports branches closed over arbitrary variables.
        """
    
        def forward(self, pred, x):
            def true_fn(val):
                return x * 2
    
            def false_fn(val):
                return x - 2
    
            return cond(pred, true_fn, false_fn, [x + 1])
    
    example_args = (torch.tensor(True), torch.randn(3, 2))
    tags = {"torch.cond", "python.closure"}
    model = CondClosedOverVariable()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, pred: "b8[]", x: "f32[3, 2]"):
                     add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 1);  add = None
                
                     true_graph_0 = self.true_graph_0
                false_graph_0 = self.false_graph_0
                cond = torch.ops.higher_order.cond(pred, true_graph_0, false_graph_0, (x,));  pred = true_graph_0 = false_graph_0 = x = None
                getitem: "f32[3, 2]" = cond[0];  cond = None
                return (getitem,)
                
            class true_graph_0(torch.nn.Module):
                def forward(self, x: "f32[3, 2]"):
                             mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(x, 2);  x = None
                    return (mul,)
                    
            class false_graph_0(torch.nn.Module):
                def forward(self, x: "f32[3, 2]"):
                             sub: "f32[3, 2]" = torch.ops.aten.sub.Tensor(x, 2);  x = None
                    return (sub,)
                    
    Graph signature: 
        # inputs
        pred: USER_INPUT
        x: USER_INPUT
        
        # outputs
        getitem: USER_OUTPUT
        
    Range constraints: {}
    


cond_operands
^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    from torch.export import Dim
    
    x = torch.randn(3, 2)
    y = torch.randn(2)
    dim0_x = Dim("dim0_x")
    
    class CondOperands(torch.nn.Module):
        """
        The operands passed to cond() must be:
        - a list of tensors
        - match arguments of `true_fn` and `false_fn`
    
        NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
        """
    
        def forward(self, x, y):
            def true_fn(x, y):
                return x + y
    
            def false_fn(x, y):
                return x - y
    
            return torch.cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
    
    example_args = (x, y)
    tags = {
        "torch.cond",
        "torch.dynamic-shape",
    }
    extra_inputs = (torch.randn(2, 2), torch.randn(2))
    dynamic_shapes = {"x": {0: dim0_x}, "y": None}
    model = CondOperands()
    

    torch.export.export(model, example_args, dynamic_shapes=dynamic_shapes)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[s77, 2]", y: "f32[2]"):
                 # 
                sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0)
                
                     gt: "Sym(s77 > 2)" = sym_size_int_1 > 2;  sym_size_int_1 = None
                
                     true_graph_0 = self.true_graph_0
                false_graph_0 = self.false_graph_0
                cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x, y));  gt = true_graph_0 = false_graph_0 = x = y = None
                getitem: "f32[s77, 2]" = cond[0];  cond = None
                return (getitem,)
                
            class true_graph_0(torch.nn.Module):
                def forward(self, x: "f32[s77, 2]", y: "f32[2]"):
                             add: "f32[s77, 2]" = torch.ops.aten.add.Tensor(x, y);  x = y = None
                    return (add,)
                    
            class false_graph_0(torch.nn.Module):
                def forward(self, x: "f32[s77, 2]", y: "f32[2]"):
                             sub: "f32[s77, 2]" = torch.ops.aten.sub.Tensor(x, y);  x = y = None
                    return (sub,)
                    
    Graph signature: 
        # inputs
        x: USER_INPUT
        y: USER_INPUT
        
        # outputs
        getitem: USER_OUTPUT
        
    Range constraints: {s77: VR[0, int_oo]}
    


cond_predicate
^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    from functorch.experimental.control_flow import cond
    
    class CondPredicate(torch.nn.Module):
        """
        The conditional statement (aka predicate) passed to cond() must be one of the following:
          - torch.Tensor with a single element
          - boolean expression
    
        NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
        """
    
        def forward(self, x):
            pred = x.dim() > 2 and x.shape[2] > 10
    
            return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
    
    example_args = (torch.randn(6, 4, 3),)
    tags = {
        "torch.cond",
        "torch.dynamic-shape",
    }
    model = CondPredicate()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[6, 4, 3]"):
                     sin: "f32[6, 4, 3]" = torch.ops.aten.sin.default(x);  x = None
                return (sin,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        
        # outputs
        sin: USER_OUTPUT
        
    Range constraints: {}
    


constrain_as_size_example
^^^^^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-value <torch.dynamic-value>`, :doc:`torch.escape-hatch <torch.escape-hatch>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    
    class ConstrainAsSizeExample(torch.nn.Module):
        """
        If the value is not known at tracing time, you can provide hint so that we
        can trace further. Please look at torch._check and torch._check_is_size APIs.
        torch._check_is_size is used for values that NEED to be used for constructing
        tensor.
        """
    
        def forward(self, x):
            a = x.item()
            torch._check_is_size(a)
            torch._check(a <= 5)
            return torch.zeros((a, 5))
    
    
    example_args = (torch.tensor(4),)
    tags = {
        "torch.dynamic-value",
        "torch.escape-hatch",
    }
    model = ConstrainAsSizeExample()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "i64[]"):
                     item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None
                
                 # 
                sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item);  sym_constrain_range_for_size_default = None
                
                     ge_1: "Sym(u0 >= 0)" = item >= 0
                _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'");  ge_1 = _assert_scalar_default = None
                le_1: "Sym(u0 <= 5)" = item <= 5
                _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 5 on node 'le_1'");  le_1 = _assert_scalar_default_1 = None
                
                     zeros: "f32[u0, 5]" = torch.ops.aten.zeros.default([item, 5], device = device(type='cpu'), pin_memory = False);  item = None
                return (zeros,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        
        # outputs
        zeros: USER_OUTPUT
        
    Range constraints: {u0: VR[0, 5], u1: VR[0, 5]}
    


constrain_as_value_example
^^^^^^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-value <torch.dynamic-value>`, :doc:`torch.escape-hatch <torch.escape-hatch>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    
    class ConstrainAsValueExample(torch.nn.Module):
        """
        If the value is not known at tracing time, you can provide hint so that we
        can trace further. Please look at torch._check and torch._check_is_size APIs.
        torch._check is used for values that don't need to be used for constructing
        tensor.
        """
    
        def forward(self, x, y):
            a = x.item()
            torch._check(a >= 0)
            torch._check(a <= 5)
    
            if a < 6:
                return y.sin()
            return y.cos()
    
    
    example_args = (torch.tensor(4), torch.randn(5, 5))
    tags = {
        "torch.dynamic-value",
        "torch.escape-hatch",
    }
    model = ConstrainAsValueExample()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "i64[]", y: "f32[5, 5]"):
                     item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None
                ge_1: "Sym(u0 >= 0)" = item >= 0
                _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'");  ge_1 = _assert_scalar_default = None
                le_1: "Sym(u0 <= 5)" = item <= 5;  item = None
                _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 5 on node 'le_1'");  le_1 = _assert_scalar_default_1 = None
                
                     sin: "f32[5, 5]" = torch.ops.aten.sin.default(y);  y = None
                return (sin,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        y: USER_INPUT
        
        # outputs
        sin: USER_OUTPUT
        
    Range constraints: {u0: VR[0, 5], u1: VR[0, 5]}
    


decorator
^^^^^^^^^

.. note::

    Tags: 

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import functools
    
    import torch
    
    def test_decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            return func(*args, **kwargs) + 1
    
        return wrapper
    
    class Decorator(torch.nn.Module):
        """
        Decorators calls are inlined into the exported function during tracing.
        """
    
        @test_decorator
        def forward(self, x, y):
            return x + y
    
    example_args = (torch.randn(3, 2), torch.randn(3, 2))
    model = Decorator()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, 2]", y: "f32[3, 2]"):
                     add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, y);  x = y = None
                
                     add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(add, 1);  add = None
                return (add_1,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        y: USER_INPUT
        
        # outputs
        add_1: USER_OUTPUT
        
    Range constraints: {}
    


dictionary
^^^^^^^^^^

.. note::

    Tags: :doc:`python.data-structure <python.data-structure>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    class Dictionary(torch.nn.Module):
        """
        Dictionary structures are inlined and flattened along tracing.
        """
    
        def forward(self, x, y):
            elements = {}
            elements["x2"] = x * x
            y = y * elements["x2"]
            return {"y": y}
    
    example_args = (torch.randn(3, 2), torch.tensor(4))
    tags = {"python.data-structure"}
    model = Dictionary()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, 2]", y: "i64[]"):
                     mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(x, x);  x = None
                
                     mul_1: "f32[3, 2]" = torch.ops.aten.mul.Tensor(y, mul);  y = mul = None
                return (mul_1,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        y: USER_INPUT
        
        # outputs
        mul_1: USER_OUTPUT
        
    Range constraints: {}
    


dynamic_shape_assert
^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`python.assert <python.assert>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    class DynamicShapeAssert(torch.nn.Module):
        """
        A basic usage of python assertion.
        """
    
        def forward(self, x):
            # assertion with error message
            assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2"
            # assertion without error message
            assert x.shape[0] > 1
            return x
    
    example_args = (torch.randn(3, 2),)
    tags = {"python.assert"}
    model = DynamicShapeAssert()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, 2]"):
                return (x,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        
        # outputs
        x: USER_OUTPUT
        
    Range constraints: {}
    


dynamic_shape_constructor
^^^^^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    class DynamicShapeConstructor(torch.nn.Module):
        """
        Tensor constructors should be captured with dynamic shape inputs rather
        than being baked in with static shape.
        """
    
        def forward(self, x):
            return torch.zeros(x.shape[0] * 2)
    
    example_args = (torch.randn(3, 2),)
    tags = {"torch.dynamic-shape"}
    model = DynamicShapeConstructor()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, 2]"):
                     zeros: "f32[6]" = torch.ops.aten.zeros.default([6], device = device(type='cpu'), pin_memory = False)
                return (zeros,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        
        # outputs
        zeros: USER_OUTPUT
        
    Range constraints: {}
    


dynamic_shape_if_guard
^^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`python.control-flow <python.control-flow>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    class DynamicShapeIfGuard(torch.nn.Module):
        """
        `if` statement with backed dynamic shape predicate will be specialized into
        one particular branch and generate a guard. However, export will fail if the
        the dimension is marked as dynamic shape from higher level API.
        """
    
        def forward(self, x):
            if x.shape[0] == 3:
                return x.cos()
    
            return x.sin()
    
    example_args = (torch.randn(3, 2, 2),)
    tags = {"torch.dynamic-shape", "python.control-flow"}
    model = DynamicShapeIfGuard()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, 2, 2]"):
                     cos: "f32[3, 2, 2]" = torch.ops.aten.cos.default(x);  x = None
                return (cos,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        
        # outputs
        cos: USER_OUTPUT
        
    Range constraints: {}
    


dynamic_shape_map
^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.map <torch.map>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    from functorch.experimental.control_flow import map
    
    class DynamicShapeMap(torch.nn.Module):
        """
        functorch map() maps a function over the first tensor dimension.
        """
    
        def forward(self, xs, y):
            def body(x, y):
                return x + y
    
            return map(body, xs, y)
    
    example_args = (torch.randn(3, 2), torch.randn(2))
    tags = {"torch.dynamic-shape", "torch.map"}
    model = DynamicShapeMap()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, xs: "f32[3, 2]", y: "f32[2]"):
                     body_graph_0 = self.body_graph_0
                map_impl = torch.ops.higher_order.map_impl(body_graph_0, [xs], [y]);  body_graph_0 = xs = y = None
                getitem: "f32[3, 2]" = map_impl[0];  map_impl = None
                return (getitem,)
                
            class body_graph_0(torch.nn.Module):
                def forward(self, xs: "f32[2]", y: "f32[2]"):
                             add: "f32[2]" = torch.ops.aten.add.Tensor(xs, y);  xs = y = None
                    return (add,)
                    
    Graph signature: 
        # inputs
        xs: USER_INPUT
        y: USER_INPUT
        
        # outputs
        getitem: USER_OUTPUT
        
    Range constraints: {}
    


dynamic_shape_slicing
^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    class DynamicShapeSlicing(torch.nn.Module):
        """
        Slices with dynamic shape arguments should be captured into the graph
        rather than being baked in.
        """
    
        def forward(self, x):
            return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]
    
    example_args = (torch.randn(3, 2),)
    tags = {"torch.dynamic-shape"}
    model = DynamicShapeSlicing()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, 2]"):
                     slice_1: "f32[1, 2]" = torch.ops.aten.slice.Tensor(x, 0, 0, 1);  x = None
                slice_2: "f32[1, 1]" = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 9223372036854775807, 2);  slice_1 = None
                return (slice_2,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        
        # outputs
        slice_2: USER_OUTPUT
        
    Range constraints: {}
    


dynamic_shape_view
^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    class DynamicShapeView(torch.nn.Module):
        """
        Dynamic shapes should be propagated to view arguments instead of being
        baked into the exported graph.
        """
    
        def forward(self, x):
            new_x_shape = x.size()[:-1] + (2, 5)
            x = x.view(*new_x_shape)
            return x.permute(0, 2, 1)
    
    example_args = (torch.randn(10, 10),)
    tags = {"torch.dynamic-shape"}
    model = DynamicShapeView()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[10, 10]"):
                     view: "f32[10, 2, 5]" = torch.ops.aten.view.default(x, [10, 2, 5]);  x = None
                
                     permute: "f32[10, 5, 2]" = torch.ops.aten.permute.default(view, [0, 2, 1]);  view = None
                return (permute,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        
        # outputs
        permute: USER_OUTPUT
        
    Range constraints: {}
    


fn_with_kwargs
^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`python.data-structure <python.data-structure>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    class FnWithKwargs(torch.nn.Module):
        """
        Keyword arguments are not supported at the moment.
        """
    
        def forward(self, pos0, tuple0, *myargs, mykw0, **mykwargs):
            out = pos0
            for arg in tuple0:
                out = out * arg
            for arg in myargs:
                out = out * arg
            out = out * mykw0
            out = out * mykwargs["input0"] * mykwargs["input1"]
            return out
    
    example_args = (
        torch.randn(4),
        (torch.randn(4), torch.randn(4)),
        *[torch.randn(4), torch.randn(4)]
    )
    example_kwargs = {
        "mykw0": torch.randn(4),
        "input0": torch.randn(4),
        "input1": torch.randn(4),
    }
    tags = {"python.data-structure"}
    model = FnWithKwargs()
    

    torch.export.export(model, example_args, example_kwargs)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, pos0: "f32[4]", tuple0_0: "f32[4]", tuple0_1: "f32[4]", myargs_0: "f32[4]", myargs_1: "f32[4]", mykw0: "f32[4]", input0: "f32[4]", input1: "f32[4]"):
                     mul: "f32[4]" = torch.ops.aten.mul.Tensor(pos0, tuple0_0);  pos0 = tuple0_0 = None
                mul_1: "f32[4]" = torch.ops.aten.mul.Tensor(mul, tuple0_1);  mul = tuple0_1 = None
                
                     mul_2: "f32[4]" = torch.ops.aten.mul.Tensor(mul_1, myargs_0);  mul_1 = myargs_0 = None
                mul_3: "f32[4]" = torch.ops.aten.mul.Tensor(mul_2, myargs_1);  mul_2 = myargs_1 = None
                
                     mul_4: "f32[4]" = torch.ops.aten.mul.Tensor(mul_3, mykw0);  mul_3 = mykw0 = None
                
                     mul_5: "f32[4]" = torch.ops.aten.mul.Tensor(mul_4, input0);  mul_4 = input0 = None
                mul_6: "f32[4]" = torch.ops.aten.mul.Tensor(mul_5, input1);  mul_5 = input1 = None
                return (mul_6,)
                
    Graph signature: 
        # inputs
        pos0: USER_INPUT
        tuple0_0: USER_INPUT
        tuple0_1: USER_INPUT
        myargs_0: USER_INPUT
        myargs_1: USER_INPUT
        mykw0: USER_INPUT
        input0: USER_INPUT
        input1: USER_INPUT
        
        # outputs
        mul_6: USER_OUTPUT
        
    Range constraints: {}
    


list_contains
^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`python.assert <python.assert>`, :doc:`python.data-structure <python.data-structure>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    class ListContains(torch.nn.Module):
        """
        List containment relation can be checked on a dynamic shape or constants.
        """
    
        def forward(self, x):
            assert x.size(-1) in [6, 2]
            assert x.size(0) not in [4, 5, 6]
            assert "monkey" not in ["cow", "pig"]
            return x + x
    
    example_args = (torch.randn(3, 2),)
    tags = {"torch.dynamic-shape", "python.data-structure", "python.assert"}
    model = ListContains()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, 2]"):
                     add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, x);  x = None
                return (add,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        
        # outputs
        add: USER_OUTPUT
        
    Range constraints: {}
    


list_unpack
^^^^^^^^^^^

.. note::

    Tags: :doc:`python.data-structure <python.data-structure>`, :doc:`python.control-flow <python.control-flow>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    
    import torch
    
    class ListUnpack(torch.nn.Module):
        """
        Lists are treated as static construct, therefore unpacking should be
        erased after tracing.
        """
    
        def forward(self, args: list[torch.Tensor]):
            """
            Lists are treated as static construct, therefore unpacking should be
            erased after tracing.
            """
            x, *y = args
            return x + y[0]
    
    example_args = ([torch.randn(3, 2), torch.tensor(4), torch.tensor(5)],)
    tags = {"python.control-flow", "python.data-structure"}
    model = ListUnpack()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, args_0: "f32[3, 2]", args_1: "i64[]", args_2: "i64[]"):
                     add: "f32[3, 2]" = torch.ops.aten.add.Tensor(args_0, args_1);  args_0 = args_1 = None
                return (add,)
                
    Graph signature: 
        # inputs
        args_0: USER_INPUT
        args_1: USER_INPUT
        args_2: USER_INPUT
        
        # outputs
        add: USER_OUTPUT
        
    Range constraints: {}
    


nested_function
^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`python.closure <python.closure>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    class NestedFunction(torch.nn.Module):
        """
        Nested functions are traced through. Side effects on global captures
        are not supported though.
        """
    
        def forward(self, a, b):
            x = a + b
            z = a - b
    
            def closure(y):
                nonlocal x
                x += 1
                return x * y + z
    
            return closure(x)
    
    example_args = (torch.randn(3, 2), torch.randn(2))
    tags = {"python.closure"}
    model = NestedFunction()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, a: "f32[3, 2]", b: "f32[2]"):
                     add: "f32[3, 2]" = torch.ops.aten.add.Tensor(a, b)
                
                     sub: "f32[3, 2]" = torch.ops.aten.sub.Tensor(a, b);  a = b = None
                
                     add_: "f32[3, 2]" = torch.ops.aten.add_.Tensor(add, 1);  add = None
                
                     mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(add_, add_);  add_ = None
                add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(mul, sub);  mul = sub = None
                return (add_1,)
                
    Graph signature: 
        # inputs
        a: USER_INPUT
        b: USER_INPUT
        
        # outputs
        add_1: USER_OUTPUT
        
    Range constraints: {}
    


null_context_manager
^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`python.context-manager <python.context-manager>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import contextlib
    
    import torch
    
    class NullContextManager(torch.nn.Module):
        """
        Null context manager in Python will be traced out.
        """
    
        def forward(self, x):
            """
            Null context manager in Python will be traced out.
            """
            ctx = contextlib.nullcontext()
            with ctx:
                return x.sin() + x.cos()
    
    example_args = (torch.randn(3, 2),)
    tags = {"python.context-manager"}
    model = NullContextManager()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, 2]"):
                     sin: "f32[3, 2]" = torch.ops.aten.sin.default(x)
                cos: "f32[3, 2]" = torch.ops.aten.cos.default(x);  x = None
                add: "f32[3, 2]" = torch.ops.aten.add.Tensor(sin, cos);  sin = cos = None
                return (add,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        
        # outputs
        add: USER_OUTPUT
        
    Range constraints: {}
    


pytree_flatten
^^^^^^^^^^^^^^

.. note::

    Tags: 

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    from torch.utils import _pytree as pytree
    
    class PytreeFlatten(torch.nn.Module):
        """
        Pytree from PyTorch can be captured by TorchDynamo.
        """
    
        def forward(self, x):
            y, _spec = pytree.tree_flatten(x)
            return y[0] + 1
    
    example_args = ({1: torch.randn(3, 2), 2: torch.randn(3, 2)},),
    model = PytreeFlatten()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x_0_1: "f32[3, 2]", x_0_2: "f32[3, 2]"):
                     add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x_0_1, 1);  x_0_1 = None
                return (add,)
                
    Graph signature: 
        # inputs
        x_0_1: USER_INPUT
        x_0_2: USER_INPUT
        
        # outputs
        add: USER_OUTPUT
        
    Range constraints: {}
    


scalar_output
^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    from torch.export import Dim
    
    x = torch.randn(3, 2)
    dim1_x = Dim("dim1_x")
    
    class ScalarOutput(torch.nn.Module):
        """
        Returning scalar values from the graph is supported, in addition to Tensor
        outputs. Symbolic shapes are captured and rank is specialized.
        """
        def __init__(self) -> None:
            super().__init__()
    
        def forward(self, x):
            return x.shape[1] + 1
    
    example_args = (x,)
    tags = {"torch.dynamic-shape"}
    dynamic_shapes = {"x": {1: dim1_x}}
    model = ScalarOutput()
    

    torch.export.export(model, example_args, dynamic_shapes=dynamic_shapes)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, s27]"):
                 # 
                sym_size_int_1: "Sym(s27)" = torch.ops.aten.sym_size.int(x, 1);  x = None
                
                     add: "Sym(s27 + 1)" = sym_size_int_1 + 1;  sym_size_int_1 = None
                return (add,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        
        # outputs
        add: USER_OUTPUT
        
    Range constraints: {s27: VR[0, int_oo]}
    


specialized_attribute
^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: 

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    from enum import Enum
    
    import torch
    
    class Animal(Enum):
        COW = "moo"
    
    class SpecializedAttribute(torch.nn.Module):
        """
        Model attributes are specialized.
        """
    
        def __init__(self) -> None:
            super().__init__()
            self.a = "moo"
            self.b = 4
    
        def forward(self, x):
            if self.a == Animal.COW.value:
                return x * x + self.b
            else:
                raise ValueError("bad")
    
    example_args = (torch.randn(3, 2),)
    model = SpecializedAttribute()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, 2]"):
                     mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(x, x);  x = None
                add: "f32[3, 2]" = torch.ops.aten.add.Tensor(mul, 4);  mul = None
                return (add,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        
        # outputs
        add: USER_OUTPUT
        
    Range constraints: {}
    


static_for_loop
^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`python.control-flow <python.control-flow>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    class StaticForLoop(torch.nn.Module):
        """
        A for loop with constant number of iterations should be unrolled in the exported graph.
        """
    
        def forward(self, x):
            # constant
            ret = [i + x for i in range(10)]
            return ret
    
    example_args = (torch.randn(3, 2),)
    tags = {"python.control-flow"}
    model = StaticForLoop()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, 2]"):
                     add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 0)
                add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 1)
                add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 2)
                add_3: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 3)
                add_4: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 4)
                add_5: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 5)
                add_6: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 6)
                add_7: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 7)
                add_8: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 8)
                add_9: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 9);  x = None
                return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        
        # outputs
        add: USER_OUTPUT
        add_1: USER_OUTPUT
        add_2: USER_OUTPUT
        add_3: USER_OUTPUT
        add_4: USER_OUTPUT
        add_5: USER_OUTPUT
        add_6: USER_OUTPUT
        add_7: USER_OUTPUT
        add_8: USER_OUTPUT
        add_9: USER_OUTPUT
        
    Range constraints: {}
    


static_if
^^^^^^^^^

.. note::

    Tags: :doc:`python.control-flow <python.control-flow>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    class StaticIf(torch.nn.Module):
        """
        `if` statement with static predicate value should be traced through with the
        taken branch.
        """
    
        def forward(self, x):
            if len(x.shape) == 3:
                return x + torch.ones(1, 1, 1)
    
            return x
    
    example_args = (torch.randn(3, 2, 2),)
    tags = {"python.control-flow"}
    model = StaticIf()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, 2, 2]"):
                     ones: "f32[1, 1, 1]" = torch.ops.aten.ones.default([1, 1, 1], device = device(type='cpu'), pin_memory = False)
                add: "f32[3, 2, 2]" = torch.ops.aten.add.Tensor(x, ones);  x = ones = None
                return (add,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        
        # outputs
        add: USER_OUTPUT
        
    Range constraints: {}
    


tensor_setattr
^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`python.builtin <python.builtin>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    
    class TensorSetattr(torch.nn.Module):
        """
        setattr() call onto tensors is not supported.
        """
        def forward(self, x, attr):
            setattr(x, attr, torch.randn(3, 2))
            return x + 4
    
    example_args = (torch.randn(3, 2), "attr")
    tags = {"python.builtin"}
    model = TensorSetattr()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, 2]", attr):
                     randn: "f32[3, 2]" = torch.ops.aten.randn.default([3, 2], device = device(type='cpu'), pin_memory = False);  randn = None
                
                     add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 4);  x = None
                return (add,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        attr: USER_INPUT
        
        # outputs
        add: USER_OUTPUT
        
    Range constraints: {}
    


type_reflection_method
^^^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`python.builtin <python.builtin>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    class A:
        @classmethod
        def func(cls, x):
            return 1 + x
    
    class TypeReflectionMethod(torch.nn.Module):
        """
        type() calls on custom objects followed by attribute accesses are not allowed
        due to its overly dynamic nature.
        """
    
        def forward(self, x):
            a = A()
            return type(a).func(x)
    
    
    example_args = (torch.randn(3, 4),)
    tags = {"python.builtin"}
    model = TypeReflectionMethod()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, 4]"):
                     add: "f32[3, 4]" = torch.ops.aten.add.Tensor(x, 1);  x = None
                return (add,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        
        # outputs
        add: USER_OUTPUT
        
    Range constraints: {}
    


user_input_mutation
^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.mutation <torch.mutation>`

    Support Level: SUPPORTED

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    
    class UserInputMutation(torch.nn.Module):
        """
        Directly mutate user input in forward
        """
    
        def forward(self, x):
            x.mul_(2)
            return x.cos()
    
    
    example_args = (torch.randn(3, 2),)
    tags = {"torch.mutation"}
    model = UserInputMutation()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[3, 2]"):
                     mul_: "f32[3, 2]" = torch.ops.aten.mul_.Tensor(x, 2);  x = None
                
                     cos: "f32[3, 2]" = torch.ops.aten.cos.default(mul_);  mul_ = None
                return (cos,)
                
    Graph signature: 
        # inputs
        x: USER_INPUT
        
        # outputs
        cos: USER_OUTPUT
        
    Range constraints: {}
    


Not Supported Yet
-----------------

dynamic_shape_round
^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`python.builtin <python.builtin>`

    Support Level: NOT_SUPPORTED_YET

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    
    from torch._export.db.case import SupportLevel
    from torch.export import Dim
    
    class DynamicShapeRound(torch.nn.Module):
        """
        Calling round on dynamic shapes is not supported.
        """
    
        def forward(self, x):
            return x[: round(x.shape[0] / 2)]
    
    x = torch.randn(3, 2)
    dim0_x = Dim("dim0_x")
    example_args = (x,)
    tags = {"torch.dynamic-shape", "python.builtin"}
    support_level = SupportLevel.NOT_SUPPORTED_YET
    dynamic_shapes = {"x": {0: dim0_x}}
    model = DynamicShapeRound()
    

    torch.export.export(model, example_args, dynamic_shapes=dynamic_shapes)

Result:

.. code-block::

    Unsupported: Constraints violated (dim0_x)! For more information, run with TORCH_LOGS="+dynamic".


model_attr_mutation
^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`python.object-model <python.object-model>`

    Support Level: NOT_SUPPORTED_YET

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    from torch._export.db.case import SupportLevel
    
    
    class ModelAttrMutation(torch.nn.Module):
        """
        Attribute mutation is not supported.
        """
    
        def __init__(self) -> None:
            super().__init__()
            self.attr_list = [torch.randn(3, 2), torch.randn(3, 2)]
    
        def recreate_list(self):
            return [torch.zeros(3, 2), torch.zeros(3, 2)]
    
        def forward(self, x):
            self.attr_list = self.recreate_list()
            return x.sum() + self.attr_list[0].sum()
    
    
    example_args = (torch.randn(3, 2),)
    tags = {"python.object-model"}
    support_level = SupportLevel.NOT_SUPPORTED_YET
    model = ModelAttrMutation()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    AssertionError: Mutating module attribute attr_list during export.


optional_input
^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`python.object-model <python.object-model>`

    Support Level: NOT_SUPPORTED_YET

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    from torch._export.db.case import SupportLevel
    
    
    class OptionalInput(torch.nn.Module):
        """
        Tracing through optional input is not supported yet
        """
    
        def forward(self, x, y=torch.randn(2, 3)):
            if y is not None:
                return x + y
            return x
    
    
    example_args = (torch.randn(2, 3),)
    tags = {"python.object-model"}
    support_level = SupportLevel.NOT_SUPPORTED_YET
    model = OptionalInput()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    Unsupported: Tracing through optional input is not supported yet


unsupported_operator
^^^^^^^^^^^^^^^^^^^^

.. note::

    Tags: :doc:`torch.operator <torch.operator>`

    Support Level: NOT_SUPPORTED_YET

Original source code:

.. code-block:: python

    # mypy: allow-untyped-defs
    import torch
    from torch._export.db.case import SupportLevel
    
    
    class TorchSymMin(torch.nn.Module):
        """
        torch.sym_min operator is not supported in export.
        """
    
        def forward(self, x):
            return x.sum() + torch.sym_min(x.size(0), 100)
    
    
    example_args = (torch.randn(3, 2),)
    tags = {"torch.operator"}
    support_level = SupportLevel.NOT_SUPPORTED_YET
    model = TorchSymMin()
    

    torch.export.export(model, example_args)

Result:

.. code-block::

    Unsupported: torch.* op returned non-Tensor int call_function <function sym_min at 0x7f1961f81700>