ExportDB¶
ExportDB is a centralized dataset of supported and unsupported export cases. It is targeted towards users who want to understand specifically what types of code are supported, the subtleties of export, and how to modify their existing code to be compatible with export. Note that this is not an exhaustive set of everything that is supported by exportdb, but it covers the most common and confusing use cases that users will run into.
If you have a feature that you think needs a stronger guarantee from us to support in export please create an issue in the pytorch/pytorch repo wih a module:export tag.
Supported¶
assume_constant_result¶
Original source code:
# mypy: allow-untyped-defs
import torch
import torch._dynamo as torchdynamo
class AssumeConstantResult(torch.nn.Module):
"""
Applying `assume_constant_result` decorator to burn make non-tracable code as constant.
"""
def __init__(self):
super().__init__()
@torchdynamo.assume_constant_result
def get_item(self, y):
return y.int().item()
def forward(self, x, y):
return x[: self.get_item(y)]
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 2]", y: "i64[]"):
slice_1: "f32[3, 2]" = torch.ops.aten.slice.Tensor(x, 0, 0, 4); x = None
return (slice_1,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='slice_1'), target=None)])
Range constraints: {}
autograd_function¶
Note
Tags:
Support Level: SUPPORTED
Original source code:
# mypy: allow-untyped-defs
import torch
class MyAutogradFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone()
@staticmethod
def backward(ctx, grad_output):
return grad_output + 1
class AutogradFunction(torch.nn.Module):
"""
TorchDynamo does not keep track of backward() on autograd functions. We recommend to
use `allow_in_graph` to mitigate this problem.
"""
def forward(self, x):
return MyAutogradFunction.apply(x)
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 2]"):
clone: "f32[3, 2]" = torch.ops.aten.clone.default(x); x = None
return (clone,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='clone'), target=None)])
Range constraints: {}
class_method¶
Note
Tags:
Support Level: SUPPORTED
Original source code:
# mypy: allow-untyped-defs
import torch
class ClassMethod(torch.nn.Module):
"""
Class methods are inlined during tracing.
"""
@classmethod
def method(cls, x):
return x + 1
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 2)
def forward(self, x):
x = self.linear(x)
return self.method(x) * self.__class__.method(x) * type(self).method(x)
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_linear_weight: "f32[2, 4]", p_linear_bias: "f32[2]", x: "f32[3, 4]"):
linear: "f32[3, 2]" = torch.ops.aten.linear.default(x, p_linear_weight, p_linear_bias); x = p_linear_weight = p_linear_bias = None
add: "f32[3, 2]" = torch.ops.aten.add.Tensor(linear, 1)
add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(linear, 1)
mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(add, add_1); add = add_1 = None
add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(linear, 1); linear = None
mul_1: "f32[3, 2]" = torch.ops.aten.mul.Tensor(mul, add_2); mul = add_2 = None
return (mul_1,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_weight'), target='linear.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_bias'), target='linear.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='mul_1'), target=None)])
Range constraints: {}
cond_branch_class_method¶
Original source code:
# mypy: allow-untyped-defs
import torch
from functorch.experimental.control_flow import cond
class MySubModule(torch.nn.Module):
def foo(self, x):
return x.cos()
def forward(self, x):
return self.foo(x)
class CondBranchClassMethod(torch.nn.Module):
"""
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
- both branches must take the same args, which must also match the branch args passed to cond.
- both branches must return a single tensor
- returned tensor must have the same tensor metadata, e.g. shape and dtype
- branch function can be free function, nested function, lambda, class methods
- branch function can not have closure variables
- no inplace mutations on inputs or global variables
This example demonstrates using class method in cond().
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
def __init__(self):
super().__init__()
self.subm = MySubModule()
def bar(self, x):
return x.sin()
def forward(self, x):
return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3]"):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [x]); true_graph_0 = false_graph_0 = x = None
getitem: "f32[3]" = conditional[0]; conditional = None
return (getitem,)
class <lambda>(torch.nn.Module):
def forward(self, x: "f32[3]"):
cos: "f32[3]" = torch.ops.aten.cos.default(x); x = None
return (cos,)
class <lambda>(torch.nn.Module):
def forward(self, x: "f32[3]"):
sin: "f32[3]" = torch.ops.aten.sin.default(x); x = None
return (sin,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}
cond_branch_nested_function¶
Original source code:
# mypy: allow-untyped-defs
import torch
from functorch.experimental.control_flow import cond
class CondBranchNestedFunction(torch.nn.Module):
"""
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
- both branches must take the same args, which must also match the branch args passed to cond.
- both branches must return a single tensor
- returned tensor must have the same tensor metadata, e.g. shape and dtype
- branch function can be free function, nested function, lambda, class methods
- branch function can not have closure variables
- no inplace mutations on inputs or global variables
This example demonstrates using nested function in cond().
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
def __init__(self):
super().__init__()
def forward(self, x):
def true_fn(x):
def inner_true_fn(y):
return x + y
return inner_true_fn(x)
def false_fn(x):
def inner_false_fn(y):
return x - y
return inner_false_fn(x)
return cond(x.shape[0] < 10, true_fn, false_fn, [x])
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3]"):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(True, true_graph_0, false_graph_0, [x]); true_graph_0 = false_graph_0 = x = None
getitem: "f32[3]" = conditional[0]; conditional = None
return (getitem,)
class <lambda>(torch.nn.Module):
def forward(self, x: "f32[3]"):
add: "f32[3]" = torch.ops.aten.add.Tensor(x, x); x = None
return (add,)
class <lambda>(torch.nn.Module):
def forward(self, x: "f32[3]"):
sub: "f32[3]" = torch.ops.aten.sub.Tensor(x, x); x = None
return (sub,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}
cond_branch_nonlocal_variables¶
Original source code:
# mypy: allow-untyped-defs
import torch
from functorch.experimental.control_flow import cond
class CondBranchNonlocalVariables(torch.nn.Module):
"""
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
- both branches must take the same args, which must also match the branch args passed to cond.
- both branches must return a single tensor
- returned tensor must have the same tensor metadata, e.g. shape and dtype
- branch function can be free function, nested function, lambda, class methods
- branch function can not have closure variables
- no inplace mutations on inputs or global variables
This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions.
The code below will not work because capturing closure variables is not supported.
```
my_tensor_var = x + 100
my_primitive_var = 3.14
def true_fn(y):
nonlocal my_tensor_var, my_primitive_var
return y + my_tensor_var + my_primitive_var
def false_fn(y):
nonlocal my_tensor_var, my_primitive_var
return y - my_tensor_var - my_primitive_var
return cond(x.shape[0] > 5, true_fn, false_fn, [x])
```
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
def __init__(self):
super().__init__()
def forward(self, x):
my_tensor_var = x + 100
my_primitive_var = 3.14
def true_fn(x, y, z):
return x + y + z
def false_fn(x, y, z):
return x - y - z
return cond(
x.shape[0] > 5,
true_fn,
false_fn,
[x, my_tensor_var, torch.tensor(my_primitive_var)],
)
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, c_lifted_tensor_0: "f32[]", x: "f32[6]"):
add: "f32[6]" = torch.ops.aten.add.Tensor(x, 100)
lift_fresh_copy: "f32[]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0); c_lifted_tensor_0 = None
detach: "f32[]" = torch.ops.aten.detach.default(lift_fresh_copy); lift_fresh_copy = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(True, true_graph_0, false_graph_0, [x, add, detach]); true_graph_0 = false_graph_0 = x = add = detach = None
getitem: "f32[6]" = conditional[0]; conditional = None
return (getitem,)
class <lambda>(torch.nn.Module):
def forward(self, x: "f32[6]", add: "f32[6]", detach: "f32[]"):
add_1: "f32[6]" = torch.ops.aten.add.Tensor(x, add); x = add = None
add_2: "f32[6]" = torch.ops.aten.add.Tensor(add_1, detach); add_1 = detach = None
return (add_2,)
class <lambda>(torch.nn.Module):
def forward(self, x: "f32[6]", add: "f32[6]", detach: "f32[]"):
sub: "f32[6]" = torch.ops.aten.sub.Tensor(x, add); x = add = None
sub_1: "f32[6]" = torch.ops.aten.sub.Tensor(sub, detach); sub = detach = None
return (sub_1,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='c_lifted_tensor_0'), target='lifted_tensor_0', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}
cond_closed_over_variable¶
Original source code:
# mypy: allow-untyped-defs
import torch
from functorch.experimental.control_flow import cond
class CondClosedOverVariable(torch.nn.Module):
"""
torch.cond() supports branches closed over arbitrary variables.
"""
def forward(self, pred, x):
def true_fn(val):
return x * 2
def false_fn(val):
return x - 2
return cond(pred, true_fn, false_fn, [x + 1])
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, pred: "b8[]", x: "f32[3, 2]"):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(pred, true_graph_0, false_graph_0, [x]); pred = true_graph_0 = false_graph_0 = x = None
getitem: "f32[3, 2]" = conditional[0]; conditional = None
return (getitem,)
class <lambda>(torch.nn.Module):
def forward(self, x: "f32[3, 2]"):
mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(x, 2); x = None
return (mul,)
class <lambda>(torch.nn.Module):
def forward(self, x: "f32[3, 2]"):
sub: "f32[3, 2]" = torch.ops.aten.sub.Tensor(x, 2); x = None
return (sub,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='pred'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}
cond_operands¶
Original source code:
# mypy: allow-untyped-defs
import torch
from torch.export import Dim
from functorch.experimental.control_flow import cond
x = torch.randn(3, 2)
y = torch.randn(2)
dim0_x = Dim("dim0_x")
class CondOperands(torch.nn.Module):
"""
The operands passed to cond() must be:
- a list of tensors
- match arguments of `true_fn` and `false_fn`
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
def __init__(self):
super().__init__()
def forward(self, x, y):
def true_fn(x, y):
return x + y
def false_fn(x, y):
return x - y
return cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, 2]", y: "f32[2]"):
sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
gt: "Sym(s0 > 2)" = sym_size_int > 2; sym_size_int = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x, y]); gt = true_graph_0 = false_graph_0 = x = y = None
getitem: "f32[s0, 2]" = conditional[0]; conditional = None
return (getitem,)
class <lambda>(torch.nn.Module):
def forward(self, x: "f32[s0, 2]", y: "f32[2]"):
add: "f32[s0, 2]" = torch.ops.aten.add.Tensor(x, y); x = y = None
return (add,)
class <lambda>(torch.nn.Module):
def forward(self, x: "f32[s0, 2]", y: "f32[2]"):
sub: "f32[s0, 2]" = torch.ops.aten.sub.Tensor(x, y); x = y = None
return (sub,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {s0: VR[0, 9223372036854775806]}
cond_predicate¶
Original source code:
# mypy: allow-untyped-defs
import torch
from functorch.experimental.control_flow import cond
class CondPredicate(torch.nn.Module):
"""
The conditional statement (aka predicate) passed to cond() must be one of the following:
- torch.Tensor with a single element
- boolean expression
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
def __init__(self):
super().__init__()
def forward(self, x):
pred = x.dim() > 2 and x.shape[2] > 10
return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[6, 4, 3]"):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [x]); true_graph_0 = false_graph_0 = x = None
getitem: "f32[6, 4, 3]" = conditional[0]; conditional = None
return (getitem,)
class <lambda>(torch.nn.Module):
def forward(self, x: "f32[6, 4, 3]"):
cos: "f32[6, 4, 3]" = torch.ops.aten.cos.default(x); x = None
return (cos,)
class <lambda>(torch.nn.Module):
def forward(self, x: "f32[6, 4, 3]"):
sin: "f32[6, 4, 3]" = torch.ops.aten.sin.default(x); x = None
return (sin,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}
constrain_as_size_example¶
Original source code:
# mypy: allow-untyped-defs
import torch
class ConstrainAsSizeExample(torch.nn.Module):
"""
If the value is not known at tracing time, you can provide hint so that we
can trace further. Please look at torch._check and torch._check_is_size APIs.
torch._check_is_size is used for values that NEED to be used for constructing
tensor.
"""
def __init__(self):
super().__init__()
def forward(self, x):
a = x.item()
torch._check_is_size(a)
torch._check(a <= 5)
return torch.zeros((a, 5))
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "i64[]"):
item: "Sym(u0)" = torch.ops.aten.item.default(x); x = None
#
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item)
# No stacktrace found for following nodes
sym_constrain_range = torch.ops.aten.sym_constrain_range.default(item, min = 0, max = 5)
mul: "Sym(-u0)" = -1 * item
le: "Sym(-u0 <= 0)" = mul <= 0; mul = None
_assert_scalar = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression -u0 <= 0 on node 'le_1'"); le = None
le_1: "Sym(u0 <= 5)" = item <= 5
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 5 on node 'le_2'"); le_1 = None
zeros: "f32[u0, 5]" = torch.ops.aten.zeros.default([item, 5], device = device(type='cpu'), pin_memory = False); item = None
return (zeros,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='zeros'), target=None)])
Range constraints: {u0: VR[0, 5], u1: VR[0, 5], u2: VR[0, 5]}
constrain_as_value_example¶
Original source code:
# mypy: allow-untyped-defs
import torch
class ConstrainAsValueExample(torch.nn.Module):
"""
If the value is not known at tracing time, you can provide hint so that we
can trace further. Please look at torch._check and torch._check_is_size APIs.
torch._check is used for values that don't need to be used for constructing
tensor.
"""
def __init__(self):
super().__init__()
def forward(self, x, y):
a = x.item()
torch._check(a >= 0)
torch._check(a <= 5)
if a < 6:
return y.sin()
return y.cos()
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "i64[]", y: "f32[5, 5]"):
item: "Sym(u0)" = torch.ops.aten.item.default(x); x = None
# No stacktrace found for following nodes
sym_constrain_range = torch.ops.aten.sym_constrain_range.default(item, min = 0, max = 5)
mul: "Sym(-u0)" = -1 * item
le: "Sym(-u0 <= 0)" = mul <= 0; mul = None
_assert_scalar = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression -u0 <= 0 on node 'le_1'"); le = None
le_1: "Sym(u0 <= 5)" = item <= 5; item = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 5 on node 'le_2'"); le_1 = None
sin: "f32[5, 5]" = torch.ops.aten.sin.default(y); y = None
return (sin,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='sin'), target=None)])
Range constraints: {u0: VR[0, 5], u1: VR[0, 5], u2: VR[0, 5]}
decorator¶
Note
Tags:
Support Level: SUPPORTED
Original source code:
# mypy: allow-untyped-defs
import functools
import torch
def test_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs) + 1
return wrapper
class Decorator(torch.nn.Module):
"""
Decorators calls are inlined into the exported function during tracing.
"""
@test_decorator
def forward(self, x, y):
return x + y
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 2]", y: "f32[3, 2]"):
add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, y); x = y = None
add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(add, 1); add = None
return (add_1,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)])
Range constraints: {}
dictionary¶
Original source code:
# mypy: allow-untyped-defs
import torch
class Dictionary(torch.nn.Module):
"""
Dictionary structures are inlined and flattened along tracing.
"""
def __init__(self):
super().__init__()
def forward(self, x, y):
elements = {}
elements["x2"] = x * x
y = y * elements["x2"]
return {"y": y}
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 2]", y: "i64[]"):
mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(x, x); x = None
mul_1: "f32[3, 2]" = torch.ops.aten.mul.Tensor(y, mul); y = mul = None
return (mul_1,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='mul_1'), target=None)])
Range constraints: {}
dynamic_shape_assert¶
Original source code:
# mypy: allow-untyped-defs
import torch
class DynamicShapeAssert(torch.nn.Module):
"""
A basic usage of python assertion.
"""
def __init__(self):
super().__init__()
def forward(self, x):
# assertion with error message
assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2"
# assertion without error message
assert x.shape[0] > 1
return x
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 2]"):
return (x,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='x'), target=None)])
Range constraints: {}
dynamic_shape_constructor¶
Original source code:
# mypy: allow-untyped-defs
import torch
class DynamicShapeConstructor(torch.nn.Module):
"""
Tensor constructors should be captured with dynamic shape inputs rather
than being baked in with static shape.
"""
def __init__(self):
super().__init__()
def forward(self, x):
return torch.zeros(x.shape[0] * 2)
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 2]"):
zeros: "f32[6]" = torch.ops.aten.zeros.default([6], device = device(type='cpu'), pin_memory = False)
return (zeros,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='zeros'), target=None)])
Range constraints: {}
dynamic_shape_if_guard¶
Original source code:
# mypy: allow-untyped-defs
import torch
class DynamicShapeIfGuard(torch.nn.Module):
"""
`if` statement with backed dynamic shape predicate will be specialized into
one particular branch and generate a guard. However, export will fail if the
the dimension is marked as dynamic shape from higher level API.
"""
def forward(self, x):
if x.shape[0] == 3:
return x.cos()
return x.sin()
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 2, 2]"):
cos: "f32[3, 2, 2]" = torch.ops.aten.cos.default(x); x = None
return (cos,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cos'), target=None)])
Range constraints: {}
dynamic_shape_map¶
Original source code:
# mypy: allow-untyped-defs
import torch
from functorch.experimental.control_flow import map
class DynamicShapeMap(torch.nn.Module):
"""
functorch map() maps a function over the first tensor dimension.
"""
def __init__(self):
super().__init__()
def forward(self, xs, y):
def body(x, y):
return x + y
return map(body, xs, y)
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, xs: "f32[3, 2]", y: "f32[2]"):
body_graph_0 = self.body_graph_0
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [xs], [y]); body_graph_0 = xs = y = None
getitem: "f32[3, 2]" = map_impl[0]; map_impl = None
return (getitem,)
class <lambda>(torch.nn.Module):
def forward(self, xs: "f32[2]", y: "f32[2]"):
add: "f32[2]" = torch.ops.aten.add.Tensor(xs, y); xs = y = None
return (add,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='xs'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}
dynamic_shape_slicing¶
Original source code:
# mypy: allow-untyped-defs
import torch
class DynamicShapeSlicing(torch.nn.Module):
"""
Slices with dynamic shape arguments should be captured into the graph
rather than being baked in.
"""
def __init__(self):
super().__init__()
def forward(self, x):
return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 2]"):
slice_1: "f32[1, 2]" = torch.ops.aten.slice.Tensor(x, 0, 0, 1); x = None
slice_2: "f32[1, 1]" = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 9223372036854775807, 2); slice_1 = None
return (slice_2,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='slice_2'), target=None)])
Range constraints: {}
dynamic_shape_view¶
Original source code:
# mypy: allow-untyped-defs
import torch
class DynamicShapeView(torch.nn.Module):
"""
Dynamic shapes should be propagated to view arguments instead of being
baked into the exported graph.
"""
def __init__(self):
super().__init__()
def forward(self, x):
new_x_shape = x.size()[:-1] + (2, 5)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1)
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[10, 10]"):
view: "f32[10, 2, 5]" = torch.ops.aten.view.default(x, [10, 2, 5]); x = None
permute: "f32[10, 5, 2]" = torch.ops.aten.permute.default(view, [0, 2, 1]); view = None
return (permute,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='permute'), target=None)])
Range constraints: {}
fn_with_kwargs¶
Original source code:
# mypy: allow-untyped-defs
import torch
),
tags={"python.data-structure"},
support_level=SupportLevel.SUPPORTED,
)
class FnWithKwargs(torch.nn.Module):
"""
Keyword arguments are not supported at the moment.
"""
def __init__(self):
super().__init__()
def forward(self, pos0, tuple0, *myargs, mykw0, **mykwargs):
out = pos0
for arg in tuple0:
out = out * arg
for arg in myargs:
out = out * arg
out = out * mykw0
out = out * mykwargs["input0"] * mykwargs["input1"]
return out
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, pos0: "f32[4]", tuple0_0: "f32[4]", tuple0_1: "f32[4]", myargs_0: "f32[4]", myargs_1: "f32[4]", mykw0: "f32[4]", input0: "f32[4]", input1: "f32[4]"):
mul: "f32[4]" = torch.ops.aten.mul.Tensor(pos0, tuple0_0); pos0 = tuple0_0 = None
mul_1: "f32[4]" = torch.ops.aten.mul.Tensor(mul, tuple0_1); mul = tuple0_1 = None
mul_2: "f32[4]" = torch.ops.aten.mul.Tensor(mul_1, myargs_0); mul_1 = myargs_0 = None
mul_3: "f32[4]" = torch.ops.aten.mul.Tensor(mul_2, myargs_1); mul_2 = myargs_1 = None
mul_4: "f32[4]" = torch.ops.aten.mul.Tensor(mul_3, mykw0); mul_3 = mykw0 = None
mul_5: "f32[4]" = torch.ops.aten.mul.Tensor(mul_4, input0); mul_4 = input0 = None
mul_6: "f32[4]" = torch.ops.aten.mul.Tensor(mul_5, input1); mul_5 = input1 = None
return (mul_6,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='pos0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='tuple0_0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='tuple0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='myargs_0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='myargs_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='mykw0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='input0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='input1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='mul_6'), target=None)])
Range constraints: {}
list_contains¶
Original source code:
# mypy: allow-untyped-defs
import torch
class ListContains(torch.nn.Module):
"""
List containment relation can be checked on a dynamic shape or constants.
"""
def __init__(self):
super().__init__()
def forward(self, x):
assert x.size(-1) in [6, 2]
assert x.size(0) not in [4, 5, 6]
assert "monkey" not in ["cow", "pig"]
return x + x
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 2]"):
add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, x); x = None
return (add,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}
list_unpack¶
Original source code:
# mypy: allow-untyped-defs
from typing import List
import torch
class ListUnpack(torch.nn.Module):
"""
Lists are treated as static construct, therefore unpacking should be
erased after tracing.
"""
def __init__(self):
super().__init__()
def forward(self, args: List[torch.Tensor]):
"""
Lists are treated as static construct, therefore unpacking should be
erased after tracing.
"""
x, *y = args
return x + y[0]
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, args_0: "f32[3, 2]", args_1: "i64[]", args_2: "i64[]"):
add: "f32[3, 2]" = torch.ops.aten.add.Tensor(args_0, args_1); args_0 = args_1 = None
return (add,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='args_0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='args_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='args_2'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}
nested_function¶
Original source code:
# mypy: allow-untyped-defs
import torch
class NestedFunction(torch.nn.Module):
"""
Nested functions are traced through. Side effects on global captures
are not supported though.
"""
def __init__(self):
super().__init__()
def forward(self, a, b):
x = a + b
z = a - b
def closure(y):
nonlocal x
x += 1
return x * y + z
return closure(x)
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, a: "f32[3, 2]", b: "f32[2]"):
add: "f32[3, 2]" = torch.ops.aten.add.Tensor(a, b)
sub: "f32[3, 2]" = torch.ops.aten.sub.Tensor(a, b); a = b = None
add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(add, 1); add = None
mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(add_1, add_1); add_1 = None
add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(mul, sub); mul = sub = None
return (add_2,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='a'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='b'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None)])
Range constraints: {}
null_context_manager¶
Original source code:
# mypy: allow-untyped-defs
import contextlib
import torch
class NullContextManager(torch.nn.Module):
"""
Null context manager in Python will be traced out.
"""
def __init__(self):
super().__init__()
def forward(self, x):
"""
Null context manager in Python will be traced out.
"""
ctx = contextlib.nullcontext()
with ctx:
return x.sin() + x.cos()
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 2]"):
sin: "f32[3, 2]" = torch.ops.aten.sin.default(x)
cos: "f32[3, 2]" = torch.ops.aten.cos.default(x); x = None
add: "f32[3, 2]" = torch.ops.aten.add.Tensor(sin, cos); sin = cos = None
return (add,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}
pytree_flatten¶
Note
Tags:
Support Level: SUPPORTED
Original source code:
# mypy: allow-untyped-defs
import torch
from torch.utils import _pytree as pytree
class PytreeFlatten(torch.nn.Module):
"""
Pytree from PyTorch can be captured by TorchDynamo.
"""
def __init__(self):
super().__init__()
def forward(self, x):
y, spec = pytree.tree_flatten(x)
return y[0] + 1
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x_1: "f32[3, 2]", x_2: "f32[3, 2]"):
add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x_1, 1); x_1 = None
return (add,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x_2'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}
scalar_output¶
Original source code:
# mypy: allow-untyped-defs
import torch
from torch.export import Dim
x = torch.randn(3, 2)
dim1_x = Dim("dim1_x")
class ScalarOutput(torch.nn.Module):
"""
Returning scalar values from the graph is supported, in addition to Tensor
outputs. Symbolic shapes are captured and rank is specialized.
"""
def __init__(self):
super().__init__()
def forward(self, x):
return x.shape[1] + 1
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, s0]"):
# No stacktrace found for following nodes
sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 1); x = None
add: "Sym(s0 + 1)" = sym_size_int + 1; sym_size_int = None
return (add,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=SymIntArgument(name='add'), target=None)])
Range constraints: {s0: VR[0, 9223372036854775806]}
specialized_attribute¶
Note
Tags:
Support Level: SUPPORTED
Original source code:
# mypy: allow-untyped-defs
from enum import Enum
import torch
class Animal(Enum):
COW = "moo"
class SpecializedAttribute(torch.nn.Module):
"""
Model attributes are specialized.
"""
def __init__(self):
super().__init__()
self.a = "moo"
self.b = 4
def forward(self, x):
if self.a == Animal.COW.value:
return x * x + self.b
else:
raise ValueError("bad")
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 2]"):
mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(x, x); x = None
add: "f32[3, 2]" = torch.ops.aten.add.Tensor(mul, 4); mul = None
return (add,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}
static_for_loop¶
Original source code:
# mypy: allow-untyped-defs
import torch
class StaticForLoop(torch.nn.Module):
"""
A for loop with constant number of iterations should be unrolled in the exported graph.
"""
def __init__(self):
super().__init__()
def forward(self, x):
ret = []
for i in range(10): # constant
ret.append(i + x)
return ret
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 2]"):
add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 0)
add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 1)
add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 2)
add_3: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 3)
add_4: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 4)
add_5: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 5)
add_6: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 6)
add_7: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 7)
add_8: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 8)
add_9: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 9); x = None
return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_3'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_4'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_5'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_6'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_7'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_8'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_9'), target=None)])
Range constraints: {}
static_if¶
Original source code:
# mypy: allow-untyped-defs
import torch
class StaticIf(torch.nn.Module):
"""
`if` statement with static predicate value should be traced through with the
taken branch.
"""
def __init__(self):
super().__init__()
def forward(self, x):
if len(x.shape) == 3:
return x + torch.ones(1, 1, 1)
return x
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 2, 2]"):
ones: "f32[1, 1, 1]" = torch.ops.aten.ones.default([1, 1, 1], device = device(type='cpu'), pin_memory = False)
add: "f32[3, 2, 2]" = torch.ops.aten.add.Tensor(x, ones); x = ones = None
return (add,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}
tensor_setattr¶
Original source code:
# mypy: allow-untyped-defs
import torch
class TensorSetattr(torch.nn.Module):
"""
setattr() call onto tensors is not supported.
"""
def forward(self, x, attr):
setattr(x, attr, torch.randn(3, 2))
return x + 4
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 2]", attr):
add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 4); x = None
return (add,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=ConstantArgument(name='attr', value='attr'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}
type_reflection_method¶
Original source code:
# mypy: allow-untyped-defs
import torch
class A:
@classmethod
def func(cls, x):
return 1 + x
class TypeReflectionMethod(torch.nn.Module):
"""
type() calls on custom objects followed by attribute accesses are not allowed
due to its overly dynamic nature.
"""
def __init__(self):
super().__init__()
def forward(self, x):
a = A()
return type(a).func(x)
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 4]"):
add: "f32[3, 4]" = torch.ops.aten.add.Tensor(x, 1); x = None
return (add,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}
You can rewrite the example above to something like the following:
class TypeReflectionMethodRewrite(torch.nn.Module):
"""
Custom object class methods will be inlined.
"""
def __init__(self):
super().__init__()
def forward(self, x):
return A.func(x)
user_input_mutation¶
Original source code:
# mypy: allow-untyped-defs
import torch
class UserInputMutation(torch.nn.Module):
"""
Directly mutate user input in forward
"""
def forward(self, x):
x.mul_(2)
return x.cos()
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 2]"):
mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(x, 2); x = None
cos: "f32[3, 2]" = torch.ops.aten.cos.default(mul)
return (mul, cos)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_INPUT_MUTATION: 6>, arg=TensorArgument(name='mul'), target='x'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cos'), target=None)])
Range constraints: {}
Not Supported Yet¶
dynamic_shape_round¶
Original source code:
# mypy: allow-untyped-defs
import torch
from torch.export import Dim
x = torch.randn(3, 2)
dim0_x = Dim("dim0_x")
class DynamicShapeRound(torch.nn.Module):
"""
Calling round on dynamic shapes is not supported.
"""
def __init__(self):
super().__init__()
def forward(self, x):
return x[: round(x.shape[0] / 2)]
Result:
AssertionError: RoundToInt(IntTrueDiv(dim0_x, 2)) <= dim0_x
model_attr_mutation¶
Original source code:
# mypy: allow-untyped-defs
import torch
class ModelAttrMutation(torch.nn.Module):
"""
Attribute mutation is not supported.
"""
def __init__(self):
super().__init__()
self.attr_list = [torch.randn(3, 2), torch.randn(3, 2)]
def recreate_list(self):
return [torch.zeros(3, 2), torch.zeros(3, 2)]
def forward(self, x):
self.attr_list = self.recreate_list()
return x.sum() + self.attr_list[0].sum()
Result:
AssertionError: Mutating module attribute attr_list during export.
optional_input¶
Original source code:
# mypy: allow-untyped-defs
import torch
class OptionalInput(torch.nn.Module):
"""
Tracing through optional input is not supported yet
"""
def forward(self, x, y=torch.randn(2, 3)):
if y is not None:
return x + y
return x
Result:
AssertionError: Unexpectedly found a <class 'torch.Tensor'> in the inputs.
torch_sym_min¶
Original source code:
# mypy: allow-untyped-defs
import torch
class TorchSymMin(torch.nn.Module):
"""
torch.sym_min operator is not supported in export.
"""
def forward(self, x):
return x.sum() + torch.sym_min(x.size(0), 100)
Result:
Unsupported: torch.* op returned non-Tensor int call_function <function sym_min at 0x7f0b55cdcb80>