Shortcuts

torch.export Tutorial

Created On: Oct 02, 2023 | Last Updated: Jan 27, 2025 | Last Verified: Nov 05, 2024

Author: William Wen, Zhengxu Chen, Angela Yi, Pian Pawakapan

Warning

torch.export and its related features are in prototype status and are subject to backwards compatibility breaking changes. This tutorial provides a snapshot of torch.export usage as of PyTorch 2.5.

torch.export() is the PyTorch 2.X way to export PyTorch models into standardized model representations, intended to be run on different (i.e. Python-less) environments. The official documentation can be found here.

In this tutorial, you will learn how to use torch.export() to extract ExportedProgram’s (i.e. single-graph representations) from PyTorch programs. We also detail some considerations/modifications that you may need to make in order to make your model compatible with torch.export.

Contents

Basic Usage

torch.export extracts single-graph representations from PyTorch programs by tracing the target function, given example inputs. torch.export.export() is the main entry point for torch.export.

In this tutorial, torch.export and torch.export.export() are practically synonymous, though torch.export generally refers to the PyTorch 2.X export process, and torch.export.export() generally refers to the actual function call.

The signature of torch.export.export() is:

export(
    mod: torch.nn.Module,
    args: Tuple[Any, ...],
    kwargs: Optional[Dict[str, Any]] = None,
    *,
    dynamic_shapes: Optional[Dict[str, Dict[int, Dim]]] = None
) -> ExportedProgram

torch.export.export() traces the tensor computation graph from calling mod(*args, **kwargs) and wraps it in an ExportedProgram, which can be serialized or executed later with different inputs. To execute the ExportedProgram we can call .module() on it to return a torch.nn.Module which is callable, just like the original program. We will detail the dynamic_shapes argument later in the tutorial.

import torch
from torch.export import export

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x, y):
        return torch.nn.functional.relu(self.lin(x + y), inplace=True)

mod = MyModule()
exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100)))
print(type(exported_mod))
print(exported_mod.module()(torch.randn(8, 100), torch.randn(8, 100)))
<class 'torch.export.exported_program.ExportedProgram'>
tensor([[0.8632, 0.8407, 0.0407, 0.0000, 0.4132, 0.0000, 0.0000, 0.1538, 0.6111,
         0.0000],
        [0.0000, 0.0000, 0.0273, 0.8057, 0.0000, 1.0162, 0.8042, 0.0000, 0.2660,
         0.0000],
        [0.9481, 0.1396, 1.0225, 0.9563, 0.5832, 0.2546, 0.4095, 0.4591, 0.0000,
         2.0053],
        [1.1300, 0.4873, 0.0000, 0.9663, 1.2275, 1.4015, 0.0000, 0.9444, 0.0000,
         0.0000],
        [0.0000, 0.8724, 1.1648, 0.6867, 0.0000, 0.2833, 0.3202, 0.5848, 0.0000,
         0.0833],
        [1.1311, 0.1324, 0.0000, 1.7842, 0.0000, 0.3474, 0.9916, 0.3571, 0.0000,
         0.0000],
        [1.4348, 1.0570, 0.1771, 0.0000, 0.9510, 0.0000, 0.0000, 0.0000, 0.2618,
         0.0000],
        [0.8853, 0.0000, 0.0000, 0.4486, 0.0000, 0.0000, 0.5841, 0.7604, 0.0000,
         0.0000]], grad_fn=<ReluBackward0>)

Let’s review some attributes of ExportedProgram that are of interest.

The graph attribute is an FX graph traced from the function we exported, that is, the computation graph of all PyTorch operations. The FX graph is in “ATen IR” meaning that it contains only “ATen-level” operations.

The graph_signature attribute gives a more detailed description of the input and output nodes in the exported graph, describing which ones are parameters, buffers, user inputs, or user outputs.

The range_constraints attributes will be covered later.

print(exported_mod)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_lin_weight: "f32[10, 100]", p_lin_bias: "f32[10]", x: "f32[8, 100]", y: "f32[8, 100]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:71 in forward, code: return torch.nn.functional.relu(self.lin(x + y), inplace=True)
            add: "f32[8, 100]" = torch.ops.aten.add.Tensor(x, y);  x = y = None
            linear: "f32[8, 10]" = torch.ops.aten.linear.default(add, p_lin_weight, p_lin_bias);  add = p_lin_weight = p_lin_bias = None
            relu_: "f32[8, 10]" = torch.ops.aten.relu_.default(linear);  linear = None
            return (relu_,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_lin_weight'), target='lin.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_lin_bias'), target='lin.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='relu_'), target=None)])
Range constraints: {}

See the torch.export documentation for more details.

Graph Breaks

Although torch.export shares components with torch.compile, the key limitation of torch.export, especially when compared to torch.compile, is that it does not support graph breaks. This is because handling graph breaks involves interpreting the unsupported operation with default Python evaluation, which is incompatible with the export use case. Therefore, in order to make your model code compatible with torch.export, you will need to modify your code to remove graph breaks.

A graph break is necessary in cases such as:

  • data-dependent control flow

class Bad1(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return torch.sin(x)
        return torch.cos(x)

import traceback as tb
try:
    export(Bad1(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 122, in <module>
    export(Bad1(), (torch.randn(3, 3),))
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
    return _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 640, in inner
    raise exc.UserError(
torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

from user code:
   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 116, in forward
    if x.sum() > 0:

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
  • accessing tensor data with .data

class Bad2(torch.nn.Module):
    def forward(self, x):
        x.data[0, 0] = 3
        return x

try:
    export(Bad2(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()
  • calling unsupported functions (such as many built-in functions)

class Bad3(torch.nn.Module):
    def forward(self, x):
        x = x + 1
        return x + id(x)

try:
    export(Bad3(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 148, in <module>
    export(Bad3(), (torch.randn(3, 3),))
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
    return _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1658, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 1004, in call_function
    return handler(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 843, in builtin_dispatch
    rv = handler(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 772, in call_self_handler
    result = self_handler(tx, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 1936, in call_id
    return tensor_variable.call_id(tx)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/tensor.py", line 469, in call_id
    unimplemented("call_id not supported for sourceless TensorVariable")
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py", line 317, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: call_id not supported for sourceless TensorVariable

from user code:
   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 145, in forward
    return x + id(x)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Non-Strict Export

To trace the program, torch.export uses TorchDynamo by default, a byte code analysis engine, to symbolically analyze the Python code and build a graph based on the results. This analysis allows torch.export to provide stronger guarantees about safety, but not all Python code is supported, causing these graph breaks.

To address this issue, in PyTorch 2.3, we introduced a new mode of exporting called non-strict mode, where we trace through the program using the Python interpreter executing it exactly as it would in eager mode, allowing us to skip over unsupported Python features. This is done through adding a strict=False flag.

Looking at some of the previous examples which resulted in graph breaks:

  • Calling unsupported functions (such as many built-in functions) traces

through, but in this case, id(x) gets specialized as a constant integer in the graph. This is because id(x) is not a tensor operation, so the operation is not recorded in the graph.

class Bad3(torch.nn.Module):
    def forward(self, x):
        x = x + 1
        return x + id(x)

bad3_nonstrict = export(Bad3(), (torch.randn(3, 3),), strict=False)
print(bad3_nonstrict)
print(bad3_nonstrict.module()(torch.ones(3, 3)))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:179 in forward, code: x = x + 1
            add: "f32[3, 3]" = torch.ops.aten.add.Tensor(x, 1);  x = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:180 in forward, code: return x + id(x)
            add_1: "f32[3, 3]" = torch.ops.aten.add.Tensor(add, 140039851959984);  add = None
            return (add_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)])
Range constraints: {}

tensor([[1.4004e+14, 1.4004e+14, 1.4004e+14],
        [1.4004e+14, 1.4004e+14, 1.4004e+14],
        [1.4004e+14, 1.4004e+14, 1.4004e+14]])

However, there are still some features that require rewrites to the original module:

Control Flow Ops

torch.export actually does support data-dependent control flow. But these need to be expressed using control flow ops. For example, we can fix the control flow example above using the cond op, like so:

class Bad1Fixed(torch.nn.Module):
    def forward(self, x):
        def true_fn(x):
            return torch.sin(x)
        def false_fn(x):
            return torch.cos(x)
        return torch.cond(x.sum() > 0, true_fn, false_fn, [x])

exported_bad1_fixed = export(Bad1Fixed(), (torch.randn(3, 3),))
print(exported_bad1_fixed)
print(exported_bad1_fixed.module()(torch.ones(3, 3)))
print(exported_bad1_fixed.module()(-torch.ones(3, 3)))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:205 in forward, code: return torch.cond(x.sum() > 0, true_fn, false_fn, [x])
            sum_1: "f32[]" = torch.ops.aten.sum.default(x)
            gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0);  sum_1 = None

             # File: /usr/local/lib/python3.10/dist-packages/torch/_higher_order_ops/cond.py:144 in cond, code: return cond_op(pred, true_fn, false_fn, operands)
            true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x]);  gt = true_graph_0 = false_graph_0 = x = None
            getitem: "f32[3, 3]" = cond[0];  cond = None
            return (getitem,)

        class true_graph_0(torch.nn.Module):
            def forward(self, x: "f32[3, 3]"):
                 # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:202 in true_fn, code: return torch.sin(x)
                sin: "f32[3, 3]" = torch.ops.aten.sin.default(x);  x = None
                return (sin,)

        class false_graph_0(torch.nn.Module):
            def forward(self, x: "f32[3, 3]"):
                 # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:204 in false_fn, code: return torch.cos(x)
                cos: "f32[3, 3]" = torch.ops.aten.cos.default(x);  x = None
                return (cos,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}

tensor([[0.8415, 0.8415, 0.8415],
        [0.8415, 0.8415, 0.8415],
        [0.8415, 0.8415, 0.8415]])
tensor([[0.5403, 0.5403, 0.5403],
        [0.5403, 0.5403, 0.5403],
        [0.5403, 0.5403, 0.5403]])

There are limitations to cond that one should be aware of:

  • The predicate (i.e. x.sum() > 0) must result in a boolean or a single-element tensor.

  • The operands (i.e. [x]) must be tensors.

  • The branch function (i.e. true_fn and false_fn) signature must match with the operands and they must both return a single tensor with the same metadata (for example, dtype, shape, etc.).

  • Branch functions cannot mutate input or global variables.

  • Branch functions cannot access closure variables, except for self if the function is defined in the scope of a method.

For more details about cond, check out the cond documentation.

We can also use map, which applies a function across the first dimension of the first tensor argument.

from torch._higher_order_ops.map import map as torch_map

class MapModule(torch.nn.Module):
    def forward(self, xs, y, z):
        def body(x, y, z):
            return x + y + z

        return torch_map(body, xs, y, z)

inps = (torch.ones(6, 4), torch.tensor(5), torch.tensor(4))
exported_map_example = export(MapModule(), inps)
print(exported_map_example)
print(exported_map_example.module()(*inps))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, xs: "f32[6, 4]", y: "i64[]", z: "i64[]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:236 in forward, code: return torch_map(body, xs, y, z)
            body_graph_0 = self.body_graph_0
            map_impl = torch.ops.higher_order.map_impl(body_graph_0, [xs], [y, z]);  body_graph_0 = xs = y = z = None
            getitem: "f32[6, 4]" = map_impl[0];  map_impl = None
            return (getitem,)

        class body_graph_0(torch.nn.Module):
            def forward(self, xs: "f32[4]", y: "i64[]", z: "i64[]"):
                 # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:234 in body, code: return x + y + z
                add: "f32[4]" = torch.ops.aten.add.Tensor(xs, y);  xs = y = None
                add_1: "f32[4]" = torch.ops.aten.add.Tensor(add, z);  add = z = None
                return (add_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='xs'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='z'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}

tensor([[10., 10., 10., 10.],
        [10., 10., 10., 10.],
        [10., 10., 10., 10.],
        [10., 10., 10., 10.],
        [10., 10., 10., 10.],
        [10., 10., 10., 10.]])

Other control flow ops include while_loop, associative_scan, and scan. For more documentation on each operator, please refer to this page.

Constraints/Dynamic Shapes

This section covers dynamic behavior and representation of exported programs. Dynamic behavior is subjective to the particular model being exported, so for the most part of this tutorial, we’ll focus on this particular toy model (with the resulting tensor shapes annotated):

class DynamicModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(5, 3)

    def forward(
        self,
        w: torch.Tensor,  # [6, 5]
        x: torch.Tensor,  # [4]
        y: torch.Tensor,  # [8, 4]
        z: torch.Tensor,  # [32]
    ):
        x0 = x + y  # [8, 4]
        x1 = self.l(w)  # [6, 3]
        x2 = x0.flatten()  # [32]
        x3 = x2 + z  # [32]
        return x1, x3

By default, torch.export produces a static program. One consequence of this is that at runtime, the program won’t work on inputs with different shapes, even if they’re valid in eager mode.

w = torch.randn(6, 5)
x = torch.randn(4)
y = torch.randn(8, 4)
z = torch.randn(32)
model = DynamicModel()
ep = export(model, (w, x, y, z))
model(w, x, torch.randn(3, 4), torch.randn(12))
try:
    ep.module()(w, x, torch.randn(3, 4), torch.randn(12))
except Exception:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 286, in <module>
    ep.module()(w, x, torch.randn(3, 4), torch.randn(12))
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 822, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 400, in __call__
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 387, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1845, in _call_impl
    return inner()
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1772, in inner
    args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_unlift.py", line 49, in _check_input_constraints_pre_hook
    _check_input_constraints_for_graph(
  File "/usr/local/lib/python3.10/dist-packages/torch/_export/utils.py", line 360, in _check_input_constraints_for_graph
    raise RuntimeError(
RuntimeError: Expected input at *args[2].shape[0] to be equal to 8, but got 3

Basic concepts: symbols and guards

To enable dynamism, export() provides a dynamic_shapes argument. The easiest way to work with dynamic shapes is using Dim.AUTO and looking at the program that’s returned. Dynamic behavior is specified at a input dimension-level; for each input we can specify a tuple of values:

from torch.export.dynamic_shapes import Dim

dynamic_shapes = {
    "w": (Dim.AUTO, Dim.AUTO),
    "x": (Dim.AUTO,),
    "y": (Dim.AUTO, Dim.AUTO),
    "z": (Dim.AUTO,),
}
ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)

Before we look at the program that’s produced, let’s understand what specifying dynamic_shapes entails, and how that interacts with export. For every input dimension where a Dim object is specified, a symbol is allocated, taking on a range of [2, inf] (why not [0, inf] or [1, inf]? we’ll explain later in the 0/1 specialization section).

Export then runs model tracing, looking at each operation that’s performed by the model. Each individual operation can emit what’s called “guards”; basically boolean condition that are required to be true for the program to be valid. When guards involve symbols allocated for input dimensions, the program contains restrictions on what input shapes are valid; i.e. the program’s dynamic behavior. The symbolic shapes subsystem is the part responsible for taking in all the emitted guards and producing a final program representation that adheres to all of these guards. Before we see this “final representation” in an ExportedProgram, let’s look at the guards emitted by the toy model we’re tracing.

Here, each forward input tensor is annotated with the symbol allocated at the start of tracing:

class DynamicModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(5, 3)

    def forward(
        self,
        w: torch.Tensor,  # [s0, s1]
        x: torch.Tensor,  # [s2]
        y: torch.Tensor,  # [s3, s4]
        z: torch.Tensor,  # [s5]
    ):
        x0 = x + y  # guard: s2 == s4
        x1 = self.l(w)  # guard: s1 == 5
        x2 = x0.flatten()  # no guard added here
        x3 = x2 + z  # guard: s3 * s4 == s5
        return x1, x3

Let’s understand each of the operations and the emitted guards:

  • x0 = x + y: This is an element-wise add with broadcasting, since x is a 1-d tensor and y a 2-d tensor. x is broadcasted along the last dimension of y, emitting the guard s2 == s4.

  • x1 = self.l(w): Calling nn.Linear() performs a matrix multiplication with model parameters. In export, parameters, buffers, and constants are considered program state, which is considered static, and so this is a matmul between a dynamic input (w: [s0, s1]), and a statically-shaped tensor. This emits the guard s1 == 5.

  • x2 = x0.flatten(): This call actually doesn’t emit any guards! (at least none relevant to input shapes)

  • x3 = x2 + z: x2 has shape [s3*s4] after flattening, and this element-wise add emits s3 * s4 == s5.

Writing all of these guards down and summarizing is almost like a mathematical proof, which is what the symbolic shapes subsystem tries to do! In summary, we can conclude that the program must have the following input shapes to be valid:

  • w: [s0, 5]

  • x: [s2]

  • y: [s3, s2]

  • z: [s2*s3]

And when we do finally print out the exported program to see our result, those shapes are what we see annotated on the corresponding inputs:

print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_l_weight: "f32[3, 5]", p_l_bias: "f32[3]", w: "f32[s0, 5]", x: "f32[s2]", y: "f32[s3, s2]", z: "f32[s2*s3]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:268 in forward, code: x0 = x + y  # [8, 4]
            add: "f32[s3, s2]" = torch.ops.aten.add.Tensor(x, y);  x = y = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:269 in forward, code: x1 = self.l(w)  # [6, 3]
            linear: "f32[s0, 3]" = torch.ops.aten.linear.default(w, p_l_weight, p_l_bias);  w = p_l_weight = p_l_bias = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:270 in forward, code: x2 = x0.flatten()  # [32]
            flatten: "f32[s2*s3]" = torch.ops.aten.flatten.using_ints(add);  add = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:271 in forward, code: x3 = x2 + z  # [32]
            add_1: "f32[s2*s3]" = torch.ops.aten.add.Tensor(flatten, z);  flatten = z = None
            return (linear, add_1)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_l_weight'), target='l.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_l_bias'), target='l.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='w'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='z'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='linear'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)])
Range constraints: {s0: VR[2, int_oo], s2: VR[2, int_oo], s3: VR[2, int_oo], s2*s3: VR[4, int_oo]}

Another feature to notice is the range_constraints field above, which contains a valid range for each symbol. This isn’t so interesting currently, since this export call doesn’t emit any guards related to symbol bounds and each base symbol has a generic bound, but this will come up later.

So far, because we’ve been exporting this toy model, this experience has not been representative of how hard it typically is to debug dynamic shapes guards & issues. In most cases it isn’t obvious what guards are being emitted, and which operations and parts of user code are responsible. For this toy model we pinpoint the exact lines, and the guards are rather intuitive.

In more complicated cases, a helpful first step is always to enable verbose logging. This can be done either with the environment variable TORCH_LOGS="+dynamic", or interactively with torch._logging.set_logs(dynamic=10):

torch._logging.set_logs(dynamic=10)
ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
I0130 20:24:35.334000 634 torch/fx/experimental/symbolic_shapes.py:3192] [12/0] create_env
I0130 20:24:35.336000 634 torch/fx/experimental/symbolic_shapes.py:4423] [12/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0130 20:24:35.336000 634 torch/fx/experimental/symbolic_shapes.py:4423] [12/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0130 20:24:35.337000 634 torch/fx/experimental/symbolic_shapes.py:6614] [12/0] runtime_assert True == True [statically known]
I0130 20:24:35.339000 634 torch/fx/experimental/symbolic_shapes.py:4423] [12/0] create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0130 20:24:35.341000 634 torch/fx/experimental/symbolic_shapes.py:4423] [12/0] create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0130 20:24:35.342000 634 torch/fx/experimental/symbolic_shapes.py:4423] [12/0] create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0130 20:24:35.345000 634 torch/fx/experimental/symbolic_shapes.py:4423] [12/0] create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s5" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0130 20:24:35.347000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Eq(s2, 1) == False [statically known]
V0130 20:24:35.348000 634 torch/fx/experimental/symbolic_shapes.py:6614] [12/0] runtime_assert True == True [statically known]
V0130 20:24:35.348000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Eq(s4, 1) == False [statically known]
I0130 20:24:35.350000 634 torch/fx/experimental/symbolic_shapes.py:5963] [12/0] set_replacement s4 = s2 (solve) VR[2, int_oo]
I0130 20:24:35.350000 634 torch/fx/experimental/symbolic_shapes.py:6281] [12/0] runtime_assert Eq(s2, s4) [guard added] x0 = x + y  # [8, 4]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:268 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
V0130 20:24:35.351000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Ne(s2, 1) == True [statically known]
V0130 20:24:35.352000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Ne(s3, 1) == True [statically known]
V0130 20:24:35.359000 634 torch/fx/experimental/symbolic_shapes.py:5802] [12/0] _update_var_to_range s1 = VR[5, 5] (update)
I0130 20:24:35.360000 634 torch/fx/experimental/symbolic_shapes.py:5963] [12/0] set_replacement s1 = 5 (range_refined_to_singleton) VR[5, 5]
I0130 20:24:35.360000 634 torch/fx/experimental/symbolic_shapes.py:6281] [12/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # [6, 3]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:269 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
V0130 20:24:35.362000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Eq(s0, 1) == False [statically known]
V0130 20:24:35.367000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Eq(s2*s3, 1) == False [statically known]
V0130 20:24:35.368000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Eq(s5, 1) == False [statically known]
V0130 20:24:35.369000 634 torch/fx/experimental/symbolic_shapes.py:5802] [12/0] _update_var_to_range s5 = VR[4, int_oo] (update)
I0130 20:24:35.370000 634 torch/fx/experimental/symbolic_shapes.py:5963] [12/0] set_replacement s5 = s2*s3 (solve) VR[4, int_oo]
I0130 20:24:35.371000 634 torch/fx/experimental/symbolic_shapes.py:6281] [12/0] runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z  # [32]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:271 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
V0130 20:24:35.372000 634 torch/fx/experimental/symbolic_shapes.py:6412] [12/0] eval Ne(s2*s3, 1) == True [statically known]
I0130 20:24:35.378000 634 torch/fx/experimental/symbolic_shapes.py:4547] [12/0] produce_guards
V0130 20:24:35.379000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['w'].size()[0] s0 None
V0130 20:24:35.379000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['w'].size()[1] 5 None
V0130 20:24:35.379000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['w'].stride()[0] 5 None
V0130 20:24:35.380000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['w'].stride()[1] 1 None
V0130 20:24:35.380000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['w'].storage_offset() 0 None
V0130 20:24:35.380000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['x'].size()[0] s2 None
V0130 20:24:35.381000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['x'].stride()[0] 1 None
V0130 20:24:35.381000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['x'].storage_offset() 0 None
V0130 20:24:35.381000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['y'].size()[0] s3 None
V0130 20:24:35.381000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['y'].size()[1] s2 None
V0130 20:24:35.382000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['y'].stride()[0] s2 None
V0130 20:24:35.382000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['y'].stride()[1] 1 None
V0130 20:24:35.382000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['y'].storage_offset() 0 None
V0130 20:24:35.383000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['z'].size()[0] s2*s3 None
V0130 20:24:35.383000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['z'].stride()[0] 1 None
V0130 20:24:35.383000 634 torch/fx/experimental/symbolic_shapes.py:4755] [12/0] track_symint L['z'].storage_offset() 0 None
V0130 20:24:35.418000 634 torch/fx/experimental/symbolic_shapes.py:6412] eval Ne(s0, 1) == True [statically known]

This spits out quite a handful, even with this simple toy model. The log lines here have been cut short at front and end to ignore unnecessary info, but looking through the logs we can see the lines relevant to what we described above; e.g. the allocation of symbols:

"""
create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
runtime_assert True == True [statically known]
create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
"""
"\ncreate_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\nruntime_assert True == True [statically known]\ncreate_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\n"

The lines with create_symbol show when a new symbol has been allocated, and the logs also identify the tensor variable names and dimensions they’ve been allocated for. In other lines we can also see the guards emitted:

"""
runtime_assert Eq(s2, s4) [guard added] x0 = x + y  # output shape: [8, 4]  # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # [6, 3]  # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z  # [32]  # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
"""
'\nruntime_assert Eq(s2, s4) [guard added] x0 = x + y  # output shape: [8, 4]  # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"\nruntime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # [6, 3]  # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"\nruntime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z  # [32]  # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"\n'

Next to the [guard added] messages, we also see the responsible user lines of code - luckily here the model is simple enough. In many real-world cases it’s not so straightforward: high-level torch operations can have complicated fake-kernel implementations or operator decompositions that complicate where and what guards are emitted. In such cases the best way to dig deeper and investigate is to follow the logs’ suggestion, and re-run with environment variable TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="...", to further attribute the guard of interest.

Dim.AUTO is just one of the available options for interacting with dynamic_shapes; as of writing this 2 other options are available: Dim.DYNAMIC, and Dim.STATIC. Dim.STATIC simply marks a dimension static, while Dim.DYNAMIC is similar to Dim.AUTO in all ways except one: it raises an error when specializing to a constant; this is designed to maintain dynamism. See for example what happens when a static guard is emitted on a dynamically-marked dimension:

dynamic_shapes["w"] = (Dim.AUTO, Dim.DYNAMIC)
try:
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
except Exception:
    tb.print_exc()
