Shortcuts

ExportDB

ExportDB is a centralized dataset of supported and unsupported export cases. It is targeted towards users who want to understand specifically what types of code are supported, the subtleties of export, and how to modify their existing code to be compatible with export. Note that this is not an exhaustive set of everything that is supported by exportdb, but it covers the most common and confusing use cases that users will run into.

If you have a feature that you think needs a stronger guarantee from us to support in export please create an issue in the pytorch/pytorch repo wih a module:export tag.

Supported

assume_constant_result

Note

Tags: torch.escape-hatch

Support Level: SUPPORTED

Original source code:

import torch
import torch._dynamo as torchdynamo



class AssumeConstantResult(torch.nn.Module):
    """
    Applying `assume_constant_result` decorator to burn make non-tracable code as constant.
    """

    def __init__(self):
        super().__init__()

    @torchdynamo.assume_constant_result
    def get_item(self, y):
        return y.int().item()

    def forward(self, x, y):
        return x[: self.get_item(y)]

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3, 2], arg1_1: i64[]):
            #
            slice_1: f32[3, 2] = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 4);  arg0_1 = None
            return (slice_1,)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['slice_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

autograd_function

Note

Tags:

Support Level: SUPPORTED

Original source code:

import torch



class MyAutogradFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x.clone()

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output + 1


class AutogradFunction(torch.nn.Module):
    """
    TorchDynamo does not keep track of backward() on autograd functions. We recommend to
    use `allow_in_graph` to mitigate this problem.
    """

    def forward(self, x):
        return MyAutogradFunction.apply(x)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3, 2]):
            #
            clone: f32[3, 2] = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
            return (clone,)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['clone'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

class_method

Note

Tags:

Support Level: SUPPORTED

Original source code:

import torch



class ClassMethod(torch.nn.Module):
    """
    Class methods are inlined during tracing.
    """

    @classmethod
    def method(cls, x):
        return x + 1

    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(4, 2)

    def forward(self, x):
        x = self.linear(x)
        return self.method(x) * self.__class__.method(x) * type(self).method(x)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[2, 4], arg1_1: f32[2], arg2_1: f32[3, 4]):
            #
            permute: f32[4, 2] = torch.ops.aten.permute.default(arg0_1, [1, 0]);  arg0_1 = None
            addmm: f32[3, 2] = torch.ops.aten.addmm.default(arg1_1, arg2_1, permute);  arg1_1 = arg2_1 = permute = None
            add: f32[3, 2] = torch.ops.aten.add.Tensor(addmm, 1)
            add_1: f32[3, 2] = torch.ops.aten.add.Tensor(addmm, 1)
            mul: f32[3, 2] = torch.ops.aten.mul.Tensor(add, add_1);  add = add_1 = None
            add_2: f32[3, 2] = torch.ops.aten.add.Tensor(addmm, 1);  addmm = None
            mul_1: f32[3, 2] = torch.ops.aten.mul.Tensor(mul, add_2);  mul = add_2 = None
            return (mul_1,)

Graph Signature: ExportGraphSignature(parameters=['L__self___linear.weight', 'L__self___linear.bias'], buffers=[], user_inputs=['arg2_1'], user_outputs=['mul_1'], inputs_to_parameters={'arg0_1': 'L__self___linear.weight', 'arg1_1': 'L__self___linear.bias'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

cond_branch_class_method

Note

Tags: torch.dynamic-shape, torch.cond

Support Level: SUPPORTED

Original source code:

import torch

from functorch.experimental.control_flow import cond


class MySubModule(torch.nn.Module):
    def foo(self, x):
        return x.cos()

    def forward(self, x):
        return self.foo(x)


class CondBranchClassMethod(torch.nn.Module):
    """
    The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
      - both branches must take the same args, which must also match the branch args passed to cond.
      - both branches must return a single tensor
      - returned tensor must have the same tensor metadata, e.g. shape and dtype
      - branch function can be free function, nested function, lambda, class methods
      - branch function can not have closure variables
      - no inplace mutations on inputs or global variables


    This example demonstrates using class method in cond().

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    def __init__(self):
        super().__init__()
        self.subm = MySubModule()

    def bar(self, x):
        return x.sin()

    def forward(self, x):
        return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3]):
            #
            submodule_0 = self.submodule_0
            submodule_1 = self.submodule_1
            cond: f32[3] = torch.ops.higher_order.cond(False, submodule_0, submodule_1, [arg0_1]);  submodule_0 = submodule_1 = arg0_1 = None
            return (cond,)

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[3]):
                        cos: f32[3] = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
                return cos

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[3]):
                        sin: f32[3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                return sin

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['cond'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

cond_branch_nested_function

Note

Tags: torch.dynamic-shape, torch.cond

Support Level: SUPPORTED

Original source code:

import torch

from functorch.experimental.control_flow import cond


def cond_branch_nested_function(x):
    """
    The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
      - both branches must take the same args, which must also match the branch args passed to cond.
      - both branches must return a single tensor
      - returned tensor must have the same tensor metadata, e.g. shape and dtype
      - branch function can be free function, nested function, lambda, class methods
      - branch function can not have closure variables
      - no inplace mutations on inputs or global variables

    This example demonstrates using nested function in cond().

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    def true_fn(x):
        def inner_true_fn(y):
            return x + y

        return inner_true_fn(x)

    def false_fn(x):
        def inner_false_fn(y):
            return x - y

        return inner_false_fn(x)

    return cond(x.shape[0] < 10, true_fn, false_fn, [x])

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3]):
            #
            submodule_0 = self.submodule_0
            submodule_1 = self.submodule_1
            cond: f32[3] = torch.ops.higher_order.cond(True, submodule_0, submodule_1, [arg0_1]);  submodule_0 = submodule_1 = arg0_1 = None
            return (cond,)

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[3]):
                        add: f32[3] = torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 = None
                return add

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[3]):
                        sub: f32[3] = torch.ops.aten.sub.Tensor(arg0_1, arg0_1);  arg0_1 = None
                return sub

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['cond'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

cond_branch_nonlocal_variables

Note

Tags: torch.dynamic-shape, torch.cond

Support Level: SUPPORTED

Original source code:

import torch

from functorch.experimental.control_flow import cond


def cond_branch_nonlocal_variables(x):
    """
    The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
      - both branches must take the same args, which must also match the branch args passed to cond.
      - both branches must return a single tensor
      - returned tensor must have the same tensor metadata, e.g. shape and dtype
      - branch function can be free function, nested function, lambda, class methods
      - branch function can not have closure variables
      - no inplace mutations on inputs or global variables

    This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions.

    The code below will not work because capturing closure variables is not supported.
    ```
    my_tensor_var = x + 100
    my_primitive_var = 3.14

    def true_fn(y):
        nonlocal my_tensor_var, my_primitive_var
        return y + my_tensor_var + my_primitive_var

    def false_fn(y):
        nonlocal my_tensor_var, my_primitive_var
        return y - my_tensor_var - my_primitive_var

    return cond(x.shape[0] > 5, true_fn, false_fn, [x])
    ```

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    my_tensor_var = x + 100
    my_primitive_var = 3.14

    def true_fn(x, y, z):
        return x + y + z

    def false_fn(x, y, z):
        return x - y - z

    return cond(
        x.shape[0] > 5,
        true_fn,
        false_fn,
        [x, my_tensor_var, torch.tensor(my_primitive_var)],
    )

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[6]):
            #
            add: f32[6] = torch.ops.aten.add.Tensor(arg0_1, 100)
            _tensor_constant0: f32[] = self._tensor_constant0
            lift_fresh_copy: f32[] = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
            submodule_0 = self.submodule_0
            submodule_1 = self.submodule_1
            cond: f32[6] = torch.ops.higher_order.cond(True, submodule_0, submodule_1, [arg0_1, add, lift_fresh_copy]);  submodule_0 = submodule_1 = arg0_1 = add = lift_fresh_copy = None
            return (cond,)

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[6], arg1_1: f32[6], arg2_1: f32[]):
                        add: f32[6] = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                add_1: f32[6] = torch.ops.aten.add.Tensor(add, arg2_1);  add = arg2_1 = None
                return add_1

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[6], arg1_1: f32[6], arg2_1: f32[]):
                        sub: f32[6] = torch.ops.aten.sub.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                sub_1: f32[6] = torch.ops.aten.sub.Tensor(sub, arg2_1);  sub = arg2_1 = None
                return sub_1

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['cond'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

cond_closed_over_variable

Note

Tags: python.closure, torch.cond

Support Level: SUPPORTED

Original source code:

import torch

from functorch.experimental.control_flow import cond


class CondClosedOverVariable(torch.nn.Module):
    """
    torch.cond() supports branches closed over arbitrary variables.
    """

    def forward(self, pred, x):
        def true_fn(val):
            return x * 2

        def false_fn(val):
            return x - 2

        return cond(pred, true_fn, false_fn, [x + 1])

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: b8[], arg1_1: f32[3, 2]):
            #
            add: f32[3, 2] = torch.ops.aten.add.Tensor(arg1_1, 1)
            submodule_0 = self.submodule_0
            submodule_1 = self.submodule_1
            cond: f32[3, 2] = torch.ops.higher_order.cond(arg0_1, submodule_0, submodule_1, [add, arg1_1, arg1_1]);  arg0_1 = submodule_0 = submodule_1 = add = arg1_1 = None
            return (cond,)

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[3, 2], arg1_1: f32[3, 2], arg2_1: f32[3, 2]):
                        mul: f32[3, 2] = torch.ops.aten.mul.Tensor(arg2_1, 2);  arg2_1 = None
                return mul

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[3, 2], arg1_1: f32[3, 2], arg2_1: f32[3, 2]):
                        sub: f32[3, 2] = torch.ops.aten.sub.Tensor(arg2_1, 2);  arg2_1 = None
                return sub

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['cond'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

cond_operands

Note

Tags: torch.dynamic-shape, torch.cond

Support Level: SUPPORTED

Original source code:

import torch

from torch._export import dynamic_dim
from functorch.experimental.control_flow import cond

x = torch.randn(3, 2)
y = torch.ones(2)
dynamic_constraint = dynamic_dim(x, 0)

def cond_operands(x, y):
    """
    The operands passed to cond() must be:
      - a list of tensors
      - match arguments of `true_fn` and `false_fn`

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    def true_fn(x, y):
        return x + y

    def false_fn(x, y):
        return x - y

    return cond(x.shape[0] > 2, true_fn, false_fn, [x, y])

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[s0, 2], arg1_1: f32[2]):
            #
            sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0)
            gt: Sym(s0 > 2) = sym_size > 2;  sym_size = None
            submodule_0 = self.submodule_0
            submodule_1 = self.submodule_1
            cond: f32[s0, 2] = torch.ops.higher_order.cond(gt, submodule_0, submodule_1, [arg0_1, arg1_1]);  gt = submodule_0 = submodule_1 = arg0_1 = arg1_1 = None
            return (cond,)

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[s0, 2], arg1_1: f32[2]):
                        add: f32[s0, 2] = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                return add

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[s0, 2], arg1_1: f32[2]):
                        sub: f32[s0, 2] = torch.ops.aten.sub.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                return sub

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['cond'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {s0: RangeConstraint(min_val=2, max_val=9223372036854775806)}

cond_predicate

Note

Tags: torch.dynamic-shape, torch.cond

Support Level: SUPPORTED

Original source code:

import torch

from functorch.experimental.control_flow import cond


def cond_predicate(x):
    """
    The conditional statement (aka predicate) passed to cond() must be one of the following:
      - torch.Tensor with a single element
      - boolean expression

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    pred = x.dim() > 2 and x.shape[2] > 10

    return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[6, 4, 3]):
            #
            submodule_0 = self.submodule_0
            submodule_1 = self.submodule_1
            cond: f32[6, 4, 3] = torch.ops.higher_order.cond(False, submodule_0, submodule_1, [arg0_1]);  submodule_0 = submodule_1 = arg0_1 = None
            return (cond,)

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[6, 4, 3]):
                        cos: f32[6, 4, 3] = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
                return cos

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[6, 4, 3]):
                        sin: f32[6, 4, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                return sin

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['cond'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

constrain_as_size_example

Note

Tags: torch.escape-hatch, torch.dynamic-value

Support Level: SUPPORTED

Original source code:

import torch
from torch._export.constraints import constrain_as_size



def constrain_as_size_example(x):
    """
    If the value is not known at tracing time, you can provide hint so that we
    can trace further. Please look at constrain_as_value and constrain_as_size APIs
    constrain_as_size is used for values that NEED to be used for constructing
    tensor.
    """
    a = x.item()
    constrain_as_size(a, min=0, max=5)
    return torch.ones((a, 5))

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: i64[]):
            #
            _local_scalar_dense: Sym(i4) = torch.ops.aten._local_scalar_dense.default(arg0_1);  arg0_1 = None
            ge: Sym(i4 >= 0) = _local_scalar_dense >= 0
            scalar_tensor: f32[] = torch.ops.aten.scalar_tensor.default(ge);  ge = None
            _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, '_local_scalar_dense is outside of inline constraint [0, 5].');  scalar_tensor = None
            le: Sym(i4 <= 5) = _local_scalar_dense <= 5
            scalar_tensor_1: f32[] = torch.ops.aten.scalar_tensor.default(le);  le = None
            _assert_async_1 = torch.ops.aten._assert_async.msg(scalar_tensor_1, '_local_scalar_dense is outside of inline constraint [0, 5].');  scalar_tensor_1 = None
            sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense, min = 0, max = 5)
            full: f32[i4, 5] = torch.ops.aten.full.default([_local_scalar_dense, 5], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
            sym_size: Sym(i4) = torch.ops.aten.sym_size.int(full, 0)
            ge_1: Sym(i4 >= 0) = sym_size >= 0
            scalar_tensor_2: f32[] = torch.ops.aten.scalar_tensor.default(ge_1);  ge_1 = None
            _assert_async_2 = torch.ops.aten._assert_async.msg(scalar_tensor_2, 'full.shape[0] is outside of inline constraint [0, 5].');  scalar_tensor_2 = None
            le_1: Sym(i4 <= 5) = sym_size <= 5;  sym_size = None
            scalar_tensor_3: f32[] = torch.ops.aten.scalar_tensor.default(le_1);  le_1 = None
            _assert_async_3 = torch.ops.aten._assert_async.msg(scalar_tensor_3, 'full.shape[0] is outside of inline constraint [0, 5].');  scalar_tensor_3 = None
            return (full,)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['full'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {i0: RangeConstraint(min_val=2, max_val=5), i1: RangeConstraint(min_val=2, max_val=5), i2: RangeConstraint(min_val=2, max_val=5), i3: RangeConstraint(min_val=2, max_val=5), i4: RangeConstraint(min_val=2, max_val=5)}

constrain_as_value_example

Note

Tags: torch.escape-hatch, torch.dynamic-value

Support Level: SUPPORTED

Original source code:

import torch
from torch._export.constraints import constrain_as_value



def constrain_as_value_example(x, y):
    """
    If the value is not known at tracing time, you can provide hint so that we
    can trace further. Please look at constrain_as_value and constrain_as_size APIs.
    constrain_as_value is used for values that don't need to be used for constructing
    tensor.
    """
    a = x.item()
    constrain_as_value(a, min=0, max=5)

    if a < 6:
        return y.sin()
    return y.cos()

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: i64[], arg1_1: f32[5, 5]):
            #
            _local_scalar_dense: Sym(i4) = torch.ops.aten._local_scalar_dense.default(arg0_1);  arg0_1 = None
            ge: Sym(i4 >= 0) = _local_scalar_dense >= 0
            scalar_tensor: f32[] = torch.ops.aten.scalar_tensor.default(ge);  ge = None
            _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, '_local_scalar_dense is outside of inline constraint [0, 5].');  scalar_tensor = None
            le: Sym(i4 <= 5) = _local_scalar_dense <= 5
            scalar_tensor_1: f32[] = torch.ops.aten.scalar_tensor.default(le);  le = None
            _assert_async_1 = torch.ops.aten._assert_async.msg(scalar_tensor_1, '_local_scalar_dense is outside of inline constraint [0, 5].');  scalar_tensor_1 = None
            sym_constrain_range = torch.ops.aten.sym_constrain_range.default(_local_scalar_dense, min = 0, max = 5);  _local_scalar_dense = None
            sin: f32[5, 5] = torch.ops.aten.sin.default(arg1_1);  arg1_1 = None
            return (sin,)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['sin'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {i0: RangeConstraint(min_val=0, max_val=5), i1: RangeConstraint(min_val=0, max_val=5), i2: RangeConstraint(min_val=0, max_val=5), i3: RangeConstraint(min_val=0, max_val=5), i4: RangeConstraint(min_val=0, max_val=5)}

decorator

Note

Tags:

Support Level: SUPPORTED

Original source code:

import functools

import torch



def test_decorator(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs) + 1

    return wrapper


class Decorator(torch.nn.Module):
    """
    Decorators calls are inlined into the exported function during tracing.
    """

    @test_decorator
    def forward(self, x, y):
        return x + y

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3, 2], arg1_1: f32[3, 2]):
            #
            add: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
            add_1: f32[3, 2] = torch.ops.aten.add.Tensor(add, 1);  add = None
            return (add_1,)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['add_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

dictionary

Note

Tags: python.data-structure

Support Level: SUPPORTED

Original source code:

import torch



def dictionary(x, y):
    """
    Dictionary structures are inlined and flattened along tracing.
    """
    elements = {}
    elements["x2"] = x * x
    y = y * elements["x2"]
    return {"y": y}

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3, 2], arg1_1: i64[]):
            #
            mul: f32[3, 2] = torch.ops.aten.mul.Tensor(arg0_1, arg0_1);  arg0_1 = None
            mul_1: f32[3, 2] = torch.ops.aten.mul.Tensor(arg1_1, mul);  arg1_1 = mul = None
            return (mul_1,)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['mul_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

dynamic_shape_assert

Note

Tags: python.assert

Support Level: SUPPORTED

Original source code:

import torch



def dynamic_shape_assert(x):
    """
    A basic usage of python assertion.
    """
    # assertion with error message
    assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2"
    # assertion without error message
    assert x.shape[0] > 1
    return x

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3, 2]):
            return (arg0_1,)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['arg0_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

dynamic_shape_constructor

Note

Tags: torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

import torch



def dynamic_shape_constructor(x):
    """
    Tensor constructors should be captured with dynamic shape inputs rather
    than being baked in with static shape.
    """
    return torch.ones(x.shape[0] * 2)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3, 2]):
            #
            full: f32[6] = torch.ops.aten.full.default([6], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
            return (full,)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['full'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

dynamic_shape_if_guard

Note

Tags: torch.dynamic-shape, python.control-flow

Support Level: SUPPORTED

Original source code:

import torch



class DynamicShapeIfGuard(torch.nn.Module):
    """
    `if` statement with backed dynamic shape predicate will be specialized into
    one particular branch and generate a guard. However, export will fail if the
    the dimension is marked as dynamic shape from higher level API.
    """

    def forward(self, x):
        if x.shape[0] == 3:
            return x.cos()

        return x.sin()

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3, 2, 2]):
            #
            cos: f32[3, 2, 2] = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
            return (cos,)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['cos'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

dynamic_shape_map

Note

Tags: torch.map, torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

import torch

from functorch.experimental.control_flow import map


def dynamic_shape_map(xs, y):
    """
    functorch map() maps a function over the first tensor dimension.
    """

    def body(x, y):
        return x + y

    return map(body, xs, y)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3, 2], arg1_1: f32[2]):
            #
            submodule_0 = self.submodule_0
            map_impl = torch.ops.map_impl(submodule_0, 1, arg0_1, arg1_1);  submodule_0 = arg0_1 = arg1_1 = None
            getitem: f32[3, 2] = map_impl[0];  map_impl = None
            return (getitem,)

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[2], arg1_1: f32[2]):
                        add: f32[2] = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
                return [add]

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['getitem'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

dynamic_shape_slicing

Note

Tags: torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

import torch



def dynamic_shape_slicing(x):
    """
    Slices with dynamic shape arguments should be captured into the graph
    rather than being baked in.
    """
    return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3, 2]):
            #
            slice_1: f32[1, 2] = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 1);  arg0_1 = None
            slice_2: f32[1, 1] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 9223372036854775807, 2);  slice_1 = None
            return (slice_2,)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['slice_2'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

dynamic_shape_view

Note

Tags: torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

import torch



def dynamic_shape_view(x):
    """
    Dynamic shapes should be propagated to view arguments instead of being
    baked into the exported graph.
    """
    new_x_shape = x.size()[:-1] + (2, 5)
    x = x.view(*new_x_shape)
    return x.permute(0, 2, 1)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[10, 10]):
            #
            view: f32[10, 2, 5] = torch.ops.aten.view.default(arg0_1, [10, 2, 5]);  arg0_1 = None
            permute: f32[10, 5, 2] = torch.ops.aten.permute.default(view, [0, 2, 1]);  view = None
            return (permute,)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['permute'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

fn_with_kwargs

Note

Tags: python.data-structure

Support Level: SUPPORTED

Original source code:

import torch



def fn_with_kwargs(pos0, tuple0, *myargs, mykw0, **mykwargs):
    """
    Keyword arguments are not supported at the moment.
    """
    out = pos0
    for arg in tuple0:
        out = out * arg
    for arg in myargs:
        out = out * arg
    out = out * mykw0
    out = out * mykwargs["input0"] * mykwargs["input1"]
    return out

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[4], arg1_1: f32[4], arg2_1: f32[4], arg3_1: f32[4], arg4_1: f32[4], arg5_1: f32[4], arg6_1: f32[4], arg7_1: f32[4]):
            #
            mul: f32[4] = torch.ops.aten.mul.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
            mul_1: f32[4] = torch.ops.aten.mul.Tensor(mul, arg2_1);  mul = arg2_1 = None
            mul_2: f32[4] = torch.ops.aten.mul.Tensor(mul_1, arg3_1);  mul_1 = arg3_1 = None
            mul_3: f32[4] = torch.ops.aten.mul.Tensor(mul_2, arg4_1);  mul_2 = arg4_1 = None
            mul_4: f32[4] = torch.ops.aten.mul.Tensor(mul_3, arg5_1);  mul_3 = arg5_1 = None
            mul_5: f32[4] = torch.ops.aten.mul.Tensor(mul_4, arg6_1);  mul_4 = arg6_1 = None
            mul_6: f32[4] = torch.ops.aten.mul.Tensor(mul_5, arg7_1);  mul_5 = arg7_1 = None
            return (mul_6,)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1', 'arg2_1', 'arg3_1', 'arg4_1', 'arg5_1', 'arg6_1', 'arg7_1'], user_outputs=['mul_6'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

list_contains

Note

Tags: torch.dynamic-shape, python.assert, python.data-structure

Support Level: SUPPORTED

Original source code:

import torch



def list_contains(x):
    """
    List containment relation can be checked on a dynamic shape or constants.
    """
    assert x.size(-1) in [6, 2]
    assert x.size(0) not in [4, 5, 6]
    assert "monkey" not in ["cow", "pig"]
    return x + x

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3, 2]):
            #
            add: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 = None
            return (add,)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['add'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

list_unpack

Note

Tags: python.data-structure, python.control-flow

Support Level: SUPPORTED

Original source code:

from typing import List

import torch



def list_unpack(args: List[torch.Tensor]):
    """
    Lists are treated as static construct, therefore unpacking should be
    erased after tracing.
    """
    x, *y = args
    return x + y[0]

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3, 2], arg1_1: i64[], arg2_1: i64[]):
            #
            add: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
            return (add,)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1', 'arg2_1'], user_outputs=['add'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

nested_function

Note

Tags: python.closure

Support Level: SUPPORTED

Original source code:

import torch



def nested_function(a, b):
    """
    Nested functions are traced through. Side effects on global captures
    are not supported though.
    """
    x = a + b
    z = a - b

    def closure(y):
        nonlocal x
        x += 1
        return x * y + z

    return closure(x)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3, 2], arg1_1: f32[2]):
            #
            add: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
            sub: f32[3, 2] = torch.ops.aten.sub.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
            add_1: f32[3, 2] = torch.ops.aten.add.Tensor(add, 1);  add = None
            mul: f32[3, 2] = torch.ops.aten.mul.Tensor(add_1, add_1);  add_1 = None
            add_2: f32[3, 2] = torch.ops.aten.add.Tensor(mul, sub);  mul = sub = None
            return (add_2,)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['add_2'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

null_context_manager

Note

Tags: python.context-manager

Support Level: SUPPORTED

Original source code:

import contextlib

import torch



def null_context_manager(x):
    """
    Null context manager in Python will be traced out.
    """
    ctx = contextlib.nullcontext()
    with ctx:
        return x.sin() + x.cos()

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3, 2]):
            #
            sin: f32[3, 2] = torch.ops.aten.sin.default(arg0_1)
            cos: f32[3, 2] = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
            add: f32[3, 2] = torch.ops.aten.add.Tensor(sin, cos);  sin = cos = None
            return (add,)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['add'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

pytree_flatten

Note

Tags:

Support Level: SUPPORTED

Original source code:

import torch

from torch.utils import _pytree as pytree


def pytree_flatten(x):
    """
    Pytree from PyTorch cannot be captured by TorchDynamo.
    """
    y, spec = pytree.tree_flatten(x)
    return y[0] + 1

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3, 2], arg1_1: f32[3, 2]):
            #
            add: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
            return (add,)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['add'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

scalar_output

Note

Tags: torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

import torch

from torch._export import dynamic_dim

x = torch.ones(3, 2)
dynamic_constraint = dynamic_dim(x, 1)

def scalar_output(x):
    """
    Returning scalar values from the graph is supported, in addition to Tensor
    outputs. Symbolic shapes are captured and rank is specialized.
    """
    return x.shape[1] + 1

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3, s0]):
            #
            sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 1);  arg0_1 = None
            add: Sym(s0 + 1) = sym_size + 1;  sym_size = None
            return (add,)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['add'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {s0: RangeConstraint(min_val=2, max_val=9223372036854775806)}

specialized_attribute

Note

Tags:

Support Level: SUPPORTED

Original source code:

from enum import Enum

import torch



class Animal(Enum):
    COW = "moo"


class SpecializedAttribute(torch.nn.Module):
    """
    Model attributes are specialized.
    """

    def __init__(self):
        super().__init__()
        self.a = "moo"
        self.b = 4

    def forward(self, x):
        if self.a == Animal.COW.value:
            return x * x + self.b
        else:
            raise ValueError("bad")

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3, 2]):
            #
            mul: f32[3, 2] = torch.ops.aten.mul.Tensor(arg0_1, arg0_1);  arg0_1 = None
            add: f32[3, 2] = torch.ops.aten.add.Tensor(mul, 4);  mul = None
            return (add,)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['add'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

static_for_loop

Note

Tags: python.control-flow

Support Level: SUPPORTED

Original source code:

import torch



class StaticForLoop(torch.nn.Module):
    """
    A for loop with constant number of iterations should be unrolled in the exported graph.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        ret = []
        for i in range(10):  # constant
            ret.append(i + x)
        return ret

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3, 2]):
            #
            add: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 0)
            add_1: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 1)
            add_2: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 2)
            add_3: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 3)
            add_4: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 4)
            add_5: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 5)
            add_6: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 6)
            add_7: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 7)
            add_8: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 8)
            add_9: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 9);  arg0_1 = None
            return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['add', 'add_1', 'add_2', 'add_3', 'add_4', 'add_5', 'add_6', 'add_7', 'add_8', 'add_9'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

static_if

Note

Tags: python.control-flow

Support Level: SUPPORTED

Original source code:

import torch



class StaticIf(torch.nn.Module):
    """
    `if` statement with static predicate value should be traced through with the
    taken branch.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        if len(x.shape) == 3:
            return x + torch.ones(1, 1, 1)

        return x

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3, 2, 2]):
            #
            full: f32[1, 1, 1] = torch.ops.aten.full.default([1, 1, 1], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
            add: f32[3, 2, 2] = torch.ops.aten.add.Tensor(arg0_1, full);  arg0_1 = full = None
            return (add,)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['add'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

tensor_setattr

Note

Tags: python.builtin

Support Level: SUPPORTED

Original source code:

import torch



def tensor_setattr(x, attr):
    """
    setattr() call onto tensors is not supported.
    """
    setattr(x, attr, torch.randn(3, 2))
    return x + 4

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3, 2], arg1_1):
            #
            add: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 4);  arg0_1 = None
            return (add,)

Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['add'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}

Not Supported Yet

dynamic_shape_round

Note

Tags: torch.dynamic-shape, python.builtin

Support Level: NOT_SUPPORTED_YET

Original source code:

import torch

from torch._export import dynamic_dim

x = torch.ones(3, 2)
dynamic_constraint = dynamic_dim(x, 0)

def dynamic_shape_round(x):
    """
    Calling round on dynamic shapes is not supported.
    """
    return x[: round(x.shape[0] / 2)]

Result:

Unsupported: Calling round() on symbolic value is not supported. You can use floor() to implement this functionality

type_reflection_method

Note

Tags: python.builtin

Support Level: NOT_SUPPORTED_YET

Original source code:

import torch



class A:
    @classmethod
    def func(cls, x):
        return 1 + x


def type_reflection_method(x):
    """
    type() calls on custom objects followed by method calls are not allowed
    due to its overly dynamic nature.
    """
    a = A()
    return type(a).func(x)

Result:

Unsupported: Can't call type() on generated custom object. Please use __class__ instead

You can rewrite the example above to something like the following:

def type_reflection_method_rewrite(x):
    """
    Custom object class methods will be inlined.
    """
    return A.func(x)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources