.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "intermediate/torch_export_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_intermediate_torch_export_tutorial.py: torch.export Tutorial =================================================== **Author:** William Wen, Zhengxu Chen, Angela Yi .. GENERATED FROM PYTHON SOURCE LINES 10-29 .. warning:: ``torch.export`` and its related features are in prototype status and are subject to backwards compatibility breaking changes. This tutorial provides a snapshot of ``torch.export`` usage as of PyTorch 2.3. :func:`torch.export` is the PyTorch 2.X way to export PyTorch models into standardized model representations, intended to be run on different (i.e. Python-less) environments. The official documentation can be found `here `__. In this tutorial, you will learn how to use :func:`torch.export` to extract ``ExportedProgram``'s (i.e. single-graph representations) from PyTorch programs. We also detail some considerations/modifications that you may need to make in order to make your model compatible with ``torch.export``. **Contents** .. contents:: :local: .. GENERATED FROM PYTHON SOURCE LINES 32-60 Basic Usage ----------- ``torch.export`` extracts single-graph representations from PyTorch programs by tracing the target function, given example inputs. ``torch.export.export()`` is the main entry point for ``torch.export``. In this tutorial, ``torch.export`` and ``torch.export.export()`` are practically synonymous, though ``torch.export`` generally refers to the PyTorch 2.X export process, and ``torch.export.export()`` generally refers to the actual function call. The signature of ``torch.export.export()`` is: .. code-block:: python export( f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None, *, dynamic_shapes: Optional[Dict[str, Dict[int, Dim]]] = None ) -> ExportedProgram ``torch.export.export()`` traces the tensor computation graph from calling ``f(*args, **kwargs)`` and wraps it in an ``ExportedProgram``, which can be serialized or executed later with different inputs. Note that while the output ``ExportedGraph`` is callable and can be called in the same way as the original input callable, it is not a ``torch.nn.Module``. We will detail the ``dynamic_shapes`` argument later in the tutorial. .. GENERATED FROM PYTHON SOURCE LINES 60-78 .. code-block:: default import torch from torch.export import export class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.lin = torch.nn.Linear(100, 10) def forward(self, x, y): return torch.nn.functional.relu(self.lin(x + y), inplace=True) mod = MyModule() exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100))) print(type(exported_mod)) print(exported_mod.module()(torch.randn(8, 100), torch.randn(8, 100))) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[0.8632, 0.8407, 0.0407, 0.0000, 0.4132, 0.0000, 0.0000, 0.1538, 0.6111, 0.0000], [0.0000, 0.0000, 0.0273, 0.8057, 0.0000, 1.0162, 0.8042, 0.0000, 0.2660, 0.0000], [0.9481, 0.1396, 1.0225, 0.9563, 0.5832, 0.2546, 0.4095, 0.4591, 0.0000, 2.0053], [1.1300, 0.4873, 0.0000, 0.9663, 1.2275, 1.4015, 0.0000, 0.9444, 0.0000, 0.0000], [0.0000, 0.8724, 1.1648, 0.6867, 0.0000, 0.2833, 0.3202, 0.5848, 0.0000, 0.0833], [1.1311, 0.1324, 0.0000, 1.7842, 0.0000, 0.3474, 0.9916, 0.3571, 0.0000, 0.0000], [1.4348, 1.0570, 0.1771, 0.0000, 0.9510, 0.0000, 0.0000, 0.0000, 0.2618, 0.0000], [0.8853, 0.0000, 0.0000, 0.4486, 0.0000, 0.0000, 0.5841, 0.7604, 0.0000, 0.0000]], grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 79-90 Let's review some attributes of ``ExportedProgram`` that are of interest. The ``graph`` attribute is an `FX graph `__ traced from the function we exported, that is, the computation graph of all PyTorch operations. The FX graph has some important properties: - The operations are "ATen-level" operations. - The graph is "functionalized", meaning that no operations are mutations. The ``graph_module`` attribute is the ``GraphModule`` that wraps the ``graph`` attribute so that it can be ran as a ``torch.nn.Module``. .. GENERATED FROM PYTHON SOURCE LINES 90-94 .. code-block:: default print(exported_mod) print(exported_mod.graph_module) .. rst-class:: sphx-glr-script-out .. code-block:: none ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: "f32[10, 100]", arg1_1: "f32[10]", arg2_1: "f32[8, 100]", arg3_1: "f32[8, 100]"): # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:70 in forward, code: return torch.nn.functional.relu(self.lin(x + y), inplace=True) add: "f32[8, 100]" = torch.ops.aten.add.Tensor(arg2_1, arg3_1); arg2_1 = arg3_1 = None t: "f32[100, 10]" = torch.ops.aten.t.default(arg0_1); arg0_1 = None addmm: "f32[8, 10]" = torch.ops.aten.addmm.default(arg1_1, add, t); arg1_1 = add = t = None relu: "f32[8, 10]" = torch.ops.aten.relu.default(addmm); addmm = None return (relu,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target='lin.weight', persistent=None), InputSpec(kind=, arg=TensorArgument(name='arg1_1'), target='lin.bias', persistent=None), InputSpec(kind=, arg=TensorArgument(name='arg2_1'), target=None, persistent=None), InputSpec(kind=, arg=TensorArgument(name='arg3_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='relu'), target=None)]) Range constraints: {} GraphModule() def forward(self, arg0_1, arg1_1, arg2_1, arg3_1): add = torch.ops.aten.add.Tensor(arg2_1, arg3_1); arg2_1 = arg3_1 = None t = torch.ops.aten.t.default(arg0_1); arg0_1 = None addmm = torch.ops.aten.addmm.default(arg1_1, add, t); arg1_1 = add = t = None relu = torch.ops.aten.relu.default(addmm); addmm = None return (relu,) # To see more debug info, please use `graph_module.print_readable()` .. GENERATED FROM PYTHON SOURCE LINES 95-105 The printed code shows that FX graph only contains ATen-level ops (such as ``torch.ops.aten``) and that mutations were removed. For example, the mutating op ``torch.nn.functional.relu(..., inplace=True)`` is represented in the printed code by ``torch.ops.aten.relu.default``, which does not mutate. Future uses of input to the original mutating ``relu`` op are replaced by the additional new output of the replacement non-mutating ``relu`` op. Other attributes of interest in ``ExportedProgram`` include: - ``graph_signature`` -- the inputs, outputs, parameters, buffers, etc. of the exported graph. - ``range_constraints`` -- constraints, covered later .. GENERATED FROM PYTHON SOURCE LINES 105-108 .. code-block:: default print(exported_mod.graph_signature) .. rst-class:: sphx-glr-script-out .. code-block:: none ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target='lin.weight', persistent=None), InputSpec(kind=, arg=TensorArgument(name='arg1_1'), target='lin.bias', persistent=None), InputSpec(kind=, arg=TensorArgument(name='arg2_1'), target=None, persistent=None), InputSpec(kind=, arg=TensorArgument(name='arg3_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='relu'), target=None)]) .. GENERATED FROM PYTHON SOURCE LINES 109-111 See the ``torch.export`` `documentation `__ for more details. .. GENERATED FROM PYTHON SOURCE LINES 113-127 Graph Breaks ------------ Although ``torch.export`` shares components with ``torch.compile``, the key limitation of ``torch.export``, especially when compared to ``torch.compile``, is that it does not support graph breaks. This is because handling graph breaks involves interpreting the unsupported operation with default Python evaluation, which is incompatible with the export use case. Therefore, in order to make your model code compatible with ``torch.export``, you will need to modify your code to remove graph breaks. A graph break is necessary in cases such as: - data-dependent control flow .. GENERATED FROM PYTHON SOURCE LINES 127-140 .. code-block:: default class Bad1(torch.nn.Module): def forward(self, x): if x.sum() > 0: return torch.sin(x) return torch.cos(x) import traceback as tb try: export(Bad1(), (torch.randn(3, 3),)) except Exception: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 136, in export(Bad1(), (torch.randn(3, 3),)) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/__init__.py", line 174, in export return _export( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 635, in wrapper raise e File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 618, in wrapper ep = fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 83, in wrapper return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 860, in _export gm_torch_level = _export_to_torch_ir( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 347, in _export_to_torch_ir gm_torch_level, _ = torch._dynamo.export( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1311, in inner result_traced = opt_f(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors return callback(frame, cache_entry, hooks, frame_state, skip=1) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert return _compile( File "/opt/conda/envs/py_3.10/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper r = func(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner out_code = transform_code_object(code, transform) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object transformations(instructions, code_options) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform tracer.run() File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run super().run() File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run and self.step() File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step getattr(self, inst.opname)(inst) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 464, in inner raise exc.UserError( torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands from user code: File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 130, in forward if x.sum() > 0: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information .. GENERATED FROM PYTHON SOURCE LINES 141-142 - accessing tensor data with ``.data`` .. GENERATED FROM PYTHON SOURCE LINES 142-153 .. code-block:: default class Bad2(torch.nn.Module): def forward(self, x): x.data[0, 0] = 3 return x try: export(Bad2(), (torch.randn(3, 3),)) except Exception: tb.print_exc() .. GENERATED FROM PYTHON SOURCE LINES 154-155 - calling unsupported functions (such as many built-in functions) .. GENERATED FROM PYTHON SOURCE LINES 155-166 .. code-block:: default class Bad3(torch.nn.Module): def forward(self, x): x = x + 1 return x + id(x) try: export(Bad3(), (torch.randn(3, 3),)) except Exception: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 162, in export(Bad3(), (torch.randn(3, 3),)) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/__init__.py", line 174, in export return _export( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 635, in wrapper raise e File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 618, in wrapper ep = fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 83, in wrapper return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 860, in _export gm_torch_level = _export_to_torch_ir( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 347, in _export_to_torch_ir gm_torch_level, _ = torch._dynamo.export( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1311, in inner result_traced = opt_f(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors return callback(frame, cache_entry, hooks, frame_state, skip=1) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert return _compile( File "/opt/conda/envs/py_3.10/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper r = func(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner out_code = transform_code_object(code, transform) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object transformations(instructions, code_options) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform tracer.run() File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run super().run() File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run and self.step() File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step getattr(self, inst.opname)(inst) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 489, in wrapper return inner_fn(self, inst) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1219, in CALL_FUNCTION self.call_function(fn, args, {}) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 674, in call_function self.push(fn.call_function(self, args, kwargs)) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 687, in call_function result = handler(tx, *args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 1520, in call_id unimplemented(f"call_id with args {args}") File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 190, in unimplemented raise Unsupported(msg) torch._dynamo.exc.Unsupported: call_id with args (TensorVariable(),) from user code: File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 159, in forward return x + id(x) Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information .. GENERATED FROM PYTHON SOURCE LINES 167-168 - unsupported Python language features (e.g. throwing exceptions, match statements) .. GENERATED FROM PYTHON SOURCE LINES 168-183 .. code-block:: default class Bad4(torch.nn.Module): def forward(self, x): try: x = x + 1 raise RuntimeError("bad") except: x = x + 2 return x try: export(Bad4(), (torch.randn(3, 3),)) except Exception: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 179, in export(Bad4(), (torch.randn(3, 3),)) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/__init__.py", line 174, in export return _export( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 635, in wrapper raise e File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 618, in wrapper ep = fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 83, in wrapper return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 860, in _export gm_torch_level = _export_to_torch_ir( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 347, in _export_to_torch_ir gm_torch_level, _ = torch._dynamo.export( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1311, in inner result_traced = opt_f(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors return callback(frame, cache_entry, hooks, frame_state, skip=1) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert return _compile( File "/opt/conda/envs/py_3.10/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper r = func(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner out_code = transform_code_object(code, transform) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object transformations(instructions, code_options) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform tracer.run() File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run super().run() File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run and self.step() File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step getattr(self, inst.opname)(inst) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 489, in wrapper return inner_fn(self, inst) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1219, in CALL_FUNCTION self.call_function(fn, args, {}) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 674, in call_function self.push(fn.call_function(self, args, kwargs)) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 730, in call_function return super().call_function(tx, args, kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 349, in call_function unimplemented(f"call_function {self} {args} {kwargs}") File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 190, in unimplemented raise Unsupported(msg) torch._dynamo.exc.Unsupported: call_function BuiltinVariable(RuntimeError) [ConstantVariable()] {} from user code: File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 173, in forward raise RuntimeError("bad") Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information .. GENERATED FROM PYTHON SOURCE LINES 184-202 Non-Strict Export ----------------- To trace the program, ``torch.export`` uses TorchDynamo, a byte code analysis engine, to symbolically analyze the Python code and build a graph based on the results. This analysis allows ``torch.export`` to provide stronger guarantees about safety, but not all Python code is supported, causing these graph breaks. To address this issue, in PyTorch 2.3, we introduced a new mode of exporting called non-strict mode, where we trace through the program using the Python interpreter executing it exactly as it would in eager mode, allowing us to skip over unsupported Python features. This is done through adding a ``strict=False`` flag. Looking at some of the previous examples which resulted in graph breaks: - Accessing tensor data with ``.data`` now works correctly .. GENERATED FROM PYTHON SOURCE LINES 202-211 .. code-block:: default class Bad2(torch.nn.Module): def forward(self, x): x.data[0, 0] = 3 return x bad2_nonstrict = export(Bad2(), (torch.randn(3, 3),), strict=False) print(bad2_nonstrict.module()(torch.ones(3, 3))) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[3., 1., 1.], [1., 1., 1.], [1., 1., 1.]]) .. GENERATED FROM PYTHON SOURCE LINES 212-216 - Calling unsupported functions (such as many built-in functions) traces through, but in this case, ``id(x)`` gets specialized as a constant integer in the graph. This is because ``id(x)`` is not a tensor operation, so the operation is not recorded in the graph. .. GENERATED FROM PYTHON SOURCE LINES 216-226 .. code-block:: default class Bad3(torch.nn.Module): def forward(self, x): x = x + 1 return x + id(x) bad3_nonstrict = export(Bad3(), (torch.randn(3, 3),), strict=False) print(bad3_nonstrict) print(bad3_nonstrict.module()(torch.ones(3, 3))) .. rst-class:: sphx-glr-script-out .. code-block:: none ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: "f32[3, 3]"): # No stacktrace found for following nodes add: "f32[3, 3]" = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None add_1: "f32[3, 3]" = torch.ops.aten.add.Tensor(add, 140588796781856); add = None return (add_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='add_1'), target=None)]) Range constraints: {} tensor([[1.4059e+14, 1.4059e+14, 1.4059e+14], [1.4059e+14, 1.4059e+14, 1.4059e+14], [1.4059e+14, 1.4059e+14, 1.4059e+14]]) .. GENERATED FROM PYTHON SOURCE LINES 227-229 - Unsupported Python language features (such as throwing exceptions, match statements) now also get traced through. .. GENERATED FROM PYTHON SOURCE LINES 229-243 .. code-block:: default class Bad4(torch.nn.Module): def forward(self, x): try: x = x + 1 raise RuntimeError("bad") except: x = x + 2 return x bad4_nonstrict = export(Bad4(), (torch.randn(3, 3),), strict=False) print(bad4_nonstrict.module()(torch.ones(3, 3))) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[4., 4., 4.], [4., 4., 4.], [4., 4., 4.]]) .. GENERATED FROM PYTHON SOURCE LINES 244-246 However, there are still some features that require rewrites to the original module: .. GENERATED FROM PYTHON SOURCE LINES 248-254 Control Flow Ops ---------------- ``torch.export`` actually does support data-dependent control flow. But these need to be expressed using control flow ops. For example, we can fix the control flow example above using the ``cond`` op, like so: .. GENERATED FROM PYTHON SOURCE LINES 254-269 .. code-block:: default from functorch.experimental.control_flow import cond class Bad1Fixed(torch.nn.Module): def forward(self, x): def true_fn(x): return torch.sin(x) def false_fn(x): return torch.cos(x) return cond(x.sum() > 0, true_fn, false_fn, [x]) exported_bad1_fixed = export(Bad1Fixed(), (torch.randn(3, 3),)) print(exported_bad1_fixed.module()(torch.ones(3, 3))) print(exported_bad1_fixed.module()(-torch.ones(3, 3))) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[0.8415, 0.8415, 0.8415], [0.8415, 0.8415, 0.8415], [0.8415, 0.8415, 0.8415]]) tensor([[0.5403, 0.5403, 0.5403], [0.5403, 0.5403, 0.5403], [0.5403, 0.5403, 0.5403]]) .. GENERATED FROM PYTHON SOURCE LINES 270-281 There are limitations to ``cond`` that one should be aware of: - The predicate (i.e. ``x.sum() > 0``) must result in a boolean or a single-element tensor. - The operands (i.e. ``[x]``) must be tensors. - The branch function (i.e. ``true_fn`` and ``false_fn``) signature must match with the operands and they must both return a single tensor with the same metadata (for example, ``dtype``, ``shape``, etc.). - Branch functions cannot mutate input or global variables. - Branch functions cannot access closure variables, except for ``self`` if the function is defined in the scope of a method. For more details about ``cond``, check out the `cond documentation `__. .. GENERATED FROM PYTHON SOURCE LINES 283-302 .. [NOTE] map is not documented at the moment We can also use ``map``, which applies a function across the first dimension of the first tensor argument. from functorch.experimental.control_flow import map def map_example(xs): def map_fn(x, const): def true_fn(x): return x + const def false_fn(x): return x - const return control_flow.cond(x.sum() > 0, true_fn, false_fn, [x]) return control_flow.map(map_fn, xs, torch.tensor([2.0])) exported_map_example= export(map_example, (torch.randn(4, 3),)) inp = torch.cat((torch.ones(2, 3), -torch.ones(2, 3))) print(exported_map_example(inp)) .. GENERATED FROM PYTHON SOURCE LINES 304-312 Constraints/Dynamic Shapes -------------------------- Ops can have different specializations/behaviors for different tensor shapes, so by default, ``torch.export`` requires inputs to ``ExportedProgram`` to have the same shape as the respective example inputs given to the initial ``torch.export.export()`` call. If we try to run the ``ExportedProgram`` in the example below with a tensor with a different shape, we get an error: .. GENERATED FROM PYTHON SOURCE LINES 312-329 .. code-block:: default class MyModule2(torch.nn.Module): def __init__(self): super().__init__() self.lin = torch.nn.Linear(100, 10) def forward(self, x, y): return torch.nn.functional.relu(self.lin(x + y), inplace=True) mod2 = MyModule2() exported_mod2 = export(mod2, (torch.randn(8, 100), torch.randn(8, 100))) try: exported_mod2.module()(torch.randn(10, 100), torch.randn(10, 100)) except Exception: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 325, in exported_mod2.module()(torch.randn(10, 100), torch.randn(10, 100)) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 737, in call_wrapped return self._wrapped_call(self, *args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 317, in __call__ raise e File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 304, in __call__ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_unlift.py", line 32, in _check_input_constraints_pre_hook return _check_input_constraints_for_graph( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_export/utils.py", line 129, in _check_input_constraints_for_graph raise RuntimeError( RuntimeError: Expected input at *args[0].shape[0] to be equal to 8, but got 10 .. GENERATED FROM PYTHON SOURCE LINES 330-351 We can relax this constraint using the ``dynamic_shapes`` argument of ``torch.export.export()``, which allows us to specify, using ``torch.export.Dim`` (`documentation `__), which dimensions of the input tensors are dynamic. For each tensor argument of the input callable, we can specify a mapping from the dimension to a ``torch.export.Dim``. A ``torch.export.Dim`` is essentially a named symbolic integer with optional minimum and maximum bounds. Then, the format of ``torch.export.export()``'s ``dynamic_shapes`` argument is a mapping from the input callable's tensor argument names, to dimension --> dim mappings as described above. If there is no ``torch.export.Dim`` given to a tensor argument's dimension, then that dimension is assumed to be static. The first argument of ``torch.export.Dim`` is the name for the symbolic integer, used for debugging. Then we can specify an optional minimum and maximum bound (inclusive). Below, we show a usage example. In the example below, our input ``inp1`` has an unconstrained first dimension, but the size of the second dimension must be in the interval [4, 18]. .. GENERATED FROM PYTHON SOURCE LINES 351-386 .. code-block:: default from torch.export import Dim inp1 = torch.randn(10, 10, 2) class DynamicShapesExample1(torch.nn.Module): def forward(self, x): x = x[:, 2:] return torch.relu(x) inp1_dim0 = Dim("inp1_dim0") inp1_dim1 = Dim("inp1_dim1", min=4, max=18) dynamic_shapes1 = { "x": {0: inp1_dim0, 1: inp1_dim1}, } exported_dynamic_shapes_example1 = export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1) print(exported_dynamic_shapes_example1.module()(torch.randn(5, 5, 2))) try: exported_dynamic_shapes_example1.module()(torch.randn(8, 1, 2)) except Exception: tb.print_exc() try: exported_dynamic_shapes_example1.module()(torch.randn(8, 20, 2)) except Exception: tb.print_exc() try: exported_dynamic_shapes_example1.module()(torch.randn(8, 8, 3)) except Exception: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[[0.0000, 0.0000], [0.8850, 1.2371], [0.0000, 0.0000]], [[0.0000, 0.0000], [0.0000, 0.3487], [0.2520, 1.2545]], [[0.5863, 0.2831], [0.0000, 0.4669], [0.1059, 0.0000]], [[0.7833, 0.0000], [0.4480, 0.0523], [0.0000, 0.0000]], [[0.9306, 0.0000], [0.0000, 0.7895], [0.1160, 0.0000]]]) Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 372, in exported_dynamic_shapes_example1.module()(torch.randn(8, 1, 2)) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 737, in call_wrapped return self._wrapped_call(self, *args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 317, in __call__ raise e File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 304, in __call__ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_unlift.py", line 32, in _check_input_constraints_pre_hook return _check_input_constraints_for_graph( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_export/utils.py", line 117, in _check_input_constraints_for_graph raise RuntimeError( RuntimeError: Expected input at *args[0].shape[1] to be >= 4, but got 1 Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 377, in exported_dynamic_shapes_example1.module()(torch.randn(8, 20, 2)) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 737, in call_wrapped return self._wrapped_call(self, *args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 317, in __call__ raise e File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 304, in __call__ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_unlift.py", line 32, in _check_input_constraints_pre_hook return _check_input_constraints_for_graph( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_export/utils.py", line 123, in _check_input_constraints_for_graph raise RuntimeError( RuntimeError: Expected input at *args[0].shape[1] to be <= 18, but got 20 Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 382, in exported_dynamic_shapes_example1.module()(torch.randn(8, 8, 3)) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 737, in call_wrapped return self._wrapped_call(self, *args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 317, in __call__ raise e File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 304, in __call__ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_unlift.py", line 32, in _check_input_constraints_pre_hook return _check_input_constraints_for_graph( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_export/utils.py", line 129, in _check_input_constraints_for_graph raise RuntimeError( RuntimeError: Expected input at *args[0].shape[2] to be equal to 2, but got 3 .. GENERATED FROM PYTHON SOURCE LINES 387-389 Note that if our example inputs to ``torch.export`` do not satisfy the constraints given by ``dynamic_shapes``, then we get an error. .. GENERATED FROM PYTHON SOURCE LINES 389-400 .. code-block:: default inp1_dim1_bad = Dim("inp1_dim1_bad", min=11, max=18) dynamic_shapes1_bad = { "x": {0: inp1_dim0, 1: inp1_dim1_bad}, } try: export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1_bad) except Exception: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none Traceback (most recent call last): File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 347, in _export_to_torch_ir gm_torch_level, _ = torch._dynamo.export( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1354, in inner raise constraint_violation_error File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1311, in inner result_traced = opt_f(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors return callback(frame, cache_entry, hooks, frame_state, skip=1) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert return _compile( File "/opt/conda/envs/py_3.10/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper r = func(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner out_code = transform_code_object(code, transform) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object transformations(instructions, code_options) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 482, in transform tracer = InstructionTranslator( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2115, in __init__ self.symbolic_locals = VariableTracker.apply( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 217, in apply result = { File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 218, in k: cls.apply(fn, v, cache, skip_fn) for k, v in list(value.items()) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 203, in apply result = fn(update_object_dict(value)) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2116, in lambda x: x.realize(), self.symbolic_locals File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 58, in realize self._cache.realize(self.parents_tracker) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 24, in realize self.vt = VariableBuilder(tx, self.source)(self.value) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 269, in __call__ vt = self._wrap(value) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 402, in _wrap return type_dispatch(self, value) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1073, in wrap_tensor tensor_variable = wrap_fx_proxy( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1330, in wrap_fx_proxy return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1440, in wrap_fx_proxy_cls example_value = wrap_to_fake_tensor_and_record( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1880, in wrap_to_fake_tensor_and_record fake_e = wrap_fake_exception( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1190, in wrap_fake_exception return fn() File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1881, in lambda: tx.fake_mode.from_tensor( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1666, in from_tensor return self.fake_tensor_converter( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 349, in __call__ return self.from_real_tensor( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 306, in from_real_tensor out = self.meta_converter( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_subclasses/meta_utils.py", line 967, in __call__ r = self.meta_tensor( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_subclasses/meta_utils.py", line 782, in meta_tensor ) = sym_sizes_strides_storage_offset(t, source, symbolic_context) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_subclasses/meta_utils.py", line 269, in sym_sizes_strides_storage_offset return shape_env.create_symbolic_sizes_strides_storage_offset( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 2339, in create_symbolic_sizes_strides_storage_offset return self._create_symbolic_sizes_strides_storage_offset( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/experimental/recording.py", line 231, in wrapper return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 2395, in _create_symbolic_sizes_strides_storage_offset size: List[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(ex_size, source, symbolic_context) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 2271, in _produce_dyn_sizes_from_int_tuple size.append(self.create_symbol( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/experimental/recording.py", line 231, in wrapper return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 2709, in create_symbol raise ConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]") torch.fx.experimental.symbolic_shapes.ConstraintViolationError: 10 not in range [11, 18] Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 396, in export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1_bad) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/__init__.py", line 174, in export return _export( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 635, in wrapper raise e File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 618, in wrapper ep = fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 83, in wrapper return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 860, in _export gm_torch_level = _export_to_torch_ir( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 359, in _export_to_torch_ir raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: TRY200 torch._dynamo.exc.UserError: 10 not in range [11, 18] Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information .. GENERATED FROM PYTHON SOURCE LINES 401-403 We can enforce that equalities between dimensions of different tensors by using the same ``torch.export.Dim`` object, for example, in matrix multiplication: .. GENERATED FROM PYTHON SOURCE LINES 403-429 .. code-block:: default inp2 = torch.randn(4, 8) inp3 = torch.randn(8, 2) class DynamicShapesExample2(torch.nn.Module): def forward(self, x, y): return x @ y inp2_dim0 = Dim("inp2_dim0") inner_dim = Dim("inner_dim") inp3_dim1 = Dim("inp3_dim1") dynamic_shapes2 = { "x": {0: inp2_dim0, 1: inner_dim}, "y": {0: inner_dim, 1: inp3_dim1}, } exported_dynamic_shapes_example2 = export(DynamicShapesExample2(), (inp2, inp3), dynamic_shapes=dynamic_shapes2) print(exported_dynamic_shapes_example2.module()(torch.randn(2, 16), torch.randn(16, 4))) try: exported_dynamic_shapes_example2.module()(torch.randn(4, 8), torch.randn(4, 2)) except Exception: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[-2.9354, -2.2066, -0.2080, 4.6121], [-0.5658, -0.6108, 0.8887, 1.5908]]) Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 425, in exported_dynamic_shapes_example2.module()(torch.randn(4, 8), torch.randn(4, 2)) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 737, in call_wrapped return self._wrapped_call(self, *args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 317, in __call__ raise e File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 304, in __call__ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_unlift.py", line 32, in _check_input_constraints_pre_hook return _check_input_constraints_for_graph( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_export/utils.py", line 85, in _check_input_constraints_for_graph raise RuntimeError( RuntimeError: Expected input at *args[1].shape[0] to be equal to 8, but got 4 .. GENERATED FROM PYTHON SOURCE LINES 430-433 We can also describe one dimension in terms of other. There are some restrictions to how detailed we can specify one dimension in terms of another, but generally, those in the form of ``A * Dim + B`` should work. .. GENERATED FROM PYTHON SOURCE LINES 433-470 .. code-block:: default class DerivedDimExample1(torch.nn.Module): def forward(self, x, y): return x + y[1:] foo = DerivedDimExample1() x, y = torch.randn(5), torch.randn(6) dimx = torch.export.Dim("dimx", min=3, max=6) dimy = dimx + 1 derived_dynamic_shapes1 = ({0: dimx}, {0: dimy}) derived_dim_example1 = export(foo, (x, y), dynamic_shapes=derived_dynamic_shapes1) print(derived_dim_example1.module()(torch.randn(4), torch.randn(5))) try: derived_dim_example1.module()(torch.randn(4), torch.randn(6)) except Exception: tb.print_exc() class DerivedDimExample2(torch.nn.Module): def forward(self, z, y): return z[1:] + y[1::3] foo = DerivedDimExample2() z, y = torch.randn(4), torch.randn(10) dx = torch.export.Dim("dx", min=3, max=6) dz = dx + 1 dy = dx * 3 + 1 derived_dynamic_shapes2 = ({0: dz}, {0: dy}) derived_dim_example2 = export(foo, (z, y), dynamic_shapes=derived_dynamic_shapes2) print(derived_dim_example2.module()(torch.randn(7), torch.randn(19))) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([ 0.3007, -1.7282, -0.0729, 0.1139]) Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 450, in derived_dim_example1.module()(torch.randn(4), torch.randn(6)) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 737, in call_wrapped return self._wrapped_call(self, *args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 317, in __call__ raise e File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 304, in __call__ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_unlift.py", line 32, in _check_input_constraints_pre_hook return _check_input_constraints_for_graph( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_export/utils.py", line 85, in _check_input_constraints_for_graph raise RuntimeError( RuntimeError: Expected input at *args[1].shape[0] to be equal to 5, but got 6 tensor([ 2.5416, -0.2760, 0.9003, -1.7479, -2.9716, 0.1013]) .. GENERATED FROM PYTHON SOURCE LINES 471-476 We can actually use ``torch.export`` to guide us as to which ``dynamic_shapes`` constraints are necessary. We can do this by relaxing all constraints (recall that if we do not provide constraints for a dimension, the default behavior is to constrain to the exact shape value of the example input) and letting ``torch.export`` error out. .. GENERATED FROM PYTHON SOURCE LINES 476-496 .. code-block:: default inp4 = torch.randn(8, 16) inp5 = torch.randn(16, 32) class DynamicShapesExample3(torch.nn.Module): def forward(self, x, y): if x.shape[0] <= 16: return x @ y[:, :16] return y dynamic_shapes3 = { "x": {i: Dim(f"inp4_dim{i}") for i in range(inp4.dim())}, "y": {i: Dim(f"inp5_dim{i}") for i in range(inp5.dim())}, } try: export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3) except Exception: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none Traceback (most recent call last): File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 347, in _export_to_torch_ir gm_torch_level, _ = torch._dynamo.export( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1354, in inner raise constraint_violation_error File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1311, in inner result_traced = opt_f(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors return callback(frame, cache_entry, hooks, frame_state, skip=1) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert return _compile( File "/opt/conda/envs/py_3.10/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper r = func(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 634, in compile_inner check_fn = CheckFunctionManager( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 1048, in __init__ guard.create(builder) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_guards.py", line 249, in create return self.create_fn(builder, self) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 705, in SHAPE_ENV guards = output_graph.shape_env.produce_guards( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 3308, in produce_guards raise ConstraintViolationError( torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (inp4_dim0, inp5_dim0, inp5_dim1)! For more information, run with TORCH_LOGS="+dynamic". - The values of inp5_dim0 = L['y'].size()[0] and inp4_dim1 = L['x'].size()[1] must always be equal. - Not all values of inp5_dim1 = L['y'].size()[1] in the specified range satisfy the generated guard Ne(L['y'].size()[1], 16). - Not all values of inp4_dim0 = L['x'].size()[0] in the specified range satisfy the generated guard 2 <= L['x'].size()[0] and L['x'].size()[0] <= 16 - Not all values of inp5_dim1 = L['y'].size()[1] in the specified range satisfy the generated guard 16 <= L['y'].size()[1] and L['y'].size()[1] <= 9223372036854775806 Suggested fixes: inp4_dim0 = Dim('inp4_dim0', max=16) inp4_dim1 = Dim('inp4_dim1') inp4_dim1 = Dim('inp4_dim1', min=2, max=9223372036854775806) # 2 <= inp4_dim1 <= 9223372036854775806 inp5_dim1 = Dim('inp5_dim1', min=16) inp5_dim0 = inp4_dim1 During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 492, in export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/__init__.py", line 174, in export return _export( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 635, in wrapper raise e File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 618, in wrapper ep = fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/exported_program.py", line 83, in wrapper return fn(*args, **kwargs) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 860, in _export gm_torch_level = _export_to_torch_ir( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/_trace.py", line 359, in _export_to_torch_ir raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: TRY200 torch._dynamo.exc.UserError: Constraints violated (inp4_dim0, inp5_dim0, inp5_dim1)! For more information, run with TORCH_LOGS="+dynamic". - The values of inp5_dim0 = L['y'].size()[0] and inp4_dim1 = L['x'].size()[1] must always be equal. - Not all values of inp5_dim1 = L['y'].size()[1] in the specified range satisfy the generated guard Ne(L['y'].size()[1], 16). - Not all values of inp4_dim0 = L['x'].size()[0] in the specified range satisfy the generated guard 2 <= L['x'].size()[0] and L['x'].size()[0] <= 16 - Not all values of inp5_dim1 = L['y'].size()[1] in the specified range satisfy the generated guard 16 <= L['y'].size()[1] and L['y'].size()[1] <= 9223372036854775806 Suggested fixes: inp4_dim0 = Dim('inp4_dim0', max=16) inp4_dim1 = Dim('inp4_dim1') inp4_dim1 = Dim('inp4_dim1', min=2, max=9223372036854775806) # 2 <= inp4_dim1 <= 9223372036854775806 inp5_dim1 = Dim('inp5_dim1', min=16) inp5_dim0 = inp4_dim1 .. GENERATED FROM PYTHON SOURCE LINES 497-500 We can see that the error message gives us suggested fixes to our dynamic shape constraints. Let us follow those suggestions (exact suggestions may differ slightly): .. GENERATED FROM PYTHON SOURCE LINES 500-517 .. code-block:: default def suggested_fixes(): inp4_dim1 = Dim('shared_dim') # suggested fixes below inp4_dim0 = Dim('inp4_dim0', max=16) inp5_dim1 = Dim('inp5_dim1', min=17) inp5_dim0 = inp4_dim1 # end of suggested fixes return { "x": {0: inp4_dim0, 1: inp4_dim1}, "y": {0: inp5_dim0, 1: inp5_dim1}, } dynamic_shapes3_fixed = suggested_fixes() exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed) print(exported_dynamic_shapes_example3.module()(torch.randn(4, 32), torch.randn(32, 64))) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[ 12.5915, -0.7265, -1.8981, -8.0323, -2.1447, 14.4020, 13.7854, 1.5568, 3.7933, 7.2591, -2.2477, 1.6366, -5.9276, 8.5279, 7.9349, -1.1328], [ -5.2210, 9.4576, -0.2372, 9.0035, 6.2572, -8.4716, 6.0191, 4.8424, -0.4486, 0.1885, -0.1749, 2.4314, 3.8271, 8.1822, 6.5064, 0.6512], [ -3.6856, 7.5222, 4.8073, 13.1255, 3.6440, -4.1587, 2.9806, 0.3689, 1.1133, -1.7169, -2.1537, 1.1841, 6.7619, 9.3401, -1.1372, -8.9628], [ -4.3608, 5.1219, -0.6240, -6.8640, 2.3344, -2.0273, 0.2769, -0.9930, 2.0298, -10.3922, -1.7186, -5.0928, 11.7383, 6.4864, 8.0827, 0.5863]]) .. GENERATED FROM PYTHON SOURCE LINES 518-525 Note that in the example above, because we constrained the value of ``x.shape[0]`` in ``dynamic_shapes_example3``, the exported program is sound even though there is a raw ``if`` statement. If you want to see why ``torch.export`` generated these constraints, you can re-run the script with the environment variable ``TORCH_LOGS=dynamic,dynamo``, or use ``torch._logging.set_logs``. .. GENERATED FROM PYTHON SOURCE LINES 525-533 .. code-block:: default import logging torch._logging.set_logs(dynamic=logging.INFO, dynamo=logging.INFO) exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed) # reset to previous values torch._logging.set_logs(dynamic=logging.WARNING, dynamo=logging.WARNING) .. rst-class:: sphx-glr-script-out .. code-block:: none I0516 23:19:03.221000 140594383835776 torch/_dynamo/logging.py:55] [16/0] Step 1: torchdynamo start tracing forward /var/lib/workspace/intermediate_source/torch_export_tutorial.py:481 I0516 23:19:03.224000 140594383835776 torch/fx/experimental/symbolic_shapes.py:2724] [16/0] create_symbol s0 = 8 for L['x'].size()[0] [2, 16] (_dynamo/variables/builder.py:1881 in ) I0516 23:19:03.225000 140594383835776 torch/fx/experimental/symbolic_shapes.py:2724] [16/0] create_symbol s1 = 16 for L['x'].size()[1] [2, 9223372036854775806] (_dynamo/variables/builder.py:1881 in ) I0516 23:19:03.227000 140594383835776 torch/fx/experimental/symbolic_shapes.py:2724] [16/0] create_symbol s2 = 16 for L['y'].size()[0] [2, 9223372036854775806] (_dynamo/variables/builder.py:1881 in ) I0516 23:19:03.228000 140594383835776 torch/fx/experimental/symbolic_shapes.py:2724] [16/0] create_symbol s3 = 32 for L['y'].size()[1] [17, 9223372036854775806] (_dynamo/variables/builder.py:1881 in ) I0516 23:19:03.236000 140594383835776 torch/fx/experimental/symbolic_shapes.py:3809] [16/0] set_replacement s2 = s1 (solve_backed) ValueRanges(lower=2, upper=9223372036854775806, is_bool=False) I0516 23:19:03.236000 140594383835776 torch/fx/experimental/symbolic_shapes.py:4035] [16/0] eval Eq(s1, s2) [guard added] at ar/lib/workspace/intermediate_source/torch_export_tutorial.py:483 in forward (_meta_registrations.py:2014 in meta_mm) I0516 23:19:03.237000 140594383835776 torch/_dynamo/logging.py:55] [16/0] Step 1: torchdynamo done tracing forward (RETURN_VALUE) I0516 23:19:03.239000 140594383835776 torch/fx/experimental/symbolic_shapes.py:3809] [16/0] set_replacement s2 = s1 (find) ValueRanges(lower=2, upper=9223372036854775806, is_bool=False) I0516 23:19:03.239000 140594383835776 torch/_dynamo/logging.py:55] [16/0] Step 2: calling compiler function dynamo_normalization_capturing_compiler I0516 23:19:03.239000 140594383835776 torch/_dynamo/logging.py:55] [16/0] Step 2: done compiler function dynamo_normalization_capturing_compiler I0516 23:19:03.241000 140594383835776 torch/fx/experimental/symbolic_shapes.py:2806] [16/0] produce_guards I0516 23:19:03.267000 140594383835776 torch/_dynamo/eval_frame.py:1339] Summary of dimension constraints: I0516 23:19:03.267000 140594383835776 torch/_dynamo/eval_frame.py:1339] Suggested fixes: I0516 23:19:03.267000 140594383835776 torch/_dynamo/eval_frame.py:1339] inp4_dim0 = Dim('inp4_dim0', max=16) I0516 23:19:03.267000 140594383835776 torch/_dynamo/eval_frame.py:1339] inp5_dim1 = Dim('inp5_dim1', min=17) I0516 23:19:03.267000 140594383835776 torch/_dynamo/eval_frame.py:1339] shared_dim = Dim('shared_dim') I0516 23:19:03.268000 140594383835776 torch/_dynamo/eval_frame.py:1363] Dynamo captured graph: I0516 23:19:03.268000 140594383835776 torch/_dynamo/eval_frame.py:1363] I0516 23:19:03.268000 140594383835776 torch/_dynamo/eval_frame.py:1363] class GraphModule(torch.nn.Module): I0516 23:19:03.268000 140594383835776 torch/_dynamo/eval_frame.py:1363] def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor): I0516 23:19:03.268000 140594383835776 torch/_dynamo/eval_frame.py:1363] l_x_ = L_x_ I0516 23:19:03.268000 140594383835776 torch/_dynamo/eval_frame.py:1363] l_y_ = L_y_ I0516 23:19:03.268000 140594383835776 torch/_dynamo/eval_frame.py:1363] I0516 23:19:03.268000 140594383835776 torch/_dynamo/eval_frame.py:1363] # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:482 in forward, code: if x.shape[0] <= 16: I0516 23:19:03.268000 140594383835776 torch/_dynamo/eval_frame.py:1363] size = l_x_.size() I0516 23:19:03.268000 140594383835776 torch/_dynamo/eval_frame.py:1363] getitem = size[0]; size = None I0516 23:19:03.268000 140594383835776 torch/_dynamo/eval_frame.py:1363] le = getitem <= 16; getitem = None I0516 23:19:03.268000 140594383835776 torch/_dynamo/eval_frame.py:1363] I0516 23:19:03.268000 140594383835776 torch/_dynamo/eval_frame.py:1363] # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:483 in forward, code: return x @ y[:, :16] I0516 23:19:03.268000 140594383835776 torch/_dynamo/eval_frame.py:1363] getitem_2 = l_y_[(slice(None, None, None), slice(None, 16, None))]; l_y_ = None I0516 23:19:03.268000 140594383835776 torch/_dynamo/eval_frame.py:1363] matmul = l_x_ @ getitem_2; l_x_ = getitem_2 = None I0516 23:19:03.268000 140594383835776 torch/_dynamo/eval_frame.py:1363] return (matmul,) I0516 23:19:03.268000 140594383835776 torch/_dynamo/eval_frame.py:1363] .. GENERATED FROM PYTHON SOURCE LINES 534-536 We can view an ``ExportedProgram``'s symbolic shape ranges using the ``range_constraints`` field. .. GENERATED FROM PYTHON SOURCE LINES 536-539 .. code-block:: default print(exported_dynamic_shapes_example3.range_constraints) .. rst-class:: sphx-glr-script-out .. code-block:: none {s0: ValueRanges(lower=2, upper=16, is_bool=False), s1: ValueRanges(lower=2, upper=oo, is_bool=False), s3: ValueRanges(lower=17, upper=oo, is_bool=False)} .. GENERATED FROM PYTHON SOURCE LINES 540-549 Custom Ops ---------- ``torch.export`` can export PyTorch programs with custom operators. Currently, the steps to register a custom op for use by ``torch.export`` are: - Define the custom op using ``torch.library`` (`reference `__) as with any other custom op .. GENERATED FROM PYTHON SOURCE LINES 549-561 .. code-block:: default from torch.library import Library, impl, impl_abstract m = Library("my_custom_library", "DEF") m.define("custom_op(Tensor input) -> Tensor") @impl(m, "custom_op", "CompositeExplicitAutograd") def custom_op(x): print("custom_op called!") return torch.relu(x) .. GENERATED FROM PYTHON SOURCE LINES 562-564 - Define a ``"Meta"`` implementation of the custom op that returns an empty tensor with the same shape as the expected output .. GENERATED FROM PYTHON SOURCE LINES 564-569 .. code-block:: default @impl_abstract("my_custom_library::custom_op") def custom_op_meta(x): return torch.empty_like(x) .. GENERATED FROM PYTHON SOURCE LINES 570-571 - Call the custom op from the code you want to export using ``torch.ops`` .. GENERATED FROM PYTHON SOURCE LINES 571-579 .. code-block:: default class CustomOpExample(torch.nn.Module): def forward(self, x): x = torch.sin(x) x = torch.ops.my_custom_library.custom_op(x) x = torch.cos(x) return x .. GENERATED FROM PYTHON SOURCE LINES 580-581 - Export the code as before .. GENERATED FROM PYTHON SOURCE LINES 581-586 .. code-block:: default exported_custom_op_example = export(CustomOpExample(), (torch.randn(3, 3),)) exported_custom_op_example.graph_module.print_readable() print(exported_custom_op_example.module()(torch.randn(3, 3))) .. rst-class:: sphx-glr-script-out .. code-block:: none class GraphModule(torch.nn.Module): def forward(self, arg0_1: "f32[3, 3]"): # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:574 in forward, code: x = torch.sin(x) sin: "f32[3, 3]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:575 in forward, code: x = torch.ops.my_custom_library.custom_op(x) custom_op: "f32[3, 3]" = torch.ops.my_custom_library.custom_op.default(sin); sin = None # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:576 in forward, code: x = torch.cos(x) cos: "f32[3, 3]" = torch.ops.aten.cos.default(custom_op); custom_op = None return (cos,) custom_op called! tensor([[0.6387, 0.5722, 1.0000], [0.8776, 0.9846, 0.8223], [0.9983, 1.0000, 0.9979]]) .. GENERATED FROM PYTHON SOURCE LINES 587-594 Note in the above outputs that the custom op is included in the exported graph. And when we call the exported graph as a function, the original custom op is called, as evidenced by the ``print`` call. If you have a custom operator implemented in C++, please refer to `this document `__ to make it compatible with ``torch.export``. .. GENERATED FROM PYTHON SOURCE LINES 596-606 Decompositions -------------- The graph produced by ``torch.export`` by default returns a graph containing only functional ATen operators. This functional ATen operator set (or "opset") contains around 2000 operators, all of which are functional, that is, they do not mutate or alias inputs. You can find a list of all ATen operators `here `__ and you can inspect if an operator is functional by checking ``op._schema.is_mutable``, for example: .. GENERATED FROM PYTHON SOURCE LINES 606-610 .. code-block:: default print(torch.ops.aten.add.Tensor._schema.is_mutable) print(torch.ops.aten.add_.Tensor._schema.is_mutable) .. rst-class:: sphx-glr-script-out .. code-block:: none False True .. GENERATED FROM PYTHON SOURCE LINES 611-633 By default, the environment in which you want to run the exported graph should support all ~2000 of these operators. However, you can use the following API on the exported program if your specific environment is only able to support a subset of the ~2000 operators. .. code-block:: python def run_decompositions( self: ExportedProgram, decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] ) -> ExportedProgram ``run_decompositions`` takes in a decomposition table, which is a mapping of operators to a function specifying how to reduce, or decompose, that operator into an equivalent sequence of other ATen operators. The default decomposition table for ``run_decompositions`` is the `Core ATen decomposition table `__ which will decompose the all ATen operators to the `Core ATen Operator Set `__ which consists of only ~180 operators. .. GENERATED FROM PYTHON SOURCE LINES 633-648 .. code-block:: default class M(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(3, 4) def forward(self, x): return self.linear(x) ep = export(M(), (torch.randn(2, 3),)) print(ep.graph) core_ir_ep = ep.run_decompositions() print(core_ir_ep.graph) .. rst-class:: sphx-glr-script-out .. code-block:: none graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%arg0_1,), kwargs = {}) %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%arg1_1, %arg2_1, %t), kwargs = {}) return (addmm,) graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%arg0_1, [1, 0]), kwargs = {}) %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%arg1_1, %arg2_1, %permute), kwargs = {}) return (addmm,) .. GENERATED FROM PYTHON SOURCE LINES 649-661 Notice that after running ``run_decompositions`` the ``torch.ops.aten.t.default`` operator, which is not part of the Core ATen Opset, has been replaced with ``torch.ops.aten.permute.default`` which is part of the Core ATen Opset. Most ATen operators already have decompositions, which are located `here `__. If you would like to use some of these existing decomposition functions, you can pass in a list of operators you would like to decompose to the `get_decompositions `__ function, which will return a decomposition table using existing decomposition implementations. .. GENERATED FROM PYTHON SOURCE LINES 661-678 .. code-block:: default class M(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(3, 4) def forward(self, x): return self.linear(x) ep = export(M(), (torch.randn(2, 3),)) print(ep.graph) from torch._decomp import get_decompositions decomp_table = get_decompositions([torch.ops.aten.t.default, torch.ops.aten.transpose.int]) core_ir_ep = ep.run_decompositions(decomp_table) print(core_ir_ep.graph) .. rst-class:: sphx-glr-script-out .. code-block:: none graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%arg0_1,), kwargs = {}) %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%arg1_1, %arg2_1, %t), kwargs = {}) return (addmm,) graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%arg0_1, [1, 0]), kwargs = {}) %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%arg1_1, %arg2_1, %permute), kwargs = {}) return (addmm,) .. GENERATED FROM PYTHON SOURCE LINES 679-682 If there is no existing decomposition function for an ATen operator that you would like to decompose, feel free to send a pull request into PyTorch implementing the decomposition! .. GENERATED FROM PYTHON SOURCE LINES 684-699 ExportDB -------- ``torch.export`` will only ever export a single computation graph from a PyTorch program. Because of this requirement, there will be Python or PyTorch features that are not compatible with ``torch.export``, which will require users to rewrite parts of their model code. We have seen examples of this earlier in the tutorial -- for example, rewriting if-statements using ``cond``. `ExportDB `__ is the standard reference that documents supported and unsupported Python/PyTorch features for ``torch.export``. It is essentially a list a program samples, each of which represents the usage of one particular Python/PyTorch feature and its interaction with ``torch.export``. Examples are also tagged by category so that they can be more easily searched. For example, let's use ExportDB to get a better understanding of how the predicate works in the ``cond`` operator. We can look at the example called ``cond_predicate``, which has a ``torch.cond`` tag. The example code looks like: .. GENERATED FROM PYTHON SOURCE LINES 699-710 .. code-block:: default def cond_predicate(x): """ The conditional statement (aka predicate) passed to ``cond()`` must be one of the following: - ``torch.Tensor`` with a single element - boolean expression NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ pred = x.dim() > 2 and x.shape[2] > 10 return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x]) .. GENERATED FROM PYTHON SOURCE LINES 711-719 More generally, ExportDB can be used as a reference when one of the following occurs: 1. Before attempting ``torch.export``, you know ahead of time that your model uses some tricky Python/PyTorch features and you want to know if ``torch.export`` covers that feature. 2. When attempting ``torch.export``, there is a failure and it's unclear how to work around it. ExportDB is not exhaustive, but is intended to cover all use cases found in typical PyTorch code. Feel free to reach out if there is an important Python/PyTorch feature that should be added to ExportDB or supported by ``torch.export``. .. GENERATED FROM PYTHON SOURCE LINES 721-730 Running the Exported Program ---------------------------- As ``torch.export`` is only a graph capturing mechanism, calling the artifact produced by ``torch.export`` eagerly will be equivalent to running the eager module. To optimize the execution of the Exported Program, we can pass this exported artifact to backends such as Inductor through ``torch.compile``, `AOTInductor `__, or `TensorRT `__. .. GENERATED FROM PYTHON SOURCE LINES 730-752 .. code-block:: default class M(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(3, 3) def forward(self, x): x = self.linear(x) return x inp = torch.randn(2, 3, device="cuda") m = M().to(device="cuda") ep = torch.export.export(m, (inp,)) # Run it eagerly res = ep.module()(inp) print(res) # Run it with torch.compile res = torch.compile(ep.module(), backend="inductor")(inp) print(res) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[ 1.3676, -0.4303, -0.2113], [-0.5053, -0.0877, 0.5134]], device='cuda:0', grad_fn=) tensor([[ 1.3676, -0.4303, -0.2113], [-0.5053, -0.0877, 0.5134]], device='cuda:0', grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 753-767 .. code-block:: python import torch._export import torch._inductor # Note: these APIs are subject to change # Compile the exported program to a .so using ``AOTInductor`` with torch.no_grad(): so_path = torch._inductor.aot_compile(ep.module(), [inp]) # Load and run the .so file in Python. # To load and run it in a C++ environment, see: # https://pytorch.org/docs/main/torch.compiler_aot_inductor.html res = torch._export.aot_load(so_path, device="cuda")(inp) .. GENERATED FROM PYTHON SOURCE LINES 769-775 Conclusion ---------- We introduced ``torch.export``, the new PyTorch 2.X way to export single computation graphs from PyTorch programs. In particular, we demonstrate several code modifications and considerations (control flow ops, constraints, etc.) that need to be made in order to export a graph. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 1.912 seconds) .. _sphx_glr_download_intermediate_torch_export_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: torch_export_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: torch_export_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_