I0130 20:24:35.440000 634 torch/fx/experimental/symbolic_shapes.py:3192] [13/0] create_env
I0130 20:24:35.442000 634 torch/fx/experimental/symbolic_shapes.py:4423] [13/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0130 20:24:35.442000 634 torch/fx/experimental/symbolic_shapes.py:4423] [13/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0130 20:24:35.443000 634 torch/fx/experimental/symbolic_shapes.py:6614] [13/0] runtime_assert True == True [statically known]
I0130 20:24:35.445000 634 torch/fx/experimental/symbolic_shapes.py:4423] [13/0] create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0130 20:24:35.447000 634 torch/fx/experimental/symbolic_shapes.py:4423] [13/0] create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0130 20:24:35.448000 634 torch/fx/experimental/symbolic_shapes.py:4423] [13/0] create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0130 20:24:35.450000 634 torch/fx/experimental/symbolic_shapes.py:4423] [13/0] create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s5" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0130 20:24:35.452000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Eq(s2, 1) == False [statically known]
V0130 20:24:35.452000 634 torch/fx/experimental/symbolic_shapes.py:6614] [13/0] runtime_assert True == True [statically known]
V0130 20:24:35.453000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Eq(s4, 1) == False [statically known]
I0130 20:24:35.454000 634 torch/fx/experimental/symbolic_shapes.py:5963] [13/0] set_replacement s4 = s2 (solve) VR[2, int_oo]
I0130 20:24:35.455000 634 torch/fx/experimental/symbolic_shapes.py:6281] [13/0] runtime_assert Eq(s2, s4) [guard added] x0 = x + y  # [8, 4]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:268 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
V0130 20:24:35.456000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Ne(s2, 1) == True [statically known]
V0130 20:24:35.457000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Ne(s3, 1) == True [statically known]
V0130 20:24:35.463000 634 torch/fx/experimental/symbolic_shapes.py:5802] [13/0] _update_var_to_range s1 = VR[5, 5] (update)
I0130 20:24:35.464000 634 torch/fx/experimental/symbolic_shapes.py:5963] [13/0] set_replacement s1 = 5 (range_refined_to_singleton) VR[5, 5]
I0130 20:24:35.465000 634 torch/fx/experimental/symbolic_shapes.py:6281] [13/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # [6, 3]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:269 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
V0130 20:24:35.466000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Eq(s0, 1) == False [statically known]
V0130 20:24:35.471000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Eq(s2*s3, 1) == False [statically known]
V0130 20:24:35.472000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Eq(s5, 1) == False [statically known]
V0130 20:24:35.473000 634 torch/fx/experimental/symbolic_shapes.py:5802] [13/0] _update_var_to_range s5 = VR[4, int_oo] (update)
I0130 20:24:35.474000 634 torch/fx/experimental/symbolic_shapes.py:5963] [13/0] set_replacement s5 = s2*s3 (solve) VR[4, int_oo]
I0130 20:24:35.475000 634 torch/fx/experimental/symbolic_shapes.py:6281] [13/0] runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z  # [32]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:271 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
V0130 20:24:35.476000 634 torch/fx/experimental/symbolic_shapes.py:6412] [13/0] eval Ne(s2*s3, 1) == True [statically known]
I0130 20:24:35.482000 634 torch/fx/experimental/symbolic_shapes.py:4547] [13/0] produce_guards
V0130 20:24:35.482000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['w'].size()[0] s0 None
V0130 20:24:35.483000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['w'].size()[1] 5 RelaxedUnspecConstraint(warn_only=False)
V0130 20:24:35.483000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['w'].stride()[0] 5 None
V0130 20:24:35.483000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['w'].stride()[1] 1 None
V0130 20:24:35.484000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['w'].storage_offset() 0 None
V0130 20:24:35.484000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['x'].size()[0] s2 None
V0130 20:24:35.484000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['x'].stride()[0] 1 None
V0130 20:24:35.484000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['x'].storage_offset() 0 None
V0130 20:24:35.485000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['y'].size()[0] s3 None
V0130 20:24:35.485000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['y'].size()[1] s2 None
V0130 20:24:35.485000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['y'].stride()[0] s2 None
V0130 20:24:35.485000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['y'].stride()[1] 1 None
V0130 20:24:35.486000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['y'].storage_offset() 0 None
V0130 20:24:35.486000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['z'].size()[0] s2*s3 None
V0130 20:24:35.486000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['z'].stride()[0] 1 None
V0130 20:24:35.487000 634 torch/fx/experimental/symbolic_shapes.py:4755] [13/0] track_symint L['z'].storage_offset() 0 None
E0130 20:24:35.488000 634 torch/_guards.py:295] [13/0] Error while creating guard:
E0130 20:24:35.488000 634 torch/_guards.py:295] [13/0] Name: ''
E0130 20:24:35.488000 634 torch/_guards.py:295] [13/0]     Source: shape_env
E0130 20:24:35.488000 634 torch/_guards.py:295] [13/0]     Create Function: SHAPE_ENV
E0130 20:24:35.488000 634 torch/_guards.py:295] [13/0]     Guard Types: None
E0130 20:24:35.488000 634 torch/_guards.py:295] [13/0]     Code List: None
E0130 20:24:35.488000 634 torch/_guards.py:295] [13/0]     Object Weakref: None
E0130 20:24:35.488000 634 torch/_guards.py:295] [13/0]     Guarded Class Weakref: None
E0130 20:24:35.488000 634 torch/_guards.py:295] [13/0] Traceback (most recent call last):
E0130 20:24:35.488000 634 torch/_guards.py:295] [13/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 293, in create
E0130 20:24:35.488000 634 torch/_guards.py:295] [13/0]     return self.create_fn(builder, self)
E0130 20:24:35.488000 634 torch/_guards.py:295] [13/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1868, in SHAPE_ENV
E0130 20:24:35.488000 634 torch/_guards.py:295] [13/0]     code_parts, verbose_code_parts = output_graph.shape_env.produce_guards_verbose(
E0130 20:24:35.488000 634 torch/_guards.py:295] [13/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5188, in produce_guards_verbose
E0130 20:24:35.488000 634 torch/_guards.py:295] [13/0]     raise ConstraintViolationError(
E0130 20:24:35.488000 634 torch/_guards.py:295] [13/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['w'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
E0130 20:24:35.488000 634 torch/_guards.py:295] [13/0]   - Not all values of RelaxedUnspecConstraint(L['w'].size()[1]) are valid because L['w'].size()[1] was inferred to be a constant (5).
E0130 20:24:35.490000 634 torch/_guards.py:297] [13/0] Created at:
E0130 20:24:35.490000 634 torch/_guards.py:297] [13/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 642, in transform
E0130 20:24:35.490000 634 torch/_guards.py:297] [13/0]     tracer = InstructionTranslator(
E0130 20:24:35.490000 634 torch/_guards.py:297] [13/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2711, in __init__
E0130 20:24:35.490000 634 torch/_guards.py:297] [13/0]     output=OutputGraph(
E0130 20:24:35.490000 634 torch/_guards.py:297] [13/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 336, in __init__
E0130 20:24:35.490000 634 torch/_guards.py:297] [13/0]     self.init_ambient_guards()
E0130 20:24:35.490000 634 torch/_guards.py:297] [13/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 485, in init_ambient_guards
E0130 20:24:35.490000 634 torch/_guards.py:297] [13/0]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1614, in inner
    raise constraint_violation_error
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 852, in _compile_inner
    check_fn = CheckFunctionManager(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 2303, in __init__
    guard.create(builder)
  File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 293, in create
    return self.create_fn(builder, self)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1868, in SHAPE_ENV
    code_parts, verbose_code_parts = output_graph.shape_env.produce_guards_verbose(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5188, in produce_guards_verbose
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['w'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of RelaxedUnspecConstraint(L['w'].size()[1]) are valid because L['w'].size()[1] was inferred to be a constant (5).


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 418, in <module>
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
    return _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 679, in _export_to_torch_ir
    raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (L['w'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of RelaxedUnspecConstraint(L['w'].size()[1]) are valid because L['w'].size()[1] was inferred to be a constant (5).

Static guards also aren’t always inherent to the model; they can also come from user specifications. In fact, a common pitfall leading to shape specializations is when the user specifies conflicting markers for equivalent dimensions; one dynamic and another static. The same error type is raised when this is the case for x.shape[0] and y.shape[1]:

dynamic_shapes["w"] = (Dim.AUTO, Dim.AUTO)
dynamic_shapes["x"] = (Dim.STATIC,)
dynamic_shapes["y"] = (Dim.AUTO, Dim.DYNAMIC)
try:
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
except Exception:
    tb.print_exc()
I0130 20:24:35.508000 634 torch/fx/experimental/symbolic_shapes.py:3192] [14/0] create_env
I0130 20:24:35.510000 634 torch/fx/experimental/symbolic_shapes.py:4423] [14/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0130 20:24:35.510000 634 torch/fx/experimental/symbolic_shapes.py:4423] [14/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0130 20:24:35.511000 634 torch/fx/experimental/symbolic_shapes.py:6614] [14/0] runtime_assert True == True [statically known]
I0130 20:24:35.515000 634 torch/fx/experimental/symbolic_shapes.py:4423] [14/0] create_symbol s2 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0130 20:24:35.515000 634 torch/fx/experimental/symbolic_shapes.py:4423] [14/0] create_symbol s3 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0130 20:24:35.518000 634 torch/fx/experimental/symbolic_shapes.py:4423] [14/0] create_symbol s4 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0130 20:24:35.520000 634 torch/fx/experimental/symbolic_shapes.py:6412] [14/0] eval Eq(s3, 1) == False [statically known]
V0130 20:24:35.525000 634 torch/fx/experimental/symbolic_shapes.py:5802] [14/0] _update_var_to_range s3 = VR[4, 4] (update)
I0130 20:24:35.525000 634 torch/fx/experimental/symbolic_shapes.py:5963] [14/0] set_replacement s3 = 4 (range_refined_to_singleton) VR[4, 4]
I0130 20:24:35.526000 634 torch/fx/experimental/symbolic_shapes.py:6281] [14/0] runtime_assert Eq(s3, 4) [guard added] x0 = x + y  # [8, 4]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:268 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s3, 4)"
V0130 20:24:35.527000 634 torch/fx/experimental/symbolic_shapes.py:6412] [14/0] eval Ne(s2, 1) == True [statically known]
V0130 20:24:35.534000 634 torch/fx/experimental/symbolic_shapes.py:5802] [14/0] _update_var_to_range s1 = VR[5, 5] (update)
I0130 20:24:35.534000 634 torch/fx/experimental/symbolic_shapes.py:5963] [14/0] set_replacement s1 = 5 (range_refined_to_singleton) VR[5, 5]
I0130 20:24:35.535000 634 torch/fx/experimental/symbolic_shapes.py:6281] [14/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # [6, 3]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:269 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
V0130 20:24:35.536000 634 torch/fx/experimental/symbolic_shapes.py:6412] [14/0] eval Eq(s0, 1) == False [statically known]
V0130 20:24:35.536000 634 torch/fx/experimental/symbolic_shapes.py:6614] [14/0] runtime_assert True == True [statically known]
V0130 20:24:35.543000 634 torch/fx/experimental/symbolic_shapes.py:6412] [14/0] eval Eq(s4, 1) == False [statically known]
V0130 20:24:35.550000 634 torch/fx/experimental/symbolic_shapes.py:5802] [14/0] _update_var_to_range s4 = VR[8, int_oo] (update)
I0130 20:24:35.552000 634 torch/fx/experimental/symbolic_shapes.py:5963] [14/0] set_replacement s4 = 4*s2 (solve) VR[8, int_oo]
I0130 20:24:35.553000 634 torch/fx/experimental/symbolic_shapes.py:6281] [14/0] runtime_assert Eq(4*s2, s4) [guard added] x3 = x2 + z  # [32]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:271 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(4*s2, s4)"
I0130 20:24:35.560000 634 torch/fx/experimental/symbolic_shapes.py:4547] [14/0] produce_guards
V0130 20:24:35.560000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['w'].size()[0] s0 None
V0130 20:24:35.561000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['w'].size()[1] 5 None
V0130 20:24:35.561000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['w'].stride()[0] 5 None
V0130 20:24:35.561000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['w'].stride()[1] 1 None
V0130 20:24:35.561000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['w'].storage_offset() 0 None
V0130 20:24:35.562000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['x'].size()[0] 4 None
V0130 20:24:35.562000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['x'].stride()[0] 1 None
V0130 20:24:35.562000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['x'].storage_offset() 0 None
V0130 20:24:35.562000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['y'].size()[0] s2 None
V0130 20:24:35.563000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['y'].size()[1] 4 RelaxedUnspecConstraint(warn_only=False)
V0130 20:24:35.563000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['y'].stride()[0] 4 None
V0130 20:24:35.563000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['y'].stride()[1] 1 None
V0130 20:24:35.564000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['y'].storage_offset() 0 None
V0130 20:24:35.564000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['z'].size()[0] 4*s2 None
V0130 20:24:35.564000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['z'].stride()[0] 1 None
V0130 20:24:35.564000 634 torch/fx/experimental/symbolic_shapes.py:4755] [14/0] track_symint L['z'].storage_offset() 0 None
E0130 20:24:35.566000 634 torch/_guards.py:295] [14/0] Error while creating guard:
E0130 20:24:35.566000 634 torch/_guards.py:295] [14/0] Name: ''
E0130 20:24:35.566000 634 torch/_guards.py:295] [14/0]     Source: shape_env
E0130 20:24:35.566000 634 torch/_guards.py:295] [14/0]     Create Function: SHAPE_ENV
E0130 20:24:35.566000 634 torch/_guards.py:295] [14/0]     Guard Types: None
E0130 20:24:35.566000 634 torch/_guards.py:295] [14/0]     Code List: None
E0130 20:24:35.566000 634 torch/_guards.py:295] [14/0]     Object Weakref: None
E0130 20:24:35.566000 634 torch/_guards.py:295] [14/0]     Guarded Class Weakref: None
E0130 20:24:35.566000 634 torch/_guards.py:295] [14/0] Traceback (most recent call last):
E0130 20:24:35.566000 634 torch/_guards.py:295] [14/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 293, in create
E0130 20:24:35.566000 634 torch/_guards.py:295] [14/0]     return self.create_fn(builder, self)
E0130 20:24:35.566000 634 torch/_guards.py:295] [14/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1868, in SHAPE_ENV
E0130 20:24:35.566000 634 torch/_guards.py:295] [14/0]     code_parts, verbose_code_parts = output_graph.shape_env.produce_guards_verbose(
E0130 20:24:35.566000 634 torch/_guards.py:295] [14/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5188, in produce_guards_verbose
E0130 20:24:35.566000 634 torch/_guards.py:295] [14/0]     raise ConstraintViolationError(
E0130 20:24:35.566000 634 torch/_guards.py:295] [14/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['y'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
E0130 20:24:35.566000 634 torch/_guards.py:295] [14/0]   - Not all values of RelaxedUnspecConstraint(L['y'].size()[1]) are valid because L['y'].size()[1] was inferred to be a constant (4).
E0130 20:24:35.567000 634 torch/_guards.py:297] [14/0] Created at:
E0130 20:24:35.567000 634 torch/_guards.py:297] [14/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 642, in transform
E0130 20:24:35.567000 634 torch/_guards.py:297] [14/0]     tracer = InstructionTranslator(
E0130 20:24:35.567000 634 torch/_guards.py:297] [14/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2711, in __init__
E0130 20:24:35.567000 634 torch/_guards.py:297] [14/0]     output=OutputGraph(
E0130 20:24:35.567000 634 torch/_guards.py:297] [14/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 336, in __init__
E0130 20:24:35.567000 634 torch/_guards.py:297] [14/0]     self.init_ambient_guards()
E0130 20:24:35.567000 634 torch/_guards.py:297] [14/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 485, in init_ambient_guards
E0130 20:24:35.567000 634 torch/_guards.py:297] [14/0]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1614, in inner
    raise constraint_violation_error
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 852, in _compile_inner
    check_fn = CheckFunctionManager(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 2303, in __init__
    guard.create(builder)
  File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 293, in create
    return self.create_fn(builder, self)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1868, in SHAPE_ENV
    code_parts, verbose_code_parts = output_graph.shape_env.produce_guards_verbose(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5188, in produce_guards_verbose
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['y'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of RelaxedUnspecConstraint(L['y'].size()[1]) are valid because L['y'].size()[1] was inferred to be a constant (4).


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 431, in <module>
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
    return _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 679, in _export_to_torch_ir
    raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (L['y'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of RelaxedUnspecConstraint(L['y'].size()[1]) are valid because L['y'].size()[1] was inferred to be a constant (4).

Here you might ask why export “specializes”, i.e. why we resolve this static/dynamic conflict by going with the static route. The answer is because of the symbolic shapes system described above, of symbols and guards. When x.shape[0] is marked static, we don’t allocate a symbol, and compile treating this shape as a concrete integer 4. A symbol is allocated for y.shape[1], and so we finally emit the guard s3 == 4, leading to specialization.

One feature of export is that during tracing, statements like asserts, torch._check(), and if/else conditions will also emit guards. See what happens when we augment the existing model with such statements:

class DynamicModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(5, 3)

    def forward(self, w, x, y, z):
        assert w.shape[0] <= 512
        torch._check(x.shape[0] >= 4)
        if w.shape[0] == x.shape[0] + 2:
            x0 = x + y
            x1 = self.l(w)
            x2 = x0.flatten()
            x3 = x2 + z
            return x1, x3
        else:
            return w

dynamic_shapes = {
    "w": (Dim.AUTO, Dim.AUTO),
    "x": (Dim.AUTO,),
    "y": (Dim.AUTO, Dim.AUTO),
    "z": (Dim.AUTO,),
}
try:
    ep = export(DynamicModel(), (w, x, y, z), dynamic_shapes=dynamic_shapes)
except Exception:
    tb.print_exc()
I0130 20:24:35.584000 634 torch/fx/experimental/symbolic_shapes.py:3192] [15/0] create_env
I0130 20:24:35.586000 634 torch/fx/experimental/symbolic_shapes.py:4423] [15/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0130 20:24:35.586000 634 torch/fx/experimental/symbolic_shapes.py:4423] [15/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0130 20:24:35.587000 634 torch/fx/experimental/symbolic_shapes.py:6614] [15/0] runtime_assert True == True [statically known]
I0130 20:24:35.589000 634 torch/fx/experimental/symbolic_shapes.py:4423] [15/0] create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0130 20:24:35.591000 634 torch/fx/experimental/symbolic_shapes.py:4423] [15/0] create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0130 20:24:35.592000 634 torch/fx/experimental/symbolic_shapes.py:4423] [15/0] create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0130 20:24:35.594000 634 torch/fx/experimental/symbolic_shapes.py:4423] [15/0] create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s5" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0130 20:24:35.601000 634 torch/fx/experimental/symbolic_shapes.py:5802] [15/0] _update_var_to_range s0 = VR[2, 512] (update)
I0130 20:24:35.602000 634 torch/fx/experimental/symbolic_shapes.py:6281] [15/0] runtime_assert s0 <= 512 [guard added] assert w.shape[0] <= 512  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:450 in forward (_dynamo/symbolic_convert.py:522 in inner), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="s0 <= 512"
V0130 20:24:35.606000 634 torch/fx/experimental/symbolic_shapes.py:5802] [15/0] _update_var_to_range s2 = VR[4, int_oo] (update)
I0130 20:24:35.607000 634 torch/fx/experimental/symbolic_shapes.py:6281] [15/0] runtime_assert s2 >= 4 [guard added] torch._check(x.shape[0] >= 4)  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:451 in forward (_dynamo/utils.py:2586 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="s2 >= 4"
V0130 20:24:35.615000 634 torch/fx/experimental/symbolic_shapes.py:5802] [15/0] _update_var_to_range s2 = VR[4, 510] (update)
V0130 20:24:35.615000 634 torch/fx/experimental/symbolic_shapes.py:5802] [15/0] _update_var_to_range s0 = VR[6, 512] (update)
I0130 20:24:35.616000 634 torch/fx/experimental/symbolic_shapes.py:5963] [15/0] set_replacement s0 = s2 + 2 (solve) VR[6, 512]
I0130 20:24:35.617000 634 torch/fx/experimental/symbolic_shapes.py:6281] [15/0] eval Eq(s0, s2 + 2) [guard added] if w.shape[0] == x.shape[0] + 2:  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:452 in forward (_dynamo/variables/tensor.py:1201 in evaluate_expr), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s0, s2 + 2)"
V0130 20:24:35.618000 634 torch/fx/experimental/symbolic_shapes.py:6412] [15/0] eval Eq(s2, 1) == False [statically known]
V0130 20:24:35.619000 634 torch/fx/experimental/symbolic_shapes.py:6614] [15/0] runtime_assert True == True [statically known]
V0130 20:24:35.620000 634 torch/fx/experimental/symbolic_shapes.py:6412] [15/0] eval Eq(s4, 1) == False [statically known]
V0130 20:24:35.622000 634 torch/fx/experimental/symbolic_shapes.py:5802] [15/0] _update_var_to_range s4 = VR[4, 510] (update)
I0130 20:24:35.623000 634 torch/fx/experimental/symbolic_shapes.py:5963] [15/0] set_replacement s4 = s2 (solve) VR[4, 510]
I0130 20:24:35.624000 634 torch/fx/experimental/symbolic_shapes.py:6281] [15/0] runtime_assert Eq(s2, s4) [guard added] x0 = x + y  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:453 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
V0130 20:24:35.625000 634 torch/fx/experimental/symbolic_shapes.py:6412] [15/0] eval Ne(s2, 1) == True [statically known]
V0130 20:24:35.626000 634 torch/fx/experimental/symbolic_shapes.py:6412] [15/0] eval Ne(s3, 1) == True [statically known]
V0130 20:24:35.633000 634 torch/fx/experimental/symbolic_shapes.py:5802] [15/0] _update_var_to_range s1 = VR[5, 5] (update)
I0130 20:24:35.633000 634 torch/fx/experimental/symbolic_shapes.py:5963] [15/0] set_replacement s1 = 5 (range_refined_to_singleton) VR[5, 5]
I0130 20:24:35.634000 634 torch/fx/experimental/symbolic_shapes.py:6281] [15/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:454 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
V0130 20:24:35.643000 634 torch/fx/experimental/symbolic_shapes.py:6412] [15/0] eval Eq(s2*s3, 1) == False [statically known]
V0130 20:24:35.644000 634 torch/fx/experimental/symbolic_shapes.py:6412] [15/0] eval Eq(s5, 1) == False [statically known]
V0130 20:24:35.651000 634 torch/fx/experimental/symbolic_shapes.py:5802] [15/0] _update_var_to_range s5 = VR[8, int_oo] (update)
I0130 20:24:35.652000 634 torch/fx/experimental/symbolic_shapes.py:5963] [15/0] set_replacement s5 = s2*s3 (solve) VR[8, int_oo]
I0130 20:24:35.653000 634 torch/fx/experimental/symbolic_shapes.py:6281] [15/0] runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:456 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
V0130 20:24:35.654000 634 torch/fx/experimental/symbolic_shapes.py:6412] [15/0] eval Ne(s2*s3, 1) == True [statically known]
V0130 20:24:35.658000 634 torch/fx/experimental/symbolic_shapes.py:6614] [15/0] runtime_assert s2 >= 4 == True [statically known]
I0130 20:24:35.665000 634 torch/fx/experimental/symbolic_shapes.py:4547] [15/0] produce_guards
V0130 20:24:35.665000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['w'].size()[0] s2 + 2 None
V0130 20:24:35.665000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['w'].size()[1] 5 None
V0130 20:24:35.666000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['w'].stride()[0] 5 None
V0130 20:24:35.666000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['w'].stride()[1] 1 None
V0130 20:24:35.666000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['w'].storage_offset() 0 None
V0130 20:24:35.666000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['x'].size()[0] s2 None
V0130 20:24:35.667000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['x'].stride()[0] 1 None
V0130 20:24:35.667000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['x'].storage_offset() 0 None
V0130 20:24:35.667000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['y'].size()[0] s3 None
V0130 20:24:35.668000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['y'].size()[1] s2 None
V0130 20:24:35.668000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['y'].stride()[0] s2 None
V0130 20:24:35.668000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['y'].stride()[1] 1 None
V0130 20:24:35.668000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['y'].storage_offset() 0 None
V0130 20:24:35.669000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['z'].size()[0] s2*s3 None
V0130 20:24:35.669000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['z'].stride()[0] 1 None
V0130 20:24:35.669000 634 torch/fx/experimental/symbolic_shapes.py:4755] [15/0] track_symint L['z'].storage_offset() 0 None

Each of these statements emits an additional guard, and the exported program shows the changes; s0 is eliminated in favor of s2 + 2, and s2 now contains lower and upper bounds, reflected in range_constraints.

For the if/else condition, you might ask why the True branch was taken, and why it wasn’t the w.shape[0] != x.shape[0] + 2 guard that got emitted from tracing. The answer is that export is guided by the sample inputs provided by tracing, and specializes on the branches taken. If different sample input shapes were provided that fail the if condition, export would trace and emit guards corresponding to the else branch. Additionally, you might ask why we traced only the if branch, and if it’s possible to maintain control-flow in your program and keep both branches alive. For that, refer to rewriting your model code following the Control Flow Ops section above.

0/1 specialization

Since we’re talking about guards and specializations, it’s a good time to talk about the 0/1 specialization issue we brought up earlier. The bottom line is that export will specialize on sample input dimensions with value 0 or 1, because these shapes have trace-time properties that don’t generalize to other shapes. For example, size 1 tensors can broadcast while other sizes fail; and size 0 … . This just means that you should specify 0/1 sample inputs when you’d like your program to hardcode them, and non-0/1 sample inputs when dynamic behavior is desirable. See what happens at runtime when we export this linear layer:

ep = export(
    torch.nn.Linear(4, 3),
    (torch.randn(1, 4),),
    dynamic_shapes={
        "input": (Dim.AUTO, Dim.STATIC),
    },
)
try:
    ep.module()(torch.randn(2, 4))
except Exception:
    tb.print_exc()
I0130 20:24:35.738000 634 torch/fx/experimental/symbolic_shapes.py:3192] [3/1] create_env
I0130 20:24:35.752000 634 torch/fx/experimental/symbolic_shapes.py:4547] [3/1] produce_guards
V0130 20:24:35.752000 634 torch/fx/experimental/symbolic_shapes.py:4755] [3/1] track_symint L['args'][0].size()[0] 1 None
V0130 20:24:35.752000 634 torch/fx/experimental/symbolic_shapes.py:4755] [3/1] track_symint L['args'][0].size()[1] 4 None
V0130 20:24:35.753000 634 torch/fx/experimental/symbolic_shapes.py:4755] [3/1] track_symint L['args'][0].stride()[0] 4 None
V0130 20:24:35.753000 634 torch/fx/experimental/symbolic_shapes.py:4755] [3/1] track_symint L['args'][0].stride()[1] 1 None
V0130 20:24:35.753000 634 torch/fx/experimental/symbolic_shapes.py:4755] [3/1] track_symint L['args'][0].storage_offset() 0 None
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 500, in <module>
    ep.module()(torch.randn(2, 4))
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 822, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 400, in __call__
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 387, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1845, in _call_impl
    return inner()
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1772, in inner
    args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_unlift.py", line 49, in _check_input_constraints_pre_hook
    _check_input_constraints_for_graph(
  File "/usr/local/lib/python3.10/dist-packages/torch/_export/utils.py", line 360, in _check_input_constraints_for_graph
    raise RuntimeError(
RuntimeError: Expected input at *args[0].shape[0] to be equal to 1, but got 2

Named Dims

So far we’ve only been talking about 3 ways to specify dynamic shapes: Dim.AUTO, Dim.DYNAMIC, and Dim.STATIC. The attraction of these is the low-friction user experience; all the guards emitted during model tracing are adhered to, and dynamic behavior like min/max ranges, relations, and static/dynamic dimensions are automatically figured out underneath export. The dynamic shapes subsystem essentially acts as a “discovery” process, summarizing these guards and presenting what export believes is the overall dynamic behavior of the program. The drawback of this design appears once the user has stronger expectations or beliefs about the dynamic behavior of these models - maybe there is a strong desire on dynamism and specializations on particular dimensions are to be avoided at all costs, or maybe we just want to catch changes in dynamic behavior with changes to the original model code, or possibly underlying decompositions or meta-kernels. These changes won’t be detected and the export() call will most likely succeed, unless tests are in place that check the resulting ExportedProgram representation.

For such cases, our stance is to recommend the “traditional” way of specifying dynamic shapes, which longer-term users of export might be familiar with: named Dims:

dx = Dim("dx", min=4, max=256)
dh = Dim("dh", max=512)
dynamic_shapes = {
    "x": (dx, None),
    "y": (2 * dx, dh),
}

This style of dynamic shapes allows the user to specify what symbols are allocated for input dimensions, min/max bounds on those symbols, and places restrictions on the dynamic behavior of the ExportedProgram produced; ConstraintViolation errors will be raised if model tracing emits guards that conflict with the relations or static/dynamic specifications given. For example, in the above specification, the following is asserted:

  • x.shape[0] is to have range [4, 256], and related to y.shape[0] by y.shape[0] == 2 * x.shape[0].

  • x.shape[1] is static.

  • y.shape[1] has range [2, 512], and is unrelated to any other dimension.

In this design, we allow relations between dimensions to be specified with univariate linear expressions: A * dim + B can be specified for any dimension. This allows users to specify more complex constraints like integer divisibility for dynamic dimensions:

dx = Dim("dx", min=4, max=512)
dynamic_shapes = {
    "x": (4 * dx, None)  # x.shape[0] has range [16, 2048], and is divisible by 4.
}

Constraint violations, suggested fixes

One common issue with this specification style (before Dim.AUTO was introduced), is that the specification would often be mismatched with what was produced by model tracing. That would lead to ConstraintViolation errors and export suggested fixes - see for example with this model & specification, where the model inherently requires equality between dimensions 0 of x and y, and requires dimension 1 to be static.

class Foo(torch.nn.Module):
    def forward(self, x, y):
        w = x + y
        return w + torch.ones(4)

dx, dy, d1 = torch.export.dims("dx", "dy", "d1")
try:
    ep = export(
        Foo(),
        (torch.randn(6, 4), torch.randn(6, 4)),
        dynamic_shapes={
            "x": (dx, d1),
            "y": (dy, d1),
        },
    )
except Exception:
    tb.print_exc()
I0130 20:24:35.901000 634 torch/fx/experimental/symbolic_shapes.py:3192] [16/0] create_env
I0130 20:24:35.904000 634 torch/fx/experimental/symbolic_shapes.py:4423] [16/0] create_symbol s0 = 6 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0130 20:24:35.904000 634 torch/fx/experimental/symbolic_shapes.py:4423] [16/0] create_symbol s1 = 4 for L['x'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0130 20:24:35.905000 634 torch/fx/experimental/symbolic_shapes.py:6614] [16/0] runtime_assert True == True [statically known]
I0130 20:24:35.908000 634 torch/fx/experimental/symbolic_shapes.py:4423] [16/0] create_symbol s2 = 6 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0130 20:24:35.908000 634 torch/fx/experimental/symbolic_shapes.py:4423] [16/0] create_symbol s3 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2861 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0130 20:24:35.914000 634 torch/fx/experimental/symbolic_shapes.py:6412] [16/0] eval Eq(s1, 1) == False [statically known]
V0130 20:24:35.915000 634 torch/fx/experimental/symbolic_shapes.py:6614] [16/0] runtime_assert True == True [statically known]
V0130 20:24:35.915000 634 torch/fx/experimental/symbolic_shapes.py:6412] [16/0] eval Eq(s0, 1) == False [statically known]
V0130 20:24:35.916000 634 torch/fx/experimental/symbolic_shapes.py:6412] [16/0] eval Eq(s3, 1) == False [statically known]
I0130 20:24:35.919000 634 torch/fx/experimental/symbolic_shapes.py:5963] [16/0] set_replacement s3 = s1 (solve) VR[2, int_oo]
I0130 20:24:35.920000 634 torch/fx/experimental/symbolic_shapes.py:6281] [16/0] runtime_assert Eq(s1, s3) [guard added] w = x + y  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:552 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, s3)"
V0130 20:24:35.920000 634 torch/fx/experimental/symbolic_shapes.py:6412] [16/0] eval Eq(s2, 1) == False [statically known]
I0130 20:24:35.922000 634 torch/fx/experimental/symbolic_shapes.py:5963] [16/0] set_replacement s2 = s0 (solve) VR[2, int_oo]
I0130 20:24:35.923000 634 torch/fx/experimental/symbolic_shapes.py:6281] [16/0] runtime_assert Eq(s0, s2) [guard added] w = x + y  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:552 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s0, s2)"
V0130 20:24:35.925000 634 torch/fx/experimental/symbolic_shapes.py:6412] [16/0] eval Ne(s1, 1) == True [statically known]
V0130 20:24:35.925000 634 torch/fx/experimental/symbolic_shapes.py:6412] [16/0] eval Ne(s0, 1) == True [statically known]
V0130 20:24:35.932000 634 torch/fx/experimental/symbolic_shapes.py:5802] [16/0] _update_var_to_range s1 = VR[4, 4] (update)
I0130 20:24:35.933000 634 torch/fx/experimental/symbolic_shapes.py:5963] [16/0] set_replacement s1 = 4 (range_refined_to_singleton) VR[4, 4]
I0130 20:24:35.934000 634 torch/fx/experimental/symbolic_shapes.py:6281] [16/0] runtime_assert Eq(s1, 4) [guard added] return w + torch.ones(4)  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:553 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 4)"
V0130 20:24:35.937000 634 torch/fx/experimental/symbolic_shapes.py:5802] [16/0] _update_var_to_range s3 = VR[4, 4] (update)
I0130 20:24:35.937000 634 torch/fx/experimental/symbolic_shapes.py:5963] [16/0] set_replacement s3 = 4 (find) VR[4, 4]
I0130 20:24:35.940000 634 torch/fx/experimental/symbolic_shapes.py:4547] [16/0] produce_guards
V0130 20:24:35.941000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['x'].size()[0] s0 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V0130 20:24:35.941000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['x'].size()[1] 4 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V0130 20:24:35.941000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['x'].stride()[0] 4 None
V0130 20:24:35.942000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['x'].stride()[1] 1 None
V0130 20:24:35.942000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['x'].storage_offset() 0 None
V0130 20:24:35.942000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['y'].size()[0] s0 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V0130 20:24:35.942000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['y'].size()[1] 4 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V0130 20:24:35.943000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['y'].stride()[0] 4 None
V0130 20:24:35.943000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['y'].stride()[1] 1 None
V0130 20:24:35.943000 634 torch/fx/experimental/symbolic_shapes.py:4755] [16/0] track_symint L['y'].storage_offset() 0 None
E0130 20:24:35.945000 634 torch/_guards.py:295] [16/0] Error while creating guard:
E0130 20:24:35.945000 634 torch/_guards.py:295] [16/0] Name: ''
E0130 20:24:35.945000 634 torch/_guards.py:295] [16/0]     Source: shape_env
E0130 20:24:35.945000 634 torch/_guards.py:295] [16/0]     Create Function: SHAPE_ENV
E0130 20:24:35.945000 634 torch/_guards.py:295] [16/0]     Guard Types: None
E0130 20:24:35.945000 634 torch/_guards.py:295] [16/0]     Code List: None
E0130 20:24:35.945000 634 torch/_guards.py:295] [16/0]     Object Weakref: None
E0130 20:24:35.945000 634 torch/_guards.py:295] [16/0]     Guarded Class Weakref: None
E0130 20:24:35.945000 634 torch/_guards.py:295] [16/0] Traceback (most recent call last):
E0130 20:24:35.945000 634 torch/_guards.py:295] [16/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 293, in create
E0130 20:24:35.945000 634 torch/_guards.py:295] [16/0]     return self.create_fn(builder, self)
E0130 20:24:35.945000 634 torch/_guards.py:295] [16/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1868, in SHAPE_ENV
E0130 20:24:35.945000 634 torch/_guards.py:295] [16/0]     code_parts, verbose_code_parts = output_graph.shape_env.produce_guards_verbose(
E0130 20:24:35.945000 634 torch/_guards.py:295] [16/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5188, in produce_guards_verbose
E0130 20:24:35.945000 634 torch/_guards.py:295] [16/0]     raise ConstraintViolationError(
E0130 20:24:35.945000 634 torch/_guards.py:295] [16/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (d1, dy)! For more information, run with TORCH_LOGS="+dynamic".
E0130 20:24:35.945000 634 torch/_guards.py:295] [16/0]   - Not all values of d1 = L['x'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
E0130 20:24:35.945000 634 torch/_guards.py:295] [16/0]   - Not all values of d1 = L['y'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
E0130 20:24:35.945000 634 torch/_guards.py:295] [16/0]   - The values of dy = L['y'].size()[0] and dx = L['x'].size()[0] must always be equal.
E0130 20:24:35.946000 634 torch/_guards.py:297] [16/0] Created at:
E0130 20:24:35.946000 634 torch/_guards.py:297] [16/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 642, in transform
E0130 20:24:35.946000 634 torch/_guards.py:297] [16/0]     tracer = InstructionTranslator(
E0130 20:24:35.946000 634 torch/_guards.py:297] [16/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2711, in __init__
E0130 20:24:35.946000 634 torch/_guards.py:297] [16/0]     output=OutputGraph(
E0130 20:24:35.946000 634 torch/_guards.py:297] [16/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 336, in __init__
E0130 20:24:35.946000 634 torch/_guards.py:297] [16/0]     self.init_ambient_guards()
E0130 20:24:35.946000 634 torch/_guards.py:297] [16/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 485, in init_ambient_guards
E0130 20:24:35.946000 634 torch/_guards.py:297] [16/0]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1614, in inner
    raise constraint_violation_error
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 852, in _compile_inner
    check_fn = CheckFunctionManager(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 2303, in __init__
    guard.create(builder)
  File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 293, in create
    return self.create_fn(builder, self)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1868, in SHAPE_ENV
    code_parts, verbose_code_parts = output_graph.shape_env.produce_guards_verbose(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5188, in produce_guards_verbose
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (d1, dy)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of d1 = L['x'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
  - Not all values of d1 = L['y'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
  - The values of dy = L['y'].size()[0] and dx = L['x'].size()[0] must always be equal.

Suggested fixes:
  d1 = 4
  dy = dx

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 557, in <module>
    ep = export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
    return _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 679, in _export_to_torch_ir
    raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (d1, dy)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of d1 = L['x'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
  - Not all values of d1 = L['y'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4).
  - The values of dy = L['y'].size()[0] and dx = L['x'].size()[0] must always be equal.

Suggested fixes:
  d1 = 4
  dy = dx

The expectation with suggested fixes is that the user can interactively copy-paste the changes into their dynamic shapes specification, and successfully export afterwards.

Lastly, there’s couple nice-to-knows about the options for specification:

  • None is a good option for static behavior: - dynamic_shapes=None (default) exports with the entire model being static. - specifying None at an input-level exports with all tensor dimensions static, and is also required for non-tensor inputs. - specifying None at a dimension-level specializes that dimension, though this is deprecated in favor of Dim.STATIC.

  • specifying per-dimension integer values also produces static behavior, and will additionally check that the provided sample input matches the specification.

These options are combined in the inputs & dynamic shapes spec below:

inputs = (
    torch.randn(4, 4),
    torch.randn(3, 3),
    16,
    False,
)
dynamic_shapes = {
    "tensor_0": (Dim.AUTO, None),
    "tensor_1": None,
    "int_val": None,
    "bool_val": None,
}

Data-dependent errors

While trying to export models, you have may have encountered errors like “Could not guard on data-dependent expression”, or Could not extract specialized integer from data-dependent expression”. These errors exist because torch.export() compiles programs using FakeTensors, which symbolically represent their real tensor counterparts. While these have equivalent symbolic properties (e.g. sizes, strides, dtypes), they diverge in that FakeTensors do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that export may be unable to out-of-the-box compile parts of user code where compilation relies on data values. In short, if the compiler requires a concrete, data-dependent value in order to proceed, it will error out, complaining that the value is not available.

Data-dependent values appear in many places, and common sources are calls like item(), tolist(), or torch.unbind() that extract scalar values from tensors. How are these values represented in the exported program? In the Constraints/Dynamic Shapes section, we talked about allocating symbols to represent dynamic input dimensions. The same happens here: we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are “unbacked” symbols, in contrast to the “backed” symbols allocated for input dimensions. The “backed/unbacked” nomenclature refers to the presence/absence of a “hint” for the symbol: a concrete value backing the symbol, that can inform the compiler on how to proceed.

In the input shape symbol case (backed symbols), these hints are simply the sample input shapes provided, which explains why control-flow branching is determined by the sample input properties. For data-dependent values, the symbols are taken from FakeTensor “data” during tracing, and so the compiler doesn’t know the actual values (hints) that these symbols would take on.

Let’s see how these show up in exported programs:

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        b = y.tolist()
        return b + [a]

inps = (
    torch.tensor(1),
    torch.tensor([2, 3]),
)
ep = export(Foo(), inps)
print(ep)
I0130 20:24:35.962000 634 torch/fx/experimental/symbolic_shapes.py:3192] [17/0] create_env
I0130 20:24:35.966000 634 torch/fx/experimental/symbolic_shapes.py:4103] [17/0] create_unbacked_symint u0 [-int_oo, int_oo] a = x.item()  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:618 in forward (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0130 20:24:35.966000 634 torch/fx/experimental/symbolic_shapes.py:970] [17/0] compute_unbacked_bindings [u0]
I0130 20:24:35.969000 634 torch/fx/experimental/symbolic_shapes.py:4103] [17/0] create_unbacked_symint u1 [-int_oo, int_oo] b = y.tolist()  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:619 in forward (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0130 20:24:35.969000 634 torch/fx/experimental/symbolic_shapes.py:970] [17/0] compute_unbacked_bindings [u1]
I0130 20:24:35.971000 634 torch/fx/experimental/symbolic_shapes.py:4103] [17/0] create_unbacked_symint u2 [-int_oo, int_oo] b = y.tolist()  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:619 in forward (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0130 20:24:35.972000 634 torch/fx/experimental/symbolic_shapes.py:970] [17/0] compute_unbacked_bindings [u2]
I0130 20:24:35.976000 634 torch/fx/experimental/symbolic_shapes.py:4547] [17/0] produce_guards
V0130 20:24:35.976000 634 torch/fx/experimental/symbolic_shapes.py:4755] [17/0] track_symint L['x'].storage_offset() 0 None
V0130 20:24:35.976000 634 torch/fx/experimental/symbolic_shapes.py:4755] [17/0] track_symint L['y'].size()[0] 2 None
V0130 20:24:35.977000 634 torch/fx/experimental/symbolic_shapes.py:4755] [17/0] track_symint L['y'].stride()[0] 1 None
V0130 20:24:35.977000 634 torch/fx/experimental/symbolic_shapes.py:4755] [17/0] track_symint L['y'].storage_offset() 0 None
I0130 20:24:35.985000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u3 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0130 20:24:35.986000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u4 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0130 20:24:35.993000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u5 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0130 20:24:35.994000 634 torch/fx/experimental/symbolic_shapes.py:970] compute_unbacked_bindings [u5]
I0130 20:24:35.995000 634 torch/fx/experimental/symbolic_shapes.py:5963] set_replacement u5 = u0 (rename_unbacked_to) VR[-int_oo, int_oo]
I0130 20:24:35.996000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u6 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0130 20:24:35.997000 634 torch/fx/experimental/symbolic_shapes.py:970] compute_unbacked_bindings [u6]
I0130 20:24:35.997000 634 torch/fx/experimental/symbolic_shapes.py:5963] set_replacement u6 = u1 (rename_unbacked_to) VR[-int_oo, int_oo]
I0130 20:24:35.999000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u7 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0130 20:24:35.999000 634 torch/fx/experimental/symbolic_shapes.py:970] compute_unbacked_bindings [u7]
I0130 20:24:36.000000 634 torch/fx/experimental/symbolic_shapes.py:5963] set_replacement u7 = u2 (rename_unbacked_to) VR[-int_oo, int_oo]
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "i64[2]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:618 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:619 in forward, code: b = y.tolist()
            select: "i64[]" = torch.ops.aten.select.int(y, 0, 0)
            item_1: "Sym(u1)" = torch.ops.aten.item.default(select);  select = None
            select_1: "i64[]" = torch.ops.aten.select.int(y, 0, 1);  y = None
            item_2: "Sym(u2)" = torch.ops.aten.item.default(select_1);  select_1 = None
            return (item_1, item_2, item)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=SymIntArgument(name='item_1'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=SymIntArgument(name='item_2'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=SymIntArgument(name='item'), target=None)])
Range constraints: {u0: VR[-int_oo, int_oo], u1: VR[-int_oo, int_oo], u2: VR[-int_oo, int_oo], u3: VR[-int_oo, int_oo], u4: VR[-int_oo, int_oo], u5: VR[-int_oo, int_oo], u6: VR[-int_oo, int_oo], u7: VR[-int_oo, int_oo]}

The result is that 3 unbacked symbols (notice they’re prefixed with “u”, instead of the usual “s” for input shape/backed symbols) are allocated and returned: 1 for the item() call, and 1 for each of the elements of y with the tolist() call. Note from the range constraints field that these take on ranges of [-int_oo, int_oo], not the default [0, int_oo] range allocated to input shape symbols, since we have no information on what these values are - they don’t represent sizes, so don’t necessarily have positive values.

Guards, torch._check()

But the case above is easy to export, because the concrete values of these symbols aren’t used in any compiler decision-making; all that’s relevant is that the return values are unbacked symbols. The data-dependent errors highlighted in this section are cases like the following, where data-dependent guards are encountered:

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        if a // 2 >= 5:
            return y + 2
        else:
            return y * 5

Here we actually need the “hint”, or the concrete value of a for the compiler to decide whether to trace return y + 2 or return y * 5 as the output. Because we trace with FakeTensors, we don’t know what a // 2 >= 5 actually evaluates to, and export errors out with “Could not guard on data-dependent expression u0 // 2 >= 5 (unhinted)”.

So how do we export this toy model? Unlike torch.compile(), export requires full graph compilation, and we can’t just graph break on this. Here are some basic options:

  1. Manual specialization: we could intervene by selecting the branch to trace, either by removing the control-flow code to contain only the specialized branch, or using torch.compiler.is_compiling() to guard what’s traced at compile-time.

  2. torch.cond(): we could rewrite the control-flow code to use torch.cond() so we don’t specialize on a branch.

While these options are valid, they have their pitfalls. Option 1 sometimes requires drastic, invasive rewrites of the model code to specialize, and torch.cond() is not a comprehensive system for handling data-dependent errors. As we will see, there are data-dependent errors that do not involve control-flow.

The generally recommended approach is to start with torch._check() calls. While these give the impression of purely being assert statements, they are in fact a system of informing the compiler on properties of symbols. While a torch._check() call does act as an assertion at runtime, when traced at compile-time, the checked expression is sent to the symbolic shapes subsystem for reasoning, and any symbol properties that follow from the expression being true, are stored as symbol properties (provided it’s smart enough to infer those properties). So even if unbacked symbols don’t have hints, if we’re able to communicate properties that are generally true for these symbols via torch._check() calls, we can potentially bypass data-dependent guards without rewriting the offending model code.

For example in the model above, inserting torch._check(a >= 10) would tell the compiler that y + 2 can always be returned, and torch._check(a == 4) tells it to return y * 5. See what happens when we re-export this model.

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        torch._check(a >= 10)
        torch._check(a <= 60)
        if a // 2 >= 5:
            return y + 2
        else:
            return y * 5

inps = (
    torch.tensor(32),
    torch.randn(4),
)
ep = export(Foo(), inps)
print(ep)
I0130 20:24:36.010000 634 torch/fx/experimental/symbolic_shapes.py:3192] [18/0] create_env
I0130 20:24:36.014000 634 torch/fx/experimental/symbolic_shapes.py:4103] [18/0] create_unbacked_symint u0 [-int_oo, int_oo] a = x.item()  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:672 in forward (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0130 20:24:36.015000 634 torch/fx/experimental/symbolic_shapes.py:970] [18/0] compute_unbacked_bindings [u0]
V0130 20:24:36.017000 634 torch/fx/experimental/symbolic_shapes.py:5802] [18/0] _update_var_to_range u0 = VR[10, int_oo] (update)
I0130 20:24:36.018000 634 torch/fx/experimental/symbolic_shapes.py:6281] [18/0] runtime_assert u0 >= 10 [guard added] torch._check(a >= 10)  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:673 in forward (_dynamo/utils.py:2586 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 10"
V0130 20:24:36.023000 634 torch/fx/experimental/symbolic_shapes.py:5802] [18/0] _update_var_to_range u0 = VR[10, 60] (update)
I0130 20:24:36.024000 634 torch/fx/experimental/symbolic_shapes.py:6281] [18/0] runtime_assert u0 <= 60 [guard added] torch._check(a <= 60)  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:674 in forward (_dynamo/utils.py:2586 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 <= 60"
V0130 20:24:36.028000 634 torch/fx/experimental/symbolic_shapes.py:6412] [18/0] eval ((u0//2)) >= 5 == True [statically known]
V0130 20:24:36.031000 634 torch/fx/experimental/symbolic_shapes.py:6614] [18/0] runtime_assert u0 >= 10 == True [statically known]
V0130 20:24:36.032000 634 torch/fx/experimental/symbolic_shapes.py:6614] [18/0] runtime_assert u0 <= 60 == True [statically known]
I0130 20:24:36.036000 634 torch/fx/experimental/symbolic_shapes.py:4547] [18/0] produce_guards
V0130 20:24:36.036000 634 torch/fx/experimental/symbolic_shapes.py:4755] [18/0] track_symint L['x'].storage_offset() 0 None
V0130 20:24:36.036000 634 torch/fx/experimental/symbolic_shapes.py:4755] [18/0] track_symint L['y'].size()[0] 4 None
V0130 20:24:36.037000 634 torch/fx/experimental/symbolic_shapes.py:4755] [18/0] track_symint L['y'].stride()[0] 1 None
V0130 20:24:36.037000 634 torch/fx/experimental/symbolic_shapes.py:4755] [18/0] track_symint L['y'].storage_offset() 0 None
I0130 20:24:36.052000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u1 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0130 20:24:36.053000 634 torch/fx/experimental/symbolic_shapes.py:970] compute_unbacked_bindings [u1]
V0130 20:24:36.053000 634 torch/fx/experimental/symbolic_shapes.py:5802] _update_var_to_range u1 = VR[10, 60] (update)
I0130 20:24:36.054000 634 torch/fx/experimental/symbolic_shapes.py:5963] set_replacement u1 = u0 (rename_unbacked_to) VR[10, 60]
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "f32[4]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:672 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None
            ge_1: "Sym(u0 >= 10)" = item >= 10
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 10 on node 'ge_1'");  ge_1 = _assert_scalar_default = None
            le_1: "Sym(u0 <= 60)" = item <= 60;  item = None
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 60 on node 'le_1'");  le_1 = _assert_scalar_default_1 = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:676 in forward, code: return y + 2
            add: "f32[4]" = torch.ops.aten.add.Tensor(y, 2);  y = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {u0: VR[10, 60], u1: VR[10, 60]}

Export succeeds, and note from the range constraints field that u0 takes on a range of [10, 60].

So what information do torch._check() calls actually communicate? This varies as the symbolic shapes subsystem gets smarter, but at a fundamental level, these are generally true:

  1. Equality with non-data-dependent expressions: torch._check() calls that communicate equalities like u0 == s0 + 4 or u0 == 5.

  2. Range refinement: calls that provide lower or upper bounds for symbols, like the above.

  3. Some basic reasoning around more complicated expressions: inserting torch._check(a < 4) will typically tell the compiler that a >= 4 is false. Checks on complex expressions like torch._check(a ** 2 - 3 * a <= 10) will typically get you past identical guards.

As mentioned previously, torch._check() calls have applicability outside of data-dependent control flow. For example, here’s a model where torch._check() insertion prevails while manual specialization & torch.cond() do not:

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        return y[a]

inps = (
    torch.tensor(32),
    torch.randn(60),
)
try:
    export(Foo(), inps)
except Exception:
    tb.print_exc()
I0130 20:24:36.068000 634 torch/fx/experimental/symbolic_shapes.py:3192] [19/0] create_env
I0130 20:24:36.072000 634 torch/fx/experimental/symbolic_shapes.py:4103] [19/0] create_unbacked_symint u0 [-int_oo, int_oo] a = x.item()  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:701 in forward (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0130 20:24:36.072000 634 torch/fx/experimental/symbolic_shapes.py:970] [19/0] compute_unbacked_bindings [u0]
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0] Data dependent variable 'u0' allocated at:
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/bin/sphinx-build", line 8, in <module>
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     sys.exit(main())
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 288, in main
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return make_main(argv)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 193, in make_main
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return make_mode.run_make_mode(argv[1:])
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 160, in run_make_mode
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return make.run_generic_build(args[0])
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 148, in run_generic_build
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return build_main(args + opts)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 272, in build_main
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     app = Sphinx(args.sourcedir, args.confdir, args.outputdir,
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 256, in __init__
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     self._init_builder()
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 314, in _init_builder
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     self.events.emit('builder-inited')
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/events.py", line 94, in emit
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     results.append(listener.handler(self.app, *args))
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 491, in generate_gallery_rst
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     ) = generate_dir_rst(
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 431, in generate_dir_rst
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     intro, title, cost = generate_file_rst(
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1027, in generate_file_rst
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     output_blocks, time_elapsed = execute_script(script_blocks,
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 945, in execute_script
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     output_blocks.append(execute_code_block(
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 810, in execute_code_block
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     is_last_expr, mem_max = _exec_and_get_memory(
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 676, in _exec_and_get_memory
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     mem_max, _ = gallery_conf['call_memory'](
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 223, in call_memory
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return 0., func()
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 600, in __call__
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     exec(self.code, self.fake_main.__dict__)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 709, in <module>
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     export(Foo(), inps)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return _export(
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     ep = fn(*args, **kwargs)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return fn(*args, **kwargs)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return _export_for_training(
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     ep = fn(*args, **kwargs)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return fn(*args, **kwargs)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     export_artifact = export_func(  # type: ignore[operator]
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     gm_torch_level = _export_to_torch_ir(
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     gm_torch_level, _ = torch._dynamo.export(
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     result_traced = opt_f(*args, **kwargs)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return self._call_impl(*args, **kwargs)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return forward_call(*args, **kwargs)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return fn(*args, **kwargs)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return self._call_impl(*args, **kwargs)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return forward_call(*args, **kwargs)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return self._torchdynamo_orig_callable(
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return _compile(
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     guarded_code = compile_inner(code, one_graph, hooks, transform)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return _compile_inner(code, one_graph, hooks, transform)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return function(*args, **kwargs)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     out_code = transform_code_object(code, transform)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     transformations(instructions, code_options)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return fn(*args, **kwargs)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 662, in transform
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     tracer.run()
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     super().run()
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     while self.step():
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     self.dispatch_table[inst.opcode](self, inst)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return inner_fn(self, inst)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1658, in CALL_FUNCTION
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     self.call_function(fn, args, {})
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/misc.py", line 1022, in call_function
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return self.obj.call_method(tx, self.name, args, kwargs)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/tensor.py", line 591, in call_method
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return wrap_fx_proxy(
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2153, in wrap_fx_proxy
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2219, in wrap_fx_proxy_cls
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return _wrap_fx_proxy(
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2315, in _wrap_fx_proxy
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2471, in get_fake_value
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     ret_val = wrap_fake_exception(
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2017, in wrap_fake_exception
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return fn()
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2472, in <lambda>
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     lambda: run_node(tx.output, node, args, kwargs, nnmodule)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2588, in run_node
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return getattr(args[0], node.target)(*args[1:], **kwargs)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 21, in wrapper
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return fn(*args, **kwargs)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return self.dispatch(func, types, args, kwargs)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return self._cached_dispatch_impl(func, types, args, kwargs)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1386, in _cached_dispatch_impl
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     output = self._dispatch_impl(func, types, args, kwargs)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2354, in _dispatch_impl
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     op_impl_out = op_impl(self, func, *args, **kwargs)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 160, in dispatch_to_op_implementations_dict
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 403, in local_scalar_dense
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     r = fake_mode.shape_env.create_unbacked_symint()
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]     return retlog(fn(*args, **kwargs))
V0130 20:24:36.075000 634 torch/fx/experimental/symbolic_shapes.py:5727] [19/0]
W0130 20:24:36.082000 634 torch/fx/experimental/symbolic_shapes.py:6307] [19/0] failed during evaluate_expr(-u0 > 60, hint=None, size_oblivious=True, forcing_spec=False
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0] failed while running evaluate_expr(*(-u0 > 60, None), **{'fx_node': False, 'size_oblivious': True})
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0] Traceback (most recent call last):
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0]     return retlog(fn(*args, **kwargs))
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0]     return self._evaluate_expr(
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0]     raise self._make_data_dependent_error(
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression -u0 > 60 (unhinted: -u0 > 60).  (Size-like symbols: none)
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0]
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0] Caused by: return y[a]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:702 in forward (_meta_registrations.py:4874 in meta_select)
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0] For more information, run with TORCH_LOGS="dynamic"
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0] For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0] If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0] For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0]
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0] User Stack (most recent call last):
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0]   (snipped, see stack below for prefix)
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0]   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0]     return y[a]
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0]
E0130 20:24:36.083000 634 torch/fx/experimental/recording.py:299] [19/0] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] failed while attempting to run meta for aten.select.int
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] Traceback (most recent call last):
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2384, in _dispatch_impl
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]     r = func(*args, **kwargs)
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 723, in __call__
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]     return self._op(*args, **kwargs)
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/_meta_registrations.py", line 4874, in meta_select
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]     guard_size_oblivious(-index > size) or guard_size_oblivious(index >= size)
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 407, in guard_size_oblivious
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]     return expr.node.guard_size_oblivious("", 0)
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 564, in guard_size_oblivious
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]     r = self.shape_env.evaluate_expr(
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]     return retlog(fn(*args, **kwargs))
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]     return self._evaluate_expr(
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]     raise self._make_data_dependent_error(
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression -u0 > 60 (unhinted: -u0 > 60).  (Size-like symbols: none)
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] Caused by: return y[a]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:702 in forward (_meta_registrations.py:4874 in meta_select)
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] For more information, run with TORCH_LOGS="dynamic"
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] User Stack (most recent call last):
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]   (snipped, see stack below for prefix)
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]     return y[a]
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0]
E0130 20:24:36.084000 634 torch/_subclasses/fake_tensor.py:2388] [19/0] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2586, in run_node
    return node.target(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2384, in _dispatch_impl
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_meta_registrations.py", line 4874, in meta_select
    guard_size_oblivious(-index > size) or guard_size_oblivious(index >= size)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 407, in guard_size_oblivious
    return expr.node.guard_size_oblivious("", 0)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 564, in guard_size_oblivious
    r = self.shape_env.evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
    return retlog(fn(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
    return self._evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression -u0 > 60 (unhinted: -u0 > 60).  (Size-like symbols: none)

Caused by: return y[a]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:702 in forward (_meta_registrations.py:4874 in meta_select)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
    return y[a]

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2471, in get_fake_value
    ret_val = wrap_fake_exception(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2017, in wrap_fake_exception
    return fn()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2472, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2604, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2586, in run_node
    return node.target(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2384, in _dispatch_impl
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_meta_registrations.py", line 4874, in meta_select
    guard_size_oblivious(-index > size) or guard_size_oblivious(index >= size)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 407, in guard_size_oblivious
    return expr.node.guard_size_oblivious("", 0)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 564, in guard_size_oblivious
    r = self.shape_env.evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
    return retlog(fn(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
    return self._evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
    raise self._make_data_dependent_error(
RuntimeError: Failed running call_function <built-in method select of type object at 0x7f606ac1fec0>(*(FakeTensor(..., size=(60,)), 0, u0), **{}):
Could not guard on data-dependent expression -u0 > 60 (unhinted: -u0 > 60).  (Size-like symbols: none)

Caused by: return y[a]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:702 in forward (_meta_registrations.py:4874 in meta_select)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
    return y[a]

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 709, in <module>
    export(Foo(), inps)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
    return _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 314, in impl
    self.push(fn_var.call_function(self, self.popn(nargs), {}))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 1004, in call_function
    return handler(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 980, in _handle_insert_op_in_graph
    return wrap_fx_proxy(tx, proxy)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2153, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2219, in wrap_fx_proxy_cls
    return _wrap_fx_proxy(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2315, in _wrap_fx_proxy
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2526, in get_fake_value
    raise UserError(  # noqa: B904
torch._dynamo.exc.UserError: Could not guard on data-dependent expression -u0 > 60 (unhinted: -u0 > 60).  (Size-like symbols: none)

Caused by: return y[a]  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:702 in forward (_meta_registrations.py:4874 in meta_select)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
    return y[a]

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#constrain-as-size-example

from user code:
   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward
    return y[a]

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Here is a scenario where torch._check() insertion is required simply to prevent an operation from failing. The export call will fail with “Could not guard on data-dependent expression -u0 > 60”, implying that the compiler doesn’t know if this is a valid indexing operation - if the value of x is out-of-bounds for y or not. Here, manual specialization is too prohibitive, and torch.cond() has no place. Instead, informing the compiler of u0’s range is sufficient:

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        torch._check(a >= 0)
        torch._check(a < y.shape[0])
        return y[a]

inps = (
    torch.tensor(32),
    torch.randn(60),
)
ep = export(Foo(), inps)
print(ep)
I0130 20:24:36.107000 634 torch/fx/experimental/symbolic_shapes.py:3192] [20/0] create_env
I0130 20:24:36.111000 634 torch/fx/experimental/symbolic_shapes.py:4103] [20/0] create_unbacked_symint u0 [-int_oo, int_oo] a = x.item()  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:721 in forward (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0130 20:24:36.111000 634 torch/fx/experimental/symbolic_shapes.py:970] [20/0] compute_unbacked_bindings [u0]
V0130 20:24:36.113000 634 torch/fx/experimental/symbolic_shapes.py:5802] [20/0] _update_var_to_range u0 = VR[0, int_oo] (update)
I0130 20:24:36.114000 634 torch/fx/experimental/symbolic_shapes.py:6281] [20/0] runtime_assert u0 >= 0 [guard added] torch._check(a >= 0)  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:722 in forward (_dynamo/utils.py:2586 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 0"
V0130 20:24:36.117000 634 torch/fx/experimental/symbolic_shapes.py:5802] [20/0] _update_var_to_range u0 = VR[0, 59] (update)
I0130 20:24:36.118000 634 torch/fx/experimental/symbolic_shapes.py:6281] [20/0] runtime_assert u0 < 60 [guard added] torch._check(a < y.shape[0])  # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:723 in forward (_dynamo/utils.py:2586 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 < 60"
V0130 20:24:36.120000 634 torch/fx/experimental/symbolic_shapes.py:6412] [20/0] eval -u0 > 60 == False [statically known]
V0130 20:24:36.120000 634 torch/fx/experimental/symbolic_shapes.py:6412] [20/0] eval u0 >= 60 == False [statically known]
V0130 20:24:36.121000 634 torch/fx/experimental/symbolic_shapes.py:6412] [20/0] eval u0 >= 0 == True [statically known]
V0130 20:24:36.124000 634 torch/fx/experimental/symbolic_shapes.py:6614] [20/0] runtime_assert u0 >= 0 == True [statically known]
V0130 20:24:36.125000 634 torch/fx/experimental/symbolic_shapes.py:6614] [20/0] runtime_assert u0 <= 59 == True [statically known]
V0130 20:24:36.126000 634 torch/fx/experimental/symbolic_shapes.py:6614] [20/0] runtime_assert u0 < 60 == True [statically known]
I0130 20:24:36.129000 634 torch/fx/experimental/symbolic_shapes.py:4547] [20/0] produce_guards
V0130 20:24:36.130000 634 torch/fx/experimental/symbolic_shapes.py:4755] [20/0] track_symint L['x'].storage_offset() 0 None
V0130 20:24:36.130000 634 torch/fx/experimental/symbolic_shapes.py:4755] [20/0] track_symint L['y'].size()[0] 60 None
V0130 20:24:36.130000 634 torch/fx/experimental/symbolic_shapes.py:4755] [20/0] track_symint L['y'].stride()[0] 1 None
V0130 20:24:36.131000 634 torch/fx/experimental/symbolic_shapes.py:4755] [20/0] track_symint L['y'].storage_offset() 0 None
I0130 20:24:36.149000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u1 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0130 20:24:36.149000 634 torch/fx/experimental/symbolic_shapes.py:970] compute_unbacked_bindings [u1]
V0130 20:24:36.150000 634 torch/fx/experimental/symbolic_shapes.py:5802] _update_var_to_range u1 = VR[0, 59] (update)
I0130 20:24:36.150000 634 torch/fx/experimental/symbolic_shapes.py:5963] set_replacement u1 = u0 (rename_unbacked_to) VR[0, 59]
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "f32[60]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:721 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None
            ge_1: "Sym(u0 >= 0)" = item >= 0
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'");  ge_1 = _assert_scalar_default = None
            le_1: "Sym(u0 <= 59)" = item <= 59
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 59 on node 'le_1'");  le_1 = _assert_scalar_default_1 = None

             #
            lt_1: "Sym(u0 < 60)" = item < 60
            _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(lt_1, "Runtime assertion failed for expression u0 < 60 on node 'lt_1'");  lt_1 = _assert_scalar_default_2 = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:724 in forward, code: return y[a]
            select: "f32[]" = torch.ops.aten.select.int(y, 0, item);  y = item = None
            return (select,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='select'), target=None)])
Range constraints: {u0: VR[0, 59], u1: VR[0, 59]}

Specialized values

Another category of data-dependent error happens when the program attempts to extract a concrete data-dependent integer/float value while tracing. This looks something like “Could not extract specialized integer from data-dependent expression”, and is analogous to the previous class of errors - if these occur when attempting to evaluate concrete integer/float values, data-dependent guard errors arise with evaluating concrete boolean values.

This error typically occurs when there is an explicit or implicit int() cast on a data-dependent expression. For example, this list comprehension has a range() call that implicitly does an int() cast on the size of the list:

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        b = torch.cat([y for y in range(a)], dim=0)
        return b + int(a)

inps = (
    torch.tensor(32),
    torch.randn(60),
)
try:
    export(Foo(), inps, strict=False)
except Exception:
    tb.print_exc()
I0130 20:24:36.168000 634 torch/fx/experimental/symbolic_shapes.py:3192] create_env
I0130 20:24:36.174000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0130 20:24:36.175000 634 torch/fx/experimental/symbolic_shapes.py:970] compute_unbacked_bindings [u0]
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727] Data dependent variable 'u0' allocated at:
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/bin/sphinx-build", line 8, in <module>
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     sys.exit(main())
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 288, in main
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return make_main(argv)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 193, in make_main
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return make_mode.run_make_mode(argv[1:])
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 160, in run_make_mode
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return make.run_generic_build(args[0])
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 148, in run_generic_build
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return build_main(args + opts)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 272, in build_main
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     app = Sphinx(args.sourcedir, args.confdir, args.outputdir,
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 256, in __init__
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     self._init_builder()
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 314, in _init_builder
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     self.events.emit('builder-inited')
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx/events.py", line 94, in emit
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     results.append(listener.handler(self.app, *args))
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 491, in generate_gallery_rst
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     ) = generate_dir_rst(
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 431, in generate_dir_rst
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     intro, title, cost = generate_file_rst(
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1027, in generate_file_rst
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     output_blocks, time_elapsed = execute_script(script_blocks,
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 945, in execute_script
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     output_blocks.append(execute_code_block(
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 810, in execute_code_block
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     is_last_expr, mem_max = _exec_and_get_memory(
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 676, in _exec_and_get_memory
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     mem_max, _ = gallery_conf['call_memory'](
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 223, in call_memory
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return 0., func()
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 600, in __call__
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     exec(self.code, self.fake_main.__dict__)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 756, in <module>
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     export(Foo(), inps, strict=False)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return _export(
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     ep = fn(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return fn(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return _export_for_training(
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     ep = fn(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return fn(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     export_artifact = export_func(  # type: ignore[operator]
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1772, in _non_strict_export
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     aten_export_artifact = _to_aten_func(  # type: ignore[operator]
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1564, in _export_to_aten_ir_make_fx
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     gm, graph_signature = transform(_make_fx_helper)(
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1702, in _aot_export_non_strict
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1485, in _make_fx_helper
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     gm = make_fx(
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2196, in wrapped
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return make_fx_tracer.trace(f, *args)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2134, in trace
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return self._trace_inner(f, *args)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2105, in _trace_inner
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     t = dispatch_trace(
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 32, in inner
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return disable_fn(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return fn(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1138, in dispatch_trace
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1694, in trace
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     res = super().trace(root, concrete_args)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return fn(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 843, in trace
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     (self.create_arg(fn(*args)),),
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1193, in wrapped
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     out = f(*tensors)  # type:ignore[call-arg]
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "<string>", line 1, in <lambda>
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1469, in wrapped_fn
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return tuple(flat_fn(*args))
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     tree_out = fn(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 879, in functional_call
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     out = mod(*args[params_len:], **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 821, in module_call_wrapper
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return self.call_module(mod, forward, args, kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1764, in call_module
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return Tracer.call_module(self, m, forward, args, kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 539, in call_module
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     ret_val = forward(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 814, in forward
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return _orig_module_call(mod, *args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return self._call_impl(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return forward_call(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1689, in forward
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     tree_out = mod(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 821, in module_call_wrapper
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return self.call_module(mod, forward, args, kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1764, in call_module
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return Tracer.call_module(self, m, forward, args, kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 539, in call_module
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     ret_val = forward(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 814, in forward
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return _orig_module_call(mod, *args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return self._call_impl(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return forward_call(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 747, in forward
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     a = x.item()
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1241, in __torch_function__
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return func(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1288, in __torch_function__
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return func(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 557, in __torch_function__
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return func(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 840, in handler
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return torch._library.utils.handle_dispatch_mode(
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_library/utils.py", line 295, in handle_dispatch_mode
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 21, in wrapper
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return fn(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1343, in __torch_dispatch__
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return proxy_call(self, func, self.pre_dispatch, args, kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 912, in proxy_call
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     out = func(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 723, in __call__
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return self._op(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 21, in wrapper
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return fn(*args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return self.dispatch(func, types, args, kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return self._cached_dispatch_impl(func, types, args, kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1386, in _cached_dispatch_impl
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     output = self._dispatch_impl(func, types, args, kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2354, in _dispatch_impl
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     op_impl_out = op_impl(self, func, *args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 160, in dispatch_to_op_implementations_dict
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 403, in local_scalar_dense
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     r = fake_mode.shape_env.create_unbacked_symint()
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]     return retlog(fn(*args, **kwargs))
V0130 20:24:36.176000 634 torch/fx/experimental/symbolic_shapes.py:5727]
W0130 20:24:36.184000 634 torch/fx/experimental/symbolic_shapes.py:6307] failed during evaluate_expr(u0, hint=None, size_oblivious=False, forcing_spec=False
E0130 20:24:36.184000 634 torch/fx/experimental/recording.py:299] failed while running evaluate_expr(*(u0, None), **{'fx_node': False})
E0130 20:24:36.184000 634 torch/fx/experimental/recording.py:299] Traceback (most recent call last):
E0130 20:24:36.184000 634 torch/fx/experimental/recording.py:299]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
E0130 20:24:36.184000 634 torch/fx/experimental/recording.py:299]     return retlog(fn(*args, **kwargs))
E0130 20:24:36.184000 634 torch/fx/experimental/recording.py:299]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
E0130 20:24:36.184000 634 torch/fx/experimental/recording.py:299]     return self._evaluate_expr(
E0130 20:24:36.184000 634 torch/fx/experimental/recording.py:299]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
E0130 20:24:36.184000 634 torch/fx/experimental/recording.py:299]     raise self._make_data_dependent_error(
E0130 20:24:36.184000 634 torch/fx/experimental/recording.py:299] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: none)
E0130 20:24:36.184000 634 torch/fx/experimental/recording.py:299]
E0130 20:24:36.184000 634 torch/fx/experimental/recording.py:299] Caused by: (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:748 in forward)
E0130 20:24:36.184000 634 torch/fx/experimental/recording.py:299] For more information, run with TORCH_LOGS="dynamic"
E0130 20:24:36.184000 634 torch/fx/experimental/recording.py:299] For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
E0130 20:24:36.184000 634 torch/fx/experimental/recording.py:299] If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
E0130 20:24:36.184000 634 torch/fx/experimental/recording.py:299] For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
E0130 20:24:36.184000 634 torch/fx/experimental/recording.py:299]
E0130 20:24:36.184000 634 torch/fx/experimental/recording.py:299] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 756, in <module>
    export(Foo(), inps, strict=False)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 368, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1970, in _export
    return _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1834, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1772, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1564, in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1702, in _aot_export_non_strict
    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1485, in _make_fx_helper
    gm = make_fx(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2196, in wrapped
    return make_fx_tracer.trace(f, *args)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2134, in trace
    return self._trace_inner(f, *args)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2105, in _trace_inner
    t = dispatch_trace(
  File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 32, in inner
    return disable_fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1138, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1694, in trace
    res = super().trace(root, concrete_args)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 843, in trace
    (self.create_arg(fn(*args)),),
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1193, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
  File "<string>", line 1, in <lambda>
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1469, in wrapped_fn
    return tuple(flat_fn(*args))
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 879, in functional_call
    out = mod(*args[params_len:], **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 821, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1764, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 539, in call_module
    ret_val = forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 814, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1689, in forward
    tree_out = mod(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 821, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1764, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 539, in call_module
    ret_val = forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 814, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 748, in forward
    b = torch.cat([y for y in range(a)], dim=0)
  File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 427, in __index__
    return self.node.int_()
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 445, in int_
    return self.guard_int("", 0)  # NB: uses Python backtrace
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 492, in guard_int
    r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper
    return retlog(fn(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
    return self._evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: none)

Caused by: (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:748 in forward)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

For these errors, some basic options you have are:

  1. Avoid unnecessary int() cast calls, in this case the int(a) in the return statement.

  2. Use torch._check() calls; unfortunately all you may be able to do in this case is specialize (with torch._check(a == 60)).

  3. Rewrite the offending code at a higher level. For example, the list comprehension is semantically a repeat() op, which doesn’t involve an int() cast. The following rewrite avoids data-dependent errors:

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        b = y.unsqueeze(0).repeat(a, 1)
        return b + a

inps = (
    torch.tensor(32),
    torch.randn(60),
)
ep = export(Foo(), inps, strict=False)
print(ep)
I0130 20:24:36.194000 634 torch/fx/experimental/symbolic_shapes.py:3192] create_env
I0130 20:24:36.199000 634 torch/fx/experimental/symbolic_shapes.py:4103] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:403 in local_scalar_dense)
I0130 20:24:36.200000 634 torch/fx/experimental/symbolic_shapes.py:970] compute_unbacked_bindings [u0]
V0130 20:24:36.204000 634 torch/fx/experimental/symbolic_shapes.py:5802] _update_var_to_range u0 = VR[0, int_oo] (update)
I0130 20:24:36.205000 634 torch/fx/experimental/symbolic_shapes.py:6281] runtime_assert u0 >= 0 [guard added] (_refs/__init__.py:4800 in new_empty), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 0"
V0130 20:24:36.207000 634 torch/fx/experimental/symbolic_shapes.py:6412] eval Eq(u0, 0) == False [statically known]
V0130 20:24:36.210000 634 torch/fx/experimental/symbolic_shapes.py:6412] eval Eq(u0, 1) == False [statically known]
V0130 20:24:36.210000 634 torch/fx/experimental/symbolic_shapes.py:6614] runtime_assert True == True [statically known]
I0130 20:24:36.214000 634 torch/fx/experimental/symbolic_shapes.py:4547] produce_guards
V0130 20:24:36.215000 634 torch/fx/experimental/symbolic_shapes.py:4755] track_symint L['args'][0][0].storage_offset() 0 None
V0130 20:24:36.215000 634 torch/fx/experimental/symbolic_shapes.py:4755] track_symint L['args'][0][1].size()[0] 60 None
V0130 20:24:36.215000 634 torch/fx/experimental/symbolic_shapes.py:4755] track_symint L['args'][0][1].stride()[0] 1 None
V0130 20:24:36.215000 634 torch/fx/experimental/symbolic_shapes.py:4755] track_symint L['args'][0][1].storage_offset() 0 None
V0130 20:24:36.217000 634 torch/fx/experimental/symbolic_shapes.py:6614] runtime_assert u0 >= 0 == True [statically known]
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "f32[60]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:769 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None

             #
            sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item);  sym_constrain_range_for_size_default = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:769 in forward, code: a = x.item()
            ge: "Sym(u0 >= 0)" = item >= 0
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'");  ge = _assert_scalar_default = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:770 in forward, code: b = y.unsqueeze(0).repeat(a, 1)
            unsqueeze: "f32[1, 60]" = torch.ops.aten.unsqueeze.default(y, 0);  y = None
            repeat: "f32[u0, 60]" = torch.ops.aten.repeat.default(unsqueeze, [item, 1]);  unsqueeze = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:771 in forward, code: return b + a
            add: "f32[u0, 60]" = torch.ops.aten.add.Tensor(repeat, item);  repeat = item = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {u0: VR[0, int_oo]}

Data-dependent errors can be much more involved, and there are many more options in your toolkit to deal with them: torch._check_is_size(), guard_size_oblivious(), or real-tensor tracing, as starters. For more in-depth guides, please refer to the Export Programming Model, or Dealing with GuardOnDataDependentSymNode errors.

Custom Ops

torch.export can export PyTorch programs with custom operators. Please refer to this page on how to author a custom operator in either C++ or Python.

The following is an example of registering a custom operator in python to be used by torch.export. The important thing to note is that the custom op must have a FakeTensor kernel.

@torch.library.custom_op("my_custom_library::custom_op", mutates_args={})
def custom_op(x: torch.Tensor) -> torch.Tensor:
    print("custom_op called!")
    return torch.relu(x)

@custom_op.register_fake
def custom_op_meta(x):
    # Returns an empty tensor with the same shape as the expected output
    return torch.empty_like(x)

Here is an example of exporting a program with the custom op.

class CustomOpExample(torch.nn.Module):
    def forward(self, x):
        x = torch.sin(x)
        x = torch.ops.my_custom_library.custom_op(x)
        x = torch.cos(x)
        return x

exported_custom_op_example = export(CustomOpExample(), (torch.randn(3, 3),))
print(exported_custom_op_example)
print(exported_custom_op_example.module()(torch.randn(3, 3)))
I0130 20:24:36.234000 634 torch/fx/experimental/symbolic_shapes.py:3192] [21/0] create_env
I0130 20:24:36.244000 634 torch/fx/experimental/symbolic_shapes.py:4547] [21/0] produce_guards
V0130 20:24:36.244000 634 torch/fx/experimental/symbolic_shapes.py:4755] [21/0] track_symint L['x'].size()[0] 3 None
V0130 20:24:36.244000 634 torch/fx/experimental/symbolic_shapes.py:4755] [21/0] track_symint L['x'].size()[1] 3 None
V0130 20:24:36.245000 634 torch/fx/experimental/symbolic_shapes.py:4755] [21/0] track_symint L['x'].stride()[0] 3 None
V0130 20:24:36.245000 634 torch/fx/experimental/symbolic_shapes.py:4755] [21/0] track_symint L['x'].stride()[1] 1 None
V0130 20:24:36.245000 634 torch/fx/experimental/symbolic_shapes.py:4755] [21/0] track_symint L['x'].storage_offset() 0 None
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]"):
             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:812 in forward, code: x = torch.sin(x)
            sin: "f32[3, 3]" = torch.ops.aten.sin.default(x);  x = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:813 in forward, code: x = torch.ops.my_custom_library.custom_op(x)
            custom_op: "f32[3, 3]" = torch.ops.my_custom_library.custom_op.default(sin);  sin = None

             # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:814 in forward, code: x = torch.cos(x)
            cos: "f32[3, 3]" = torch.ops.aten.cos.default(custom_op);  custom_op = None
            return (cos,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cos'), target=None)])
Range constraints: {}

custom_op called!
tensor([[0.5499, 0.6889, 0.7180],
        [0.5413, 1.0000, 1.0000],
        [0.8332, 0.5524, 1.0000]])

Note that in the ExportedProgram, the custom operator is included in the graph.

IR/Decompositions

The graph produced by torch.export returns a graph containing only ATen operators, which are the basic unit of computation in PyTorch. As there are over 3000 ATen operators, export provides a way to narrow down the operator set used in the graph based on certain characteristics, creating different IRs.

By default, export produces the most generic IR which contains all ATen operators, including both functional and non-functional operators. A functional operator is one that does not contain any mutations or aliasing of the inputs. You can find a list of all ATen operators here and you can inspect if an operator is functional by checking op._schema.is_mutable, for example:

print(torch.ops.aten.add.Tensor._schema.is_mutable)
print(torch.ops.aten.add_.Tensor._schema.is_mutable)
False
True

This generic IR can be used to train in eager PyTorch Autograd. This IR can be more explicitly reached through the API torch.export.export_for_training, which was introduced in PyTorch 2.5, but calling torch.export.export should produce the same graph as of PyTorch 2.6.

class DecompExample(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, 1, 1)
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return (x,)

ep_for_training = torch.export.export_for_training(DecompExample(), (torch.randn(1, 1, 3, 3),))
print(ep_for_training.graph)
I0130 20:24:36.272000 634 torch/fx/experimental/symbolic_shapes.py:3192] [22/0] create_env
I0130 20:24:36.301000 634 torch/fx/experimental/symbolic_shapes.py:4547] [22/0] produce_guards
V0130 20:24:36.302000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].size()[0] 1 None
V0130 20:24:36.302000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].size()[1] 1 None
V0130 20:24:36.302000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].size()[2] 3 None
V0130 20:24:36.303000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].size()[3] 3 None
V0130 20:24:36.303000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].stride()[0] 9 None
V0130 20:24:36.303000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].stride()[1] 9 None
V0130 20:24:36.303000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].stride()[2] 3 None
V0130 20:24:36.304000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].stride()[3] 1 None
V0130 20:24:36.304000 634 torch/fx/experimental/symbolic_shapes.py:4755] [22/0] track_symint L['x'].storage_offset() 0 None
graph():
    %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
    %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
    %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
    %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
    %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
    %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
    %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
    %x : [num_users=1] = placeholder[target=x]
    %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv_weight, %p_conv_bias), kwargs = {})
    %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
    %batch_norm : [num_users=1] = call_function[target=torch.ops.aten.batch_norm.default](args = (%conv2d, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05, True), kwargs = {})
    return (batch_norm,)

We can then lower this exported program to an operator set which only contains functional ATen operators through the API run_decompositions, which decomposes the ATen operators into the ones specified in the decomposition table, and functionalizes the graph. By specifying an empty set, we’re only performing functionalization, and does not do any additional decompositions. This results in an IR which contains ~2000 operators (instead of the 3000 operators above), and is ideal for inference cases.

ep_for_inference = ep_for_training.run_decompositions(decomp_table={})
print(ep_for_inference.graph)
graph():
    %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
    %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
    %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
    %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
    %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
    %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
    %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
    %x : [num_users=1] = placeholder[target=x]
    %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv_weight, %p_conv_bias), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
    %_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%conv2d, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {})
    return (getitem_3, getitem_4, add, getitem)

As we can see, the previously mutable operator, torch.ops.aten.add_.default has now been replaced with torch.ops.aten.add.default, a l operator.

We can also further lower this exported program to an operator set which only contains the Core ATen Operator Set, which is a collection of only ~180 operators. This IR is optimal for backends who do not want to reimplement all ATen operators.

from torch.export import default_decompositions

core_aten_decomp_table = default_decompositions()
core_aten_ep = ep_for_training.run_decompositions(decomp_table=core_aten_decomp_table)
print(core_aten_ep.graph)
graph():
    %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
    %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
    %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
    %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
    %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
    %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
    %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
    %x : [num_users=1] = placeholder[target=x]
    %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %p_conv_weight, %p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
    %_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%convolution, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {})
    return (getitem_3, getitem_4, add, getitem)

We now see that torch.ops.aten.conv2d.default has been decomposed into torch.ops.aten.convolution.default. This is because convolution is a more “core” operator, as operations like conv1d and conv2d can be implemented using the same op.

We can also specify our own decomposition behaviors:

my_decomp_table = torch.export.default_decompositions()

def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
    return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)

my_decomp_table[torch.ops.aten.conv2d.default] = my_awesome_custom_conv2d_function
my_ep = ep_for_training.run_decompositions(my_decomp_table)
print(my_ep.graph)
graph():
    %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
    %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
    %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
    %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
    %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
    %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
    %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
    %x : [num_users=1] = placeholder[target=x]
    %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %p_conv_weight, %p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convolution, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
    %_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%mul, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {})
    return (getitem_3, getitem_4, add, getitem)

Notice that instead of torch.ops.aten.conv2d.default being decomposed into torch.ops.aten.convolution.default, it is now decomposed into torch.ops.aten.convolution.default and torch.ops.aten.mul.Tensor, which matches our custom decomposition rule.

ExportDB

torch.export will only ever export a single computation graph from a PyTorch program. Because of this requirement, there will be Python or PyTorch features that are not compatible with torch.export, which will require users to rewrite parts of their model code. We have seen examples of this earlier in the tutorial – for example, rewriting if-statements using cond.

ExportDB is the standard reference that documents supported and unsupported Python/PyTorch features for torch.export. It is essentially a list a program samples, each of which represents the usage of one particular Python/PyTorch feature and its interaction with torch.export. Examples are also tagged by category so that they can be more easily searched.

For example, let’s use ExportDB to get a better understanding of how the predicate works in the cond operator. We can look at the example called cond_predicate, which has a torch.cond tag. The example code looks like:

def cond_predicate(x):
    """
    The conditional statement (aka predicate) passed to ``cond()`` must be one of the following:
    - ``torch.Tensor`` with a single element
    - boolean expression
    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """
    pred = x.dim() > 2 and x.shape[2] > 10
    return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])

More generally, ExportDB can be used as a reference when one of the following occurs:

  1. Before attempting torch.export, you know ahead of time that your model uses some tricky Python/PyTorch features and you want to know if torch.export covers that feature.

  2. When attempting torch.export, there is a failure and it’s unclear how to work around it.

ExportDB is not exhaustive, but is intended to cover all use cases found in typical PyTorch code. Feel free to reach out if there is an important Python/PyTorch feature that should be added to ExportDB or supported by torch.export.

Running the Exported Program

As torch.export is only a graph capturing mechanism, calling the artifact produced by torch.export eagerly will be equivalent to running the eager module. To optimize the execution of the Exported Program, we can pass this exported artifact to backends such as Inductor through torch.compile, AOTInductor, or TensorRT.

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 3)

    def forward(self, x):
        x = self.linear(x)
        return x

inp = torch.randn(2, 3, device="cuda")
m = M().to(device="cuda")
ep = torch.export.export(m, (inp,))

# Run it eagerly
res = ep.module()(inp)
print(res)

# Run it with torch.compile
res = torch.compile(ep.module(), backend="inductor")(inp)
print(res)
I0130 20:24:36.931000 634 torch/fx/experimental/symbolic_shapes.py:3192] [23/0] create_env
I0130 20:24:36.946000 634 torch/fx/experimental/symbolic_shapes.py:4547] [23/0] produce_guards
V0130 20:24:36.946000 634 torch/fx/experimental/symbolic_shapes.py:4755] [23/0] track_symint L['x'].size()[0] 2 None
V0130 20:24:36.946000 634 torch/fx/experimental/symbolic_shapes.py:4755] [23/0] track_symint L['x'].size()[1] 3 None
V0130 20:24:36.947000 634 torch/fx/experimental/symbolic_shapes.py:4755] [23/0] track_symint L['x'].stride()[0] 3 None
V0130 20:24:36.947000 634 torch/fx/experimental/symbolic_shapes.py:4755] [23/0] track_symint L['x'].stride()[1] 1 None
V0130 20:24:36.947000 634 torch/fx/experimental/symbolic_shapes.py:4755] [23/0] track_symint L['x'].storage_offset() 0 None
tensor([[ 0.4830, -0.5149,  0.3888],
        [-0.9247,  0.8408, -0.2184]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
I0130 20:24:36.974000 634 torch/fx/experimental/symbolic_shapes.py:3192] [24/0] create_env
I0130 20:24:37.573000 634 torch/fx/experimental/symbolic_shapes.py:4547] [24/0] produce_guards
I0130 20:24:37.598000 634 torch/fx/experimental/symbolic_shapes.py:4547] [24/0] produce_guards
V0130 20:24:37.598000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['x'].size()[0] 2 None
V0130 20:24:37.598000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['x'].size()[1] 3 None
V0130 20:24:37.599000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['x'].stride()[0] 3 None
V0130 20:24:37.599000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['x'].stride()[1] 1 None
V0130 20:24:37.599000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['x'].storage_offset() 0 None
V0130 20:24:37.600000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['weight'].size()[0] 3 None
V0130 20:24:37.600000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['weight'].size()[1] 3 None
V0130 20:24:37.600000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['weight'].stride()[0] 3 None
V0130 20:24:37.600000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['weight'].stride()[1] 1 None
V0130 20:24:37.601000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['weight'].storage_offset() 0 None
V0130 20:24:37.601000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['bias'].size()[0] 3 None
V0130 20:24:37.601000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['bias'].stride()[0] 1 None
V0130 20:24:37.602000 634 torch/fx/experimental/symbolic_shapes.py:4755] [24/0] track_symint L['self']._modules['linear']._parameters['bias'].storage_offset() 0 None
V0130 20:24:37.602000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['x'].size()[0] == 2
V0130 20:24:37.602000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['x'].size()[1] == 3
V0130 20:24:37.603000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['x'].stride()[0] == 3
V0130 20:24:37.603000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['x'].stride()[1] == 1
V0130 20:24:37.603000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['x'].storage_offset() == 0
V0130 20:24:37.604000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['weight'].size()[0] == 3
V0130 20:24:37.604000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['weight'].size()[1] == 3
V0130 20:24:37.604000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['weight'].stride()[0] == 3
V0130 20:24:37.604000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['weight'].stride()[1] == 1
V0130 20:24:37.605000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['weight'].storage_offset() == 0
V0130 20:24:37.605000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['bias'].size()[0] == 3
V0130 20:24:37.605000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['bias'].stride()[0] == 1
V0130 20:24:37.606000 634 torch/fx/experimental/symbolic_shapes.py:4958] [24/0] Skipping guard L['self']._modules['linear']._parameters['bias'].storage_offset() == 0
tensor([[ 0.4830, -0.5149,  0.3888],
        [-0.9247,  0.8408, -0.2184]], device='cuda:0',
       grad_fn=<CompiledFunctionBackward>)
import torch._inductor

# Note: these APIs are subject to change
# Compile the exported program to a PT2 archive using ``AOTInductor``
with torch.no_grad():
    pt2_path = torch._inductor.aoti_compile_and_package(ep)

# Load and run the .so file in Python.
# To load and run it in a C++ environment, see:
# https://pytorch.org/docs/main/torch.compiler_aot_inductor.html
aoti_compiled = torch._inductor.aoti_load_package(pt2_path)
res = aoti_compiled(inp)

Conclusion

We introduced torch.export, the new PyTorch 2.X way to export single computation graphs from PyTorch programs. In particular, we demonstrate several code modifications and considerations (control flow ops, constraints, etc.) that need to be made in order to export a graph.

Total running time of the script: ( 0 minutes 2.806 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources