Shortcuts

ExportDB

ExportDB is a centralized dataset of supported and unsupported export cases. It is targeted towards users who want to understand specifically what types of code are supported, the subtleties of export, and how to modify their existing code to be compatible with export. Note that this is not an exhaustive set of everything that is supported by exportdb, but it covers the most common and confusing use cases that users will run into.

If you have a feature that you think needs a stronger guarantee from us to support in export please create an issue in the pytorch/pytorch repo wih a module:export tag.

Supported

assume_constant_result

Note

Tags: torch.escape-hatch

Support Level: SUPPORTED

Original source code:

import torch
import torch._dynamo as torchdynamo



class AssumeConstantResult(torch.nn.Module):
    """
    Applying `assume_constant_result` decorator to burn make non-tracable code as constant.
    """

    def __init__(self):
        super().__init__()

    @torchdynamo.assume_constant_result
    def get_item(self, y):
        return y.int().item()

    def forward(self, x, y):
        return x[: self.get_item(y)]

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]", arg1_1: "i64[]"):
                slice_1: "f32[3, 2]" = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 4);  arg0_1 = None
            return (slice_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='slice_1'), target=None)])
Range constraints: {}

autograd_function

Note

Tags:

Support Level: SUPPORTED

Original source code:

import torch



class MyAutogradFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x.clone()

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output + 1


class AutogradFunction(torch.nn.Module):
    """
    TorchDynamo does not keep track of backward() on autograd functions. We recommend to
    use `allow_in_graph` to mitigate this problem.
    """

    def forward(self, x):
        return MyAutogradFunction.apply(x)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]"):
                clone: "f32[3, 2]" = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
            return (clone,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='clone'), target=None)])
Range constraints: {}

class_method

Note

Tags:

Support Level: SUPPORTED

Original source code:

import torch



class ClassMethod(torch.nn.Module):
    """
    Class methods are inlined during tracing.
    """

    @classmethod
    def method(cls, x):
        return x + 1

    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(4, 2)

    def forward(self, x):
        x = self.linear(x)
        return self.method(x) * self.__class__.method(x) * type(self).method(x)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[2]", arg2_1: "f32[3, 4]"):
                t: "f32[4, 2]" = torch.ops.aten.t.default(arg0_1);  arg0_1 = None
            addmm: "f32[3, 2]" = torch.ops.aten.addmm.default(arg1_1, arg2_1, t);  arg1_1 = arg2_1 = t = None

                add: "f32[3, 2]" = torch.ops.aten.add.Tensor(addmm, 1)
            add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(addmm, 1)

                mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(add, add_1);  add = add_1 = None

                add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(addmm, 1);  addmm = None

                mul_1: "f32[3, 2]" = torch.ops.aten.mul.Tensor(mul, add_2);  mul = add_2 = None
            return (mul_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='linear.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg1_1'), target='linear.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg2_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='mul_1'), target=None)])
Range constraints: {}

cond_branch_class_method

Note

Tags: torch.dynamic-shape, torch.cond

Support Level: SUPPORTED

Original source code:

import torch

from functorch.experimental.control_flow import cond


class MySubModule(torch.nn.Module):
    def foo(self, x):
        return x.cos()

    def forward(self, x):
        return self.foo(x)


class CondBranchClassMethod(torch.nn.Module):
    """
    The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
      - both branches must take the same args, which must also match the branch args passed to cond.
      - both branches must return a single tensor
      - returned tensor must have the same tensor metadata, e.g. shape and dtype
      - branch function can be free function, nested function, lambda, class methods
      - branch function can not have closure variables
      - no inplace mutations on inputs or global variables


    This example demonstrates using class method in cond().

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    def __init__(self):
        super().__init__()
        self.subm = MySubModule()

    def bar(self, x):
        return x.sin()

    def forward(self, x):
        return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3]"):
                true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]);  true_graph_0 = false_graph_0 = arg0_1 = None
            getitem: "f32[3]" = conditional[0];  conditional = None
            return (getitem,)

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: "f32[3]"):
                        cos: "f32[3]" = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
                return (cos,)

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: "f32[3]"):
                        sin: "f32[3]" = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                return (sin,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}

cond_branch_nested_function

Note

Tags: torch.dynamic-shape, torch.cond

Support Level: SUPPORTED

Original source code:

import torch

from functorch.experimental.control_flow import cond


class CondBranchNestedFunction(torch.nn.Module):
    """
    The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
      - both branches must take the same args, which must also match the branch args passed to cond.
      - both branches must return a single tensor
      - returned tensor must have the same tensor metadata, e.g. shape and dtype
      - branch function can be free function, nested function, lambda, class methods
      - branch function can not have closure variables
      - no inplace mutations on inputs or global variables

    This example demonstrates using nested function in cond().

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        def true_fn(x):
            def inner_true_fn(y):
                return x + y

            return inner_true_fn(x)

        def false_fn(x):
            def inner_false_fn(y):
                return x - y

            return inner_false_fn(x)

        return cond(x.shape[0] < 10, true_fn, false_fn, [x])

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3]"):
                true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            conditional = torch.ops.higher_order.cond(True, true_graph_0, false_graph_0, [arg0_1]);  true_graph_0 = false_graph_0 = arg0_1 = None
            getitem: "f32[3]" = conditional[0];  conditional = None
            return (getitem,)

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: "f32[3]"):
                        add: "f32[3]" = torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 = None
                return (add,)

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: "f32[3]"):
                        sub: "f32[3]" = torch.ops.aten.sub.Tensor(arg0_1, arg0_1);  arg0_1 = None
                return (sub,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}

cond_branch_nonlocal_variables

Note

Tags: torch.dynamic-shape, torch.cond

Support Level: SUPPORTED

Original source code:

import torch

from functorch.experimental.control_flow import cond


class CondBranchNonlocalVariables(torch.nn.Module):
    """
    The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
    - both branches must take the same args, which must also match the branch args passed to cond.
    - both branches must return a single tensor
    - returned tensor must have the same tensor metadata, e.g. shape and dtype
    - branch function can be free function, nested function, lambda, class methods
    - branch function can not have closure variables
    - no inplace mutations on inputs or global variables

    This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions.

    The code below will not work because capturing closure variables is not supported.
    ```
    my_tensor_var = x + 100
    my_primitive_var = 3.14

    def true_fn(y):
        nonlocal my_tensor_var, my_primitive_var
        return y + my_tensor_var + my_primitive_var

    def false_fn(y):
        nonlocal my_tensor_var, my_primitive_var
        return y - my_tensor_var - my_primitive_var

    return cond(x.shape[0] > 5, true_fn, false_fn, [x])
    ```

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        my_tensor_var = x + 100
        my_primitive_var = 3.14

        def true_fn(x, y, z):
            return x + y + z

        def false_fn(x, y, z):
            return x - y - z

        return cond(
            x.shape[0] > 5,
            true_fn,
            false_fn,
            [x, my_tensor_var, torch.tensor(my_primitive_var)],
        )

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, _lifted_tensor_constant0: "f32[]", arg0_1: "f32[6]"):
                add: "f32[6]" = torch.ops.aten.add.Tensor(arg0_1, 100)

                lift_fresh_copy: "f32[]" = torch.ops.aten.lift_fresh_copy.default(_lifted_tensor_constant0);  _lifted_tensor_constant0 = None

                true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            conditional = torch.ops.higher_order.cond(True, true_graph_0, false_graph_0, [arg0_1, add, lift_fresh_copy]);  true_graph_0 = false_graph_0 = arg0_1 = add = lift_fresh_copy = None
            getitem: "f32[6]" = conditional[0];  conditional = None
            return (getitem,)

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: "f32[6]", arg1_1: "f32[6]", arg2_1: "f32[]"):
                        add: "f32[6]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                add_1: "f32[6]" = torch.ops.aten.add.Tensor(add, arg2_1);  add = arg2_1 = None
                return (add_1,)

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: "f32[6]", arg1_1: "f32[6]", arg2_1: "f32[]"):
                        sub: "f32[6]" = torch.ops.aten.sub.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                sub_1: "f32[6]" = torch.ops.aten.sub.Tensor(sub, arg2_1);  sub = arg2_1 = None
                return (sub_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='_lifted_tensor_constant0'), target='_lifted_tensor_constant0', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}

cond_closed_over_variable

Note

Tags: python.closure, torch.cond

Support Level: SUPPORTED

Original source code:

import torch

from functorch.experimental.control_flow import cond


class CondClosedOverVariable(torch.nn.Module):
    """
    torch.cond() supports branches closed over arbitrary variables.
    """

    def forward(self, pred, x):
        def true_fn(val):
            return x * 2

        def false_fn(val):
            return x - 2

        return cond(pred, true_fn, false_fn, [x + 1])

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "b8[]", arg1_1: "f32[3, 2]"):
                true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            conditional = torch.ops.higher_order.cond(arg0_1, true_graph_0, false_graph_0, [arg1_1]);  arg0_1 = true_graph_0 = false_graph_0 = arg1_1 = None
            getitem: "f32[3, 2]" = conditional[0];  conditional = None
            return (getitem,)

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: "f32[3, 2]"):
                        mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(arg0_1, 2);  arg0_1 = None
                return (mul,)

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: "f32[3, 2]"):
                        sub: "f32[3, 2]" = torch.ops.aten.sub.Tensor(arg0_1, 2);  arg0_1 = None
                return (sub,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}

cond_operands

Note

Tags: torch.dynamic-shape, torch.cond

Support Level: SUPPORTED

Original source code:

import torch

from torch.export import Dim
from functorch.experimental.control_flow import cond

x = torch.randn(3, 2)
y = torch.ones(2)
dim0_x = Dim("dim0_x")

class CondOperands(torch.nn.Module):
    """
    The operands passed to cond() must be:
    - a list of tensors
    - match arguments of `true_fn` and `false_fn`

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        def true_fn(x, y):
            return x + y

        def false_fn(x, y):
            return x - y

        return cond(x.shape[0] > 2, true_fn, false_fn, [x, y])

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[s0, 2]", arg1_1: "f32[2]"):
                sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(arg0_1, 0)
            gt: "Sym(s0 > 2)" = sym_size_int > 2;  sym_size_int = None
            true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            conditional = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1, arg1_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = arg1_1 = None
            getitem: "f32[s0, 2]" = conditional[0];  conditional = None
            return (getitem,)

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: "f32[s0, 2]", arg1_1: "f32[2]"):
                        add: "f32[s0, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                return (add,)

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: "f32[s0, 2]", arg1_1: "f32[2]"):
                        sub: "f32[s0, 2]" = torch.ops.aten.sub.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                return (sub,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {s0: ValueRanges(lower=2, upper=oo, is_bool=False)}

cond_predicate

Note

Tags: torch.dynamic-shape, torch.cond

Support Level: SUPPORTED

Original source code:

import torch

from functorch.experimental.control_flow import cond


class CondPredicate(torch.nn.Module):
    """
    The conditional statement (aka predicate) passed to cond() must be one of the following:
      - torch.Tensor with a single element
      - boolean expression

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        pred = x.dim() > 2 and x.shape[2] > 10

        return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[6, 4, 3]"):
                true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]);  true_graph_0 = false_graph_0 = arg0_1 = None
            getitem: "f32[6, 4, 3]" = conditional[0];  conditional = None
            return (getitem,)

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: "f32[6, 4, 3]"):
                        cos: "f32[6, 4, 3]" = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
                return (cos,)

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: "f32[6, 4, 3]"):
                        sin: "f32[6, 4, 3]" = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                return (sin,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}

constrain_as_size_example

Note

Tags: torch.dynamic-value, torch.escape-hatch

Support Level: SUPPORTED

Original source code:

import torch



class ConstrainAsSizeExample(torch.nn.Module):
    """
    If the value is not known at tracing time, you can provide hint so that we
    can trace further. Please look at constrain_as_value and constrain_as_size APIs
    constrain_as_size is used for values that NEED to be used for constructing
    tensor.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        a = x.item()
        torch._constrain_as_size(a, min=0, max=5)
        return torch.ones((a, 5))

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "i64[]"):
                _local_scalar_dense: "Sym(u4)" = torch.ops.aten._local_scalar_dense.default(arg0_1);  arg0_1 = None

                ge: "Sym(u4 >= 0)" = _local_scalar_dense >= 0
            scalar_tensor: "f32[]" = torch.ops.aten.scalar_tensor.default(ge);  ge = None
            _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, '_local_scalar_dense is outside of inline constraint [0, 5].');  scalar_tensor = None
            le: "Sym(u4 <= 5)" = _local_scalar_dense <= 5
            scalar_tensor_1: "f32[]" = torch.ops.aten.scalar_tensor.default(le);  le = None
            _assert_async_1 = torch.ops.aten._assert_async.msg(scalar_tensor_1, '_local_scalar_dense is outside of inline constraint [0, 5].');  scalar_tensor_1 = None

                sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense, min = 0, max = 5)

                ones: "f32[u4, 5]" = torch.ops.aten.ones.default([_local_scalar_dense, 5], device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
            return (ones,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='ones'), target=None)])
Range constraints: {u0: ValueRanges(lower=0, upper=5, is_bool=False), u1: ValueRanges(lower=0, upper=5, is_bool=False), u4: ValueRanges(lower=0, upper=5, is_bool=False)}

constrain_as_value_example

Note

Tags: torch.dynamic-value, torch.escape-hatch

Support Level: SUPPORTED

Original source code:

import torch



class ConstrainAsValueExample(torch.nn.Module):
    """
    If the value is not known at tracing time, you can provide hint so that we
    can trace further. Please look at constrain_as_value and constrain_as_size APIs.
    constrain_as_value is used for values that don't need to be used for constructing
    tensor.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        a = x.item()
        torch._constrain_as_value(a, min=0, max=5)

        if a < 6:
            return y.sin()
        return y.cos()

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "i64[]", arg1_1: "f32[5, 5]"):
                _local_scalar_dense: "Sym(u4)" = torch.ops.aten._local_scalar_dense.default(arg0_1);  arg0_1 = None

                ge: "Sym(u4 >= 0)" = _local_scalar_dense >= 0
            scalar_tensor: "f32[]" = torch.ops.aten.scalar_tensor.default(ge);  ge = None
            _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, '_local_scalar_dense is outside of inline constraint [0, 5].');  scalar_tensor = None
            le: "Sym(u4 <= 5)" = _local_scalar_dense <= 5
            scalar_tensor_1: "f32[]" = torch.ops.aten.scalar_tensor.default(le);  le = None
            _assert_async_1 = torch.ops.aten._assert_async.msg(scalar_tensor_1, '_local_scalar_dense is outside of inline constraint [0, 5].');  scalar_tensor_1 = None

                sym_constrain_range = torch.ops.aten.sym_constrain_range.default(_local_scalar_dense, min = 0, max = 5);  _local_scalar_dense = None

                sin: "f32[5, 5]" = torch.ops.aten.sin.default(arg1_1);  arg1_1 = None
            return (sin,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='sin'), target=None)])
Range constraints: {u0: ValueRanges(lower=0, upper=5, is_bool=False), u1: ValueRanges(lower=0, upper=5, is_bool=False), u4: ValueRanges(lower=0, upper=5, is_bool=False)}

decorator

Note

Tags:

Support Level: SUPPORTED

Original source code:

import functools

import torch



def test_decorator(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs) + 1

    return wrapper


class Decorator(torch.nn.Module):
    """
    Decorators calls are inlined into the exported function during tracing.
    """

    @test_decorator
    def forward(self, x, y):
        return x + y

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]", arg1_1: "f32[3, 2]"):
                add: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None

                add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(add, 1);  add = None
            return (add_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)])
Range constraints: {}

dictionary

Note

Tags: python.data-structure

Support Level: SUPPORTED

Original source code:

import torch



class Dictionary(torch.nn.Module):
    """
    Dictionary structures are inlined and flattened along tracing.
    """
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        elements = {}
        elements["x2"] = x * x
        y = y * elements["x2"]
        return {"y": y}

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]", arg1_1: "i64[]"):
                mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1);  arg0_1 = None

                mul_1: "f32[3, 2]" = torch.ops.aten.mul.Tensor(arg1_1, mul);  arg1_1 = mul = None
            return (mul_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='mul_1'), target=None)])
Range constraints: {}

dynamic_shape_assert

Note

Tags: python.assert

Support Level: SUPPORTED

Original source code:

import torch



class DynamicShapeAssert(torch.nn.Module):
    """
    A basic usage of python assertion.
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        # assertion with error message
        assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2"
        # assertion without error message
        assert x.shape[0] > 1
        return x

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]"):
            return (arg0_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None)])
Range constraints: {}

dynamic_shape_constructor

Note

Tags: torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

import torch



class DynamicShapeConstructor(torch.nn.Module):
    """
    Tensor constructors should be captured with dynamic shape inputs rather
    than being baked in with static shape.
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.ones(x.shape[0] * 2)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]"):
                ones: "f32[6]" = torch.ops.aten.ones.default([6], device = device(type='cpu'), pin_memory = False)
            return (ones,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='ones'), target=None)])
Range constraints: {}

dynamic_shape_if_guard

Note

Tags: python.control-flow, torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

import torch



class DynamicShapeIfGuard(torch.nn.Module):
    """
    `if` statement with backed dynamic shape predicate will be specialized into
    one particular branch and generate a guard. However, export will fail if the
    the dimension is marked as dynamic shape from higher level API.
    """

    def forward(self, x):
        if x.shape[0] == 3:
            return x.cos()

        return x.sin()

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2, 2]"):
                cos: "f32[3, 2, 2]" = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
            return (cos,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cos'), target=None)])
Range constraints: {}

dynamic_shape_map

Note

Tags: torch.map, torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

import torch

from functorch.experimental.control_flow import map


class DynamicShapeMap(torch.nn.Module):
    """
    functorch map() maps a function over the first tensor dimension.
    """

    def __init__(self):
        super().__init__()

    def forward(self, xs, y):
        def body(x, y):
            return x + y

        return map(body, xs, y)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]", arg1_1: "f32[2]"):
                body_graph_0 = self.body_graph_0
            map_impl = torch.ops.higher_order.map_impl(body_graph_0, [arg0_1], [arg1_1]);  body_graph_0 = arg0_1 = arg1_1 = None
            getitem: "f32[3, 2]" = map_impl[0];  map_impl = None
            return (getitem,)

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: "f32[2]", arg1_1: "f32[2]"):
                        add: "f32[2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}

dynamic_shape_slicing

Note

Tags: torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

import torch



class DynamicShapeSlicing(torch.nn.Module):
    """
    Slices with dynamic shape arguments should be captured into the graph
    rather than being baked in.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]"):
                slice_1: "f32[1, 2]" = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 1);  arg0_1 = None
            slice_2: "f32[1, 1]" = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 9223372036854775807, 2);  slice_1 = None
            return (slice_2,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='slice_2'), target=None)])
Range constraints: {}

dynamic_shape_view

Note

Tags: torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

import torch



class DynamicShapeView(torch.nn.Module):
    """
    Dynamic shapes should be propagated to view arguments instead of being
    baked into the exported graph.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        new_x_shape = x.size()[:-1] + (2, 5)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[10, 10]"):
                view: "f32[10, 2, 5]" = torch.ops.aten.view.default(arg0_1, [10, 2, 5]);  arg0_1 = None

                permute: "f32[10, 5, 2]" = torch.ops.aten.permute.default(view, [0, 2, 1]);  view = None
            return (permute,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='permute'), target=None)])
Range constraints: {}

fn_with_kwargs

Note

Tags: python.data-structure

Support Level: SUPPORTED

Original source code:

import torch



    ),
    tags={"python.data-structure"},
    support_level=SupportLevel.SUPPORTED,
)
class FnWithKwargs(torch.nn.Module):
    """
    Keyword arguments are not supported at the moment.
    """
    def __init__(self):
        super().__init__()

    def forward(self, pos0, tuple0, *myargs, mykw0, **mykwargs):
        out = pos0
        for arg in tuple0:
            out = out * arg
        for arg in myargs:
            out = out * arg
        out = out * mykw0
        out = out * mykwargs["input0"] * mykwargs["input1"]
        return out

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[4]", arg1_1: "f32[4]", arg2_1: "f32[4]", arg3_1: "f32[4]", arg4_1: "f32[4]", arg5_1: "f32[4]", arg6_1: "f32[4]", arg7_1: "f32[4]"):
                mul: "f32[4]" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
            mul_1: "f32[4]" = torch.ops.aten.mul.Tensor(mul, arg2_1);  mul = arg2_1 = None

                mul_2: "f32[4]" = torch.ops.aten.mul.Tensor(mul_1, arg3_1);  mul_1 = arg3_1 = None
            mul_3: "f32[4]" = torch.ops.aten.mul.Tensor(mul_2, arg4_1);  mul_2 = arg4_1 = None

                mul_4: "f32[4]" = torch.ops.aten.mul.Tensor(mul_3, arg5_1);  mul_3 = arg5_1 = None

                mul_5: "f32[4]" = torch.ops.aten.mul.Tensor(mul_4, arg6_1);  mul_4 = arg6_1 = None
            mul_6: "f32[4]" = torch.ops.aten.mul.Tensor(mul_5, arg7_1);  mul_5 = arg7_1 = None
            return (mul_6,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg2_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg5_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg6_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg7_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='mul_6'), target=None)])
Range constraints: {}

list_contains

Note

Tags: torch.dynamic-shape, python.assert, python.data-structure

Support Level: SUPPORTED

Original source code:

import torch



class ListContains(torch.nn.Module):
    """
    List containment relation can be checked on a dynamic shape or constants.
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        assert x.size(-1) in [6, 2]
        assert x.size(0) not in [4, 5, 6]
        assert "monkey" not in ["cow", "pig"]
        return x + x

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]"):
                add: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}

list_unpack

Note

Tags: python.control-flow, python.data-structure

Support Level: SUPPORTED

Original source code:

from typing import List

import torch



class ListUnpack(torch.nn.Module):
    """
    Lists are treated as static construct, therefore unpacking should be
    erased after tracing.
    """

    def __init__(self):
        super().__init__()

    def forward(self, args: List[torch.Tensor]):
        """
        Lists are treated as static construct, therefore unpacking should be
        erased after tracing.
        """
        x, *y = args
        return x + y[0]

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]", arg1_1: "i64[]", arg2_1: "i64[]"):
                add: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg2_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}

nested_function

Note

Tags: python.closure

Support Level: SUPPORTED

Original source code:

import torch



class NestedFunction(torch.nn.Module):
    """
    Nested functions are traced through. Side effects on global captures
    are not supported though.
    """
    def __init__(self):
        super().__init__()

    def forward(self, a, b):
        x = a + b
        z = a - b

        def closure(y):
            nonlocal x
            x += 1
            return x * y + z

        return closure(x)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]", arg1_1: "f32[2]"):
                add: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)

                sub: "f32[3, 2]" = torch.ops.aten.sub.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None

                add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(add, 1);  add = None

                mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(add_1, add_1);  add_1 = None
            add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(mul, sub);  mul = sub = None
            return (add_2,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None)])
Range constraints: {}

null_context_manager

Note

Tags: python.context-manager

Support Level: SUPPORTED

Original source code:

import contextlib

import torch



class NullContextManager(torch.nn.Module):
    """
    Null context manager in Python will be traced out.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        """
        Null context manager in Python will be traced out.
        """
        ctx = contextlib.nullcontext()
        with ctx:
            return x.sin() + x.cos()

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]"):
                sin: "f32[3, 2]" = torch.ops.aten.sin.default(arg0_1)
            cos: "f32[3, 2]" = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
            add: "f32[3, 2]" = torch.ops.aten.add.Tensor(sin, cos);  sin = cos = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}

pytree_flatten

Note

Tags:

Support Level: SUPPORTED

Original source code:

import torch

from torch.utils import _pytree as pytree


class PytreeFlatten(torch.nn.Module):
    """
    Pytree from PyTorch can be captured by TorchDynamo.
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        y, spec = pytree.tree_flatten(x)
        return y[0] + 1

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]", arg1_1: "f32[3, 2]"):
                add: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}

scalar_output

Note

Tags: torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

import torch

from torch.export import Dim

x = torch.ones(3, 2)
dim1_x = Dim("dim1_x")

class ScalarOutput(torch.nn.Module):
    """
    Returning scalar values from the graph is supported, in addition to Tensor
    outputs. Symbolic shapes are captured and rank is specialized.
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.shape[1] + 1

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, s0]"):
            # No stacktrace found for following nodes
            sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(arg0_1, 1);  arg0_1 = None
            add: "Sym(s0 + 1)" = sym_size_int + 1;  sym_size_int = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=SymIntArgument(name='add'), target=None)])
Range constraints: {s0: ValueRanges(lower=2, upper=oo, is_bool=False)}

specialized_attribute

Note

Tags:

Support Level: SUPPORTED

Original source code:

from enum import Enum

import torch



class Animal(Enum):
    COW = "moo"


class SpecializedAttribute(torch.nn.Module):
    """
    Model attributes are specialized.
    """

    def __init__(self):
        super().__init__()
        self.a = "moo"
        self.b = 4

    def forward(self, x):
        if self.a == Animal.COW.value:
            return x * x + self.b
        else:
            raise ValueError("bad")

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]"):
                mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1);  arg0_1 = None
            add: "f32[3, 2]" = torch.ops.aten.add.Tensor(mul, 4);  mul = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}

static_for_loop

Note

Tags: python.control-flow

Support Level: SUPPORTED

Original source code:

import torch



class StaticForLoop(torch.nn.Module):
    """
    A for loop with constant number of iterations should be unrolled in the exported graph.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        ret = []
        for i in range(10):  # constant
            ret.append(i + x)
        return ret

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]"):
                add: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 0)
            add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 1)
            add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 2)
            add_3: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 3)
            add_4: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 4)
            add_5: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 5)
            add_6: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 6)
            add_7: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 7)
            add_8: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 8)
            add_9: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 9);  arg0_1 = None
            return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_3'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_4'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_5'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_6'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_7'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_8'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_9'), target=None)])
Range constraints: {}

static_if

Note

Tags: python.control-flow

Support Level: SUPPORTED

Original source code:

import torch



class StaticIf(torch.nn.Module):
    """
    `if` statement with static predicate value should be traced through with the
    taken branch.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        if len(x.shape) == 3:
            return x + torch.ones(1, 1, 1)

        return x

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2, 2]"):
                ones: "f32[1, 1, 1]" = torch.ops.aten.ones.default([1, 1, 1], device = device(type='cpu'), pin_memory = False)
            add: "f32[3, 2, 2]" = torch.ops.aten.add.Tensor(arg0_1, ones);  arg0_1 = ones = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}

tensor_setattr

Note

Tags: python.builtin

Support Level: SUPPORTED

Original source code:

import torch



class TensorSetattr(torch.nn.Module):
    """
    setattr() call onto tensors is not supported.
    """
    def forward(self, x, attr):
        setattr(x, attr, torch.randn(3, 2))
        return x + 4

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]", arg1_1):
                add: "f32[3, 2]" = torch.ops.aten.add.Tensor(arg0_1, 4);  arg0_1 = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=ConstantArgument(value='attr'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}

type_reflection_method

Note

Tags: python.builtin

Support Level: SUPPORTED

Original source code:

import torch



class A:
    @classmethod
    def func(cls, x):
        return 1 + x


class TypeReflectionMethod(torch.nn.Module):
    """
    type() calls on custom objects followed by attribute accesses are not allowed
    due to its overly dynamic nature.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        a = A()
        return type(a).func(x)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 4]"):
                add: "f32[3, 4]" = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}

You can rewrite the example above to something like the following:

class TypeReflectionMethodRewrite(torch.nn.Module):
    """
    Custom object class methods will be inlined.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return A.func(x)

user_input_mutation

Note

Tags: torch.mutation

Support Level: SUPPORTED

Original source code:

import torch



class UserInputMutation(torch.nn.Module):
    """
    Directly mutate user input in forward
    """

    def forward(self, x):
        x.mul_(2)
        return x.cos()

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 2]"):
                mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(arg0_1, 2);  arg0_1 = None

                cos: "f32[3, 2]" = torch.ops.aten.cos.default(mul)
            return (mul, cos)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_INPUT_MUTATION: 6>, arg=TensorArgument(name='mul'), target='arg0_1'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cos'), target=None)])
Range constraints: {}

Not Supported Yet

dynamic_shape_round

Note

Tags: torch.dynamic-shape, python.builtin

Support Level: NOT_SUPPORTED_YET

Original source code:

import torch

from torch.export import Dim

x = torch.ones(3, 2)
dim0_x = Dim("dim0_x")

class DynamicShapeRound(torch.nn.Module):
    """
    Calling round on dynamic shapes is not supported.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x[: round(x.shape[0] / 2)]

Result:

AssertionError:

model_attr_mutation

Note

Tags: python.object-model

Support Level: NOT_SUPPORTED_YET

Original source code:

import torch



class ModelAttrMutation(torch.nn.Module):
    """
    Attribute mutation is not supported.
    """

    def __init__(self):
        super().__init__()
        self.attr_list = [torch.ones(3, 2), torch.ones(3, 2)]

    def recreate_list(self):
        return [torch.zeros(3, 2), torch.zeros(3, 2)]

    def forward(self, x):
        self.attr_list = self.recreate_list()
        return x.sum() + self.attr_list[0].sum()

Result:

AssertionError: Mutating module attribute attr_list during export.

optional_input

Note

Tags: python.object-model

Support Level: NOT_SUPPORTED_YET

Original source code:

import torch



class OptionalInput(torch.nn.Module):
    """
    Tracing through optional input is not supported yet
    """

    def forward(self, x, y=torch.ones(2, 3)):
        if y is not None:
            return x + y
        return x

Result:

AssertionError: Unexpectedly found a <class 'torch.Tensor'> in the inputs.

torch_sym_min

Note

Tags: torch.operator

Support Level: NOT_SUPPORTED_YET

Original source code:

import torch



class TorchSymMin(torch.nn.Module):
    """
    torch.sym_min operator is not supported in export.
    """

    def forward(self, x):
        return x.sum() + torch.sym_min(x.size(0), 100)

Result:

Unsupported: torch.* op returned non-Tensor int call_function <function sym_min at 0x7fc7f5082d30>

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources