Note
Click here to download the full example code
torch.export Tutorial¶
Author: William Wen, Zhengxu Chen
Warning
torch.export
and its related features are in prototype status and are subject to backwards compatibility
breaking changes. This tutorial provides a snapshot of torch.export
usage as of PyTorch 2.1.
Note
The torch.export nightly tutorial demonstrates some APIs that are present in the nightly binaries, but are not present in the PyTorch 2.1 release.
torch.export()
is the PyTorch 2.X way to export PyTorch models into
standardized model representations, intended
to be run on different (i.e. Python-less) environments.
In this tutorial, you will learn how to use torch.export()
to extract
ExportedProgram
’s (i.e. single-graph representations) from PyTorch programs.
We also detail some considerations/modifications that you may need
to make in order to make your model compatible with torch.export
.
Contents
Basic Usage¶
torch.export
extracts single-graph representations from PyTorch programs
by tracing the target function, given example inputs.
The signature of torch.export
is:
export(
f: Callable,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
*,
constraints: Optional[List[Constraint]] = None
) -> ExportedProgram
torch.export
traces the tensor computation graph from calling f(*args, **kwargs)
and wraps it in an ExportedProgram
, which can be serialized or executed later with
different inputs. Note that while the output ExportedGraph
is callable and can be
called in the same way as the original input callable, it is not a torch.nn.Module
.
We will detail the constraints
argument later in the tutorial.
import torch
from torch.export import export
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(100, 10)
def forward(self, x, y):
return torch.nn.functional.relu(self.lin(x + y), inplace=True)
mod = MyModule()
exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100)))
print(type(exported_mod))
print(exported_mod(torch.randn(8, 100), torch.randn(8, 100)))
<class 'torch.export.ExportedProgram'>
tensor([[0.8632, 0.8407, 0.0407, 0.0000, 0.4132, 0.0000, 0.0000, 0.1538, 0.6111,
0.0000],
[0.0000, 0.0000, 0.0273, 0.8057, 0.0000, 1.0162, 0.8042, 0.0000, 0.2660,
0.0000],
[0.9481, 0.1396, 1.0225, 0.9563, 0.5832, 0.2546, 0.4095, 0.4591, 0.0000,
2.0053],
[1.1300, 0.4873, 0.0000, 0.9663, 1.2275, 1.4015, 0.0000, 0.9444, 0.0000,
0.0000],
[0.0000, 0.8724, 1.1648, 0.6867, 0.0000, 0.2833, 0.3202, 0.5848, 0.0000,
0.0833],
[1.1311, 0.1324, 0.0000, 1.7842, 0.0000, 0.3474, 0.9916, 0.3571, 0.0000,
0.0000],
[1.4348, 1.0570, 0.1771, 0.0000, 0.9510, 0.0000, 0.0000, 0.0000, 0.2618,
0.0000],
[0.8853, 0.0000, 0.0000, 0.4486, 0.0000, 0.0000, 0.5841, 0.7604, 0.0000,
0.0000]])
Let’s review some attributes of ExportedProgram
that are of interest.
The graph
attribute is an FX graph
traced from the function we exported, that is, the computation graph of all PyTorch operations.
The FX graph has some important properties:
The operations are “ATen-level” operations.
The graph is “functionalized”, meaning that no operations are mutations.
The graph_module
attribute is the GraphModule
that wraps the graph
attribute
so that it can be ran as a torch.nn.Module
.
print(exported_mod)
print(exported_mod.graph_module)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[10, 100], arg1_1: f32[10], arg2_1: f32[8, 100], arg3_1: f32[8, 100]):
#
add: f32[8, 100] = torch.ops.aten.add.Tensor(arg2_1, arg3_1); arg2_1 = arg3_1 = None
permute: f32[100, 10] = torch.ops.aten.permute.default(arg0_1, [1, 0]); arg0_1 = None
addmm: f32[8, 10] = torch.ops.aten.addmm.default(arg1_1, add, permute); arg1_1 = add = permute = None
relu: f32[8, 10] = torch.ops.aten.relu.default(addmm); addmm = None
return (relu,)
Graph Signature: ExportGraphSignature(parameters=['L__self___lin.weight', 'L__self___lin.bias'], buffers=[], user_inputs=['arg2_1', 'arg3_1'], user_outputs=['relu'], inputs_to_parameters={'arg0_1': 'L__self___lin.weight', 'arg1_1': 'L__self___lin.bias'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
GraphModule()
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
add = torch.ops.aten.add.Tensor(arg2_1, arg3_1); arg2_1 = arg3_1 = None
permute = torch.ops.aten.permute.default(arg0_1, [1, 0]); arg0_1 = None
addmm = torch.ops.aten.addmm.default(arg1_1, add, permute); arg1_1 = add = permute = None
relu = torch.ops.aten.relu.default(addmm); addmm = None
return (relu,)
# To see more debug info, please use `graph_module.print_readable()`
The printed code shows that FX graph only contains ATen-level ops (such as torch.ops.aten
)
and that mutations were removed. For example, the mutating op torch.nn.functional.relu(..., inplace=True)
is represented in the printed code by torch.ops.aten.relu.default
, which does not mutate.
Future uses of input to the original mutating relu
op are replaced by the additional new output
of the replacement non-mutating relu
op.
Other attributes of interest in ExportedProgram
include:
graph_signature
– the inputs, outputs, parameters, buffers, etc. of the exported graph.range_constraints
andequality_constraints
– constraints, covered later
print(exported_mod.graph_signature)
ExportGraphSignature(parameters=['L__self___lin.weight', 'L__self___lin.bias'], buffers=[], user_inputs=['arg2_1', 'arg3_1'], user_outputs=['relu'], inputs_to_parameters={'arg0_1': 'L__self___lin.weight', 'arg1_1': 'L__self___lin.bias'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
See the torch.export
documentation
for more details.
Graph Breaks¶
Although torch.export
shares components with torch.compile
,
the key limitation of torch.export
, especially when compared to torch.compile
, is that it does not
support graph breaks. This is because handling graph breaks involves interpreting
the unsupported operation with default Python evaluation, which is incompatible
with the export use case. Therefore, in order to make your model code compatible
with torch.export
, you will need to modify your code to remove graph breaks.
A graph break is necessary in cases such as:
data-dependent control flow
def bad1(x):
if x.sum() > 0:
return torch.sin(x)
return torch.cos(x)
import traceback as tb
try:
export(bad1, (torch.randn(3, 3),))
except Exception:
tb.print_exc()
Traceback (most recent call last):
File "/var/lib/jenkins/workspace/intermediate_source/torch_export_tutorial.py", line 131, in <module>
export(bad1, (torch.randn(3, 3),))
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/__init__.py", line 1018, in export
return export(f, args, kwargs, constraints)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_export/__init__.py", line 270, in export
gm_torch_level, _ = torch._dynamo.export(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1140, in inner
result_traced = opt_f(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
return fn(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
return _compile(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
out_code = transform_code_object(code, transform)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
tracer.run()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2069, in run
super().run()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 719, in run
and self.step()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 683, in step
getattr(self, inst.opname)(inst)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 370, in inner
raise exc.UserError(
torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow
from user code:
File "/var/lib/jenkins/workspace/intermediate_source/torch_export_tutorial.py", line 125, in bad1
if x.sum() > 0:
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
accessing tensor data with
.data
def bad2(x):
x.data[0, 0] = 3
return x
try:
export(bad2, (torch.randn(3, 3),))
except Exception:
tb.print_exc()
Traceback (most recent call last):
File "/var/lib/jenkins/workspace/intermediate_source/torch_export_tutorial.py", line 143, in <module>
export(bad2, (torch.randn(3, 3),))
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/__init__.py", line 1018, in export
return export(f, args, kwargs, constraints)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_export/__init__.py", line 356, in export
gm, graph_signature = aot_export_module(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 4027, in aot_export_module
fx_g, metadata, in_spec, out_spec = _aot_export_function(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 4209, in _aot_export_function
fx_g, meta = create_aot_dispatcher_function(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3438, in create_aot_dispatcher_function
raise RuntimeError(f"""
RuntimeError:
Found following user inputs located at [0] are mutated. This is currently banned in the aot_export workflow.
If you need this functionality, please file a github issue.
fw_metadata=ViewAndMutationMeta(input_info=[InputAliasInfo(is_leaf=True, mutates_data=True, mutates_metadata=False)], output_info=[OutputAliasInfo(output_type=<OutputType.is_input: 3>, raw_type=<class 'torch.Tensor'>, base_idx=0, dynamic_dims=set())], requires_grad_info=[False, False], num_intermediate_bases=0, keep_input_mutations=False, traced_tangents=[FakeTensor(..., size=(3, 3))], num_symints_saved_for_bw=None)
calling unsupported functions (such as many built-in functions)
def bad3(x):
x = x + 1
return x + id(x)
try:
export(bad3, (torch.randn(3, 3),))
except Exception:
tb.print_exc()
Traceback (most recent call last):
File "/var/lib/jenkins/workspace/intermediate_source/torch_export_tutorial.py", line 155, in <module>
export(bad3, (torch.randn(3, 3),))
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/__init__.py", line 1018, in export
return export(f, args, kwargs, constraints)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_export/__init__.py", line 270, in export
gm_torch_level, _ = torch._dynamo.export(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1140, in inner
result_traced = opt_f(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
return fn(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
return _compile(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
out_code = transform_code_object(code, transform)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
tracer.run()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2069, in run
super().run()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 719, in run
and self.step()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 683, in step
getattr(self, inst.opname)(inst)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1110, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 557, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 618, in call_function
result = handler(tx, *args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 1297, in call_id
unimplemented(f"call_id with args {args}")
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 172, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_id with args (TensorVariable(),)
from user code:
File "/var/lib/jenkins/workspace/intermediate_source/torch_export_tutorial.py", line 152, in bad3
return x + id(x)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
unsupported Python language features (e.g. throwing exceptions, match statements)
def bad4(x):
try:
x = x + 1
raise RuntimeError("bad")
except:
x = x + 2
return x
try:
export(bad4, (torch.randn(3, 3),))
except Exception:
tb.print_exc()
Traceback (most recent call last):
File "/var/lib/jenkins/workspace/intermediate_source/torch_export_tutorial.py", line 171, in <module>
export(bad4, (torch.randn(3, 3),))
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/__init__.py", line 1018, in export
return export(f, args, kwargs, constraints)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_export/__init__.py", line 270, in export
gm_torch_level, _ = torch._dynamo.export(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1140, in inner
result_traced = opt_f(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
return fn(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
return _compile(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
out_code = transform_code_object(code, transform)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
tracer.run()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2069, in run
super().run()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 719, in run
and self.step()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 683, in step
getattr(self, inst.opname)(inst)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1110, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 557, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 645, in call_function
return super().call_function(tx, args, kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 306, in call_function
unimplemented(f"call_function {self} {args} {kwargs}")
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 172, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_function BuiltinVariable(RuntimeError) [ConstantVariable(str)] {}
from user code:
File "/var/lib/jenkins/workspace/intermediate_source/torch_export_tutorial.py", line 165, in bad4
raise RuntimeError("bad")
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
The sections below demonstrate some ways you can modify your code in order to remove graph breaks.
Control Flow Ops¶
torch.export
actually does support data-dependent control flow.
But these need to be expressed using control flow ops. For example,
we can fix the control flow example above using the cond
op, like so:
from functorch.experimental.control_flow import cond
def bad1_fixed(x):
def true_fn(x):
return torch.sin(x)
def false_fn(x):
return torch.cos(x)
return cond(x.sum() > 0, true_fn, false_fn, [x])
exported_bad1_fixed = export(bad1_fixed, (torch.randn(3, 3),))
print(exported_bad1_fixed(torch.ones(3, 3)))
print(exported_bad1_fixed(-torch.ones(3, 3)))
tensor([[0.8415, 0.8415, 0.8415],
[0.8415, 0.8415, 0.8415],
[0.8415, 0.8415, 0.8415]])
tensor([[0.5403, 0.5403, 0.5403],
[0.5403, 0.5403, 0.5403],
[0.5403, 0.5403, 0.5403]])
There are limitations to cond
that one should be aware of:
The predicate (i.e.
x.sum() > 0
) must result in a boolean or a single-element tensor.The operands (i.e.
[x]
) must be tensors.The branch function (i.e.
true_fn
andfalse_fn
) signature must match with the operands and they must both return a single tensor with the same metadata (for example,dtype
,shape
, etc.).Branch functions cannot mutate input or global variables.
Branch functions cannot access closure variables, except for
self
if the function is defined in the scope of a method.
Constraints¶
Ops can have different specializations/behaviors for different tensor shapes, so by default,
torch.export
requires inputs to ExportedProgram
to have the same shape as the respective
example inputs given to the initial torch.export
call.
If we try to run the ExportedProgram
in the example below with a tensor
with a different shape, we get an error:
class MyModule2(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(100, 10)
def forward(self, x, y):
return torch.nn.functional.relu(self.lin(x + y), inplace=True)
mod2 = MyModule2()
exported_mod2 = export(mod2, (torch.randn(8, 100), torch.randn(8, 100)))
try:
exported_mod2(torch.randn(10, 100), torch.randn(10, 100))
except Exception:
tb.print_exc()
Traceback (most recent call last):
File "/var/lib/jenkins/workspace/intermediate_source/torch_export_tutorial.py", line 257, in <module>
exported_mod2(torch.randn(10, 100), torch.randn(10, 100))
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/__init__.py", line 337, in __call__
self._check_input_constraints(*ordered_params, *ordered_buffers, *args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/__init__.py", line 559, in _check_input_constraints
_assertion_graph(*args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 678, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 284, in __call__
raise e
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 274, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "<eval_with_key>.83", line 17, in forward
_assert_async_1 = torch.ops.aten._assert_async.msg(scalar_tensor_1, 'Input arg3_1.shape[0] is specialized at 8'); scalar_tensor_1 = None
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_ops.py", line 448, in __call__
return self._op(*args, **kwargs or {})
RuntimeError: Input arg3_1.shape[0] is specialized at 8
We can modify the torch.export
call to
relax some of these constraints. We use torch.export.dynamic_dim
to
express shape constraints manually.
Using dynamic_dim
on a tensor’s dimension marks it as dynamic (i.e. unconstrained), and
we can provide additional upper and lower bound shape constraints.
The first argument of dynamic_dim
is the tensor variable we wish
to specify a dimension constraint for. The second argument specifies
the dimension of the first argument the constraint applies to.
In the example below, our input
inp1
has an unconstrained first dimension, but the size of the second
dimension must be in the interval (3, 18].
from torch.export import dynamic_dim
inp1 = torch.randn(10, 10)
def constraints_example1(x):
x = x[:, 2:]
return torch.relu(x)
constraints1 = [
dynamic_dim(inp1, 0),
3 < dynamic_dim(inp1, 1),
dynamic_dim(inp1, 1) <= 18,
]
exported_constraints_example1 = export(constraints_example1, (inp1,), constraints=constraints1)
print(exported_constraints_example1(torch.randn(5, 5)))
try:
exported_constraints_example1(torch.randn(8, 1))
except Exception:
tb.print_exc()
try:
exported_constraints_example1(torch.randn(8, 20))
except Exception:
tb.print_exc()
tensor([[0.0000, 0.9904, 0.2659],
[0.9732, 1.2030, 0.2684],
[0.1263, 0.7264, 0.0000],
[0.6484, 1.2110, 0.5155],
[0.5219, 1.5092, 0.0000]])
Traceback (most recent call last):
File "/var/lib/jenkins/workspace/intermediate_source/torch_export_tutorial.py", line 297, in <module>
exported_constraints_example1(torch.randn(8, 1))
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/__init__.py", line 337, in __call__
self._check_input_constraints(*ordered_params, *ordered_buffers, *args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/__init__.py", line 559, in _check_input_constraints
_assertion_graph(*args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 678, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 284, in __call__
raise e
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 274, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "<eval_with_key>.107", line 12, in forward
_assert_async_1 = torch.ops.aten._assert_async.msg(scalar_tensor_1, 'Input arg0_1.shape[1] is outside of specified dynamic range [4, 18]'); scalar_tensor_1 = None
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_ops.py", line 448, in __call__
return self._op(*args, **kwargs or {})
RuntimeError: Input arg0_1.shape[1] is outside of specified dynamic range [4, 18]
Traceback (most recent call last):
File "/var/lib/jenkins/workspace/intermediate_source/torch_export_tutorial.py", line 302, in <module>
exported_constraints_example1(torch.randn(8, 20))
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/__init__.py", line 337, in __call__
self._check_input_constraints(*ordered_params, *ordered_buffers, *args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/__init__.py", line 559, in _check_input_constraints
_assertion_graph(*args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 678, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 284, in __call__
raise e
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 274, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "<eval_with_key>.112", line 9, in forward
_assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, 'Input arg0_1.shape[1] is outside of specified dynamic range [4, 18]'); scalar_tensor = None
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_ops.py", line 448, in __call__
return self._op(*args, **kwargs or {})
RuntimeError: Input arg0_1.shape[1] is outside of specified dynamic range [4, 18]
Note that if our example inputs to torch.export
do not satisfy the constraints,
then we get an error.
constraints1_bad = [
dynamic_dim(inp1, 0),
10 < dynamic_dim(inp1, 1),
dynamic_dim(inp1, 1) <= 18,
]
try:
export(constraints_example1, (inp1,), constraints=constraints1_bad)
except Exception:
tb.print_exc()
Traceback (most recent call last):
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_export/__init__.py", line 270, in export
gm_torch_level, _ = torch._dynamo.export(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1183, in inner
raise constraint_violation_error
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1140, in inner
result_traced = opt_f(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
return fn(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
return _compile(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
out_code = transform_code_object(code, transform)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 441, in transform
tracer = InstructionTranslator(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2013, in __init__
self.symbolic_locals = collections.OrderedDict(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2016, in <genexpr>
VariableBuilder(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 225, in __call__
vt = self._wrap(value).clone(**self.options())
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 370, in _wrap
return type_dispatch(self, value)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 964, in wrap_tensor
tensor_variable = wrap_fx_proxy(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1191, in wrap_fx_proxy
return wrap_fx_proxy_cls(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1306, in wrap_fx_proxy_cls
example_value = wrap_to_fake_tensor_and_record(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1587, in wrap_to_fake_tensor_and_record
fake_e = wrap_fake_exception(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 916, in wrap_fake_exception
return fn()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1588, in <lambda>
lambda: tx.fake_mode.from_tensor(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1721, in from_tensor
return self.fake_tensor_converter(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 371, in __call__
return self.from_real_tensor(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 324, in from_real_tensor
out = self.meta_converter(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_subclasses/meta_utils.py", line 591, in __call__
r = self.meta_tensor(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_subclasses/meta_utils.py", line 405, in meta_tensor
sizes, strides, storage_offset = sym_sizes_strides_storage_offset(t)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_subclasses/meta_utils.py", line 234, in sym_sizes_strides_storage_offset
return shape_env.create_symbolic_sizes_strides_storage_offset(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 2344, in create_symbolic_sizes_strides_storage_offset
size: List[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(ex_size, source, dynamic_dims, constraint_dims)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 2248, in _produce_dyn_sizes_from_int_tuple
size.append(self.create_symbol(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 2554, in create_symbol
raise ConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]")
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: 10 not in range [11, 18]
from user code:
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/var/lib/jenkins/workspace/intermediate_source/torch_export_tutorial.py", line 316, in <module>
export(constraints_example1, (inp1,), constraints=constraints1_bad)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/__init__.py", line 1018, in export
return export(f, args, kwargs, constraints)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_export/__init__.py", line 280, in export
raise UserError(UserErrorType.CONSTRAIN_VIOLATION, str(e))
torch._dynamo.exc.UserError: 10 not in range [11, 18]
from user code:
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
We can also use dynamic_dim
to enforce expected equalities between
dimensions, for example, in matrix multiplication:
inp2 = torch.randn(4, 8)
inp3 = torch.randn(8, 2)
def constraints_example2(x, y):
return x @ y
constraints2 = [
dynamic_dim(inp2, 0),
dynamic_dim(inp2, 1) == dynamic_dim(inp3, 0),
dynamic_dim(inp3, 1),
]
exported_constraints_example2 = export(constraints_example2, (inp2, inp3), constraints=constraints2)
print(exported_constraints_example2(torch.randn(2, 16), torch.randn(16, 4)))
try:
exported_constraints_example2(torch.randn(4, 8), torch.randn(4, 2))
except Exception:
tb.print_exc()
tensor([[-1.6614, 8.1254, -4.6447, -8.0367],
[-2.8395, -6.2832, 2.3965, -1.5896]])
Traceback (most recent call last):
File "/var/lib/jenkins/workspace/intermediate_source/torch_export_tutorial.py", line 341, in <module>
exported_constraints_example2(torch.randn(4, 8), torch.randn(4, 2))
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/__init__.py", line 337, in __call__
self._check_input_constraints(*ordered_params, *ordered_buffers, *args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/__init__.py", line 559, in _check_input_constraints
_assertion_graph(*args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 678, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 284, in __call__
raise e
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/graph_module.py", line 274, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "<eval_with_key>.136", line 11, in forward
_assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, 'Input arg0_1.shape[1] is not equal to input arg1_1.shape[0]'); scalar_tensor = None
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_ops.py", line 448, in __call__
return self._op(*args, **kwargs or {})
RuntimeError: Input arg0_1.shape[1] is not equal to input arg1_1.shape[0]
We can actually use torch.export
to guide us as to which constraints
are necessary. We can do this by relaxing all constraints (recall that if we
do not provide constraints for a dimension, the default behavior is to constrain
to the exact shape value of the example input) and letting torch.export
error out.
inp4 = torch.randn(8, 16)
inp5 = torch.randn(16, 32)
def constraints_example3(x, y):
if x.shape[0] <= 16:
return x @ y[:, :16]
return y
constraints3 = (
[dynamic_dim(inp4, i) for i in range(inp4.dim())] +
[dynamic_dim(inp5, i) for i in range(inp5.dim())]
)
try:
export(constraints_example3, (inp4, inp5), constraints=constraints3)
except Exception:
tb.print_exc()
Traceback (most recent call last):
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_export/__init__.py", line 270, in export
gm_torch_level, _ = torch._dynamo.export(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1183, in inner
raise constraint_violation_error
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1140, in inner
result_traced = opt_f(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
return fn(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
return _compile(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 549, in compile_inner
check_fn = CheckFunctionManager(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 937, in __init__
guard.create(local_builder, global_builder)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_guards.py", line 243, in create
return self.create_fn(self.source.select(local_builder, global_builder), self)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 607, in SHAPE_ENV
guards = output_graph.shape_env.produce_guards(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 2981, in produce_guards
raise ConstraintViolationError(f"Constraints violated!\n{err}")
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated!
1. The specified set of equalities {} is not sufficient; please also specify L['y'].size()[0] == L['x'].size()[1].
2. Not all values of L['y'].size()[1] in the specified range satisfy the generated guard Ne(L['y'].size()[1], 16). For more information about why this guard was generated, run with TORCH_LOGS=dynamic.
3. Not all values of L['x'].size()[0] in the specified range satisfy the generated guard L['x'].size()[0] <= 16. For more information about why this guard was generated, run with TORCH_LOGS=dynamic.
4. Not all values of L['y'].size()[1] in the specified range satisfy the generated guard L['y'].size()[1] >= 16. For more information about why this guard was generated, run with TORCH_LOGS=dynamic.
The following dimensions CAN be dynamic.
Please use the following code to specify the constraints they must satisfy:
```
def specify_constraints(x, y):
return [
# x:
dynamic_dim(x, 0) <= 16,
# y:
16 < dynamic_dim(y, 1),
dynamic_dim(y, 0) == dynamic_dim(x, 1),
]
```
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/var/lib/jenkins/workspace/intermediate_source/torch_export_tutorial.py", line 366, in <module>
export(constraints_example3, (inp4, inp5), constraints=constraints3)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/__init__.py", line 1018, in export
return export(f, args, kwargs, constraints)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_export/__init__.py", line 280, in export
raise UserError(UserErrorType.CONSTRAIN_VIOLATION, str(e))
torch._dynamo.exc.UserError: Constraints violated!
1. The specified set of equalities {} is not sufficient; please also specify L['y'].size()[0] == L['x'].size()[1].
2. Not all values of L['y'].size()[1] in the specified range satisfy the generated guard Ne(L['y'].size()[1], 16). For more information about why this guard was generated, run with TORCH_LOGS=dynamic.
3. Not all values of L['x'].size()[0] in the specified range satisfy the generated guard L['x'].size()[0] <= 16. For more information about why this guard was generated, run with TORCH_LOGS=dynamic.
4. Not all values of L['y'].size()[1] in the specified range satisfy the generated guard L['y'].size()[1] >= 16. For more information about why this guard was generated, run with TORCH_LOGS=dynamic.
The following dimensions CAN be dynamic.
Please use the following code to specify the constraints they must satisfy:
```
def specify_constraints(x, y):
return [
# x:
dynamic_dim(x, 0) <= 16,
# y:
16 < dynamic_dim(y, 1),
dynamic_dim(y, 0) == dynamic_dim(x, 1),
]
```
We can see that the error message suggests to us to use some additional code to specify the necessary constraints. Let us use that code (exact code may differ slightly):
def specify_constraints(x, y):
return [
# x:
dynamic_dim(x, 0) <= 16,
# y:
16 < dynamic_dim(y, 1),
dynamic_dim(y, 0) == dynamic_dim(x, 1),
]
constraints3_fixed = specify_constraints(inp4, inp5)
exported_constraints_example3 = export(constraints_example3, (inp4, inp5), constraints=constraints3_fixed)
print(exported_constraints_example3(torch.randn(4, 32), torch.randn(32, 64)))
tensor([[ -3.2486, -0.0564, 10.4182, -0.5107, 6.5052, -2.8538, 1.4333,
5.5788, -6.1642, 5.2328, 7.8676, -5.6574, -5.1638, 4.5801,
12.9064, 0.2945],
[ 1.2501, 0.8377, 0.9198, -2.6442, 2.9253, -6.0027, -2.0633,
5.5477, 1.0574, -0.1237, -15.5657, 0.9839, -4.6436, -3.3065,
-3.2065, 1.7973],
[ 2.7279, -5.7209, -8.2522, -7.9277, 5.4594, -2.9686, 6.2203,
0.9756, -13.3661, -8.4977, 1.0126, -6.3700, 3.6296, 6.0752,
7.3516, -4.2076],
[ -6.0910, 2.7838, 1.7066, 2.6283, 7.2123, 9.0221, 0.2256,
4.8568, -7.1904, -0.7193, -2.8492, 3.3305, 5.6469, 6.2556,
0.2927, 2.6650]])
Note that in the example above, because we constrained the value of x.shape[0]
in
constraints_example3
, the exported program is sound even though there is a
raw if
statement.
If you want to see why torch.export
generated these constraints, you can
re-run the script with the environment variable TORCH_LOGS=dynamic,dynamo
,
or use torch._logging.set_logs
.
import logging
torch._logging.set_logs(dynamic=logging.INFO, dynamo=logging.INFO)
exported_constraints_example3 = export(constraints_example3, (inp4, inp5), constraints=constraints3_fixed)
# reset to previous values
torch._logging.set_logs(dynamic=logging.WARNING, dynamo=logging.WARNING)
[2023-12-04 18:15:04,439] [12/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing constraints_example3 /var/lib/jenkins/workspace/intermediate_source/torch_export_tutorial.py:355
[2023-12-04 18:15:04,440] [12/0] torch.fx.experimental.symbolic_shapes: [INFO] create_env
[2023-12-04 18:15:04,441] [12/0] torch.fx.experimental.symbolic_shapes: [INFO] create_symbol s0 = 8 for L['x'].size()[0]
[2023-12-04 18:15:04,441] [12/0] torch.fx.experimental.symbolic_shapes: [INFO] create_symbol s1 = 16 for L['x'].size()[1]
[2023-12-04 18:15:04,445] [12/0] torch.fx.experimental.symbolic_shapes: [INFO] create_symbol s2 = 16 for L['y'].size()[0]
[2023-12-04 18:15:04,446] [12/0] torch.fx.experimental.symbolic_shapes: [INFO] create_symbol s3 = 32 for L['y'].size()[1]
[2023-12-04 18:15:04,464] [12/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s1, s2) [guard added] at ar/lib/jenkins/workspace/intermediate_source/torch_export_tutorial.py:357 in constraints_example3 (_meta_registrations.py:1813 in meta_mm)
[2023-12-04 18:15:04,466] [12/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing constraints_example3 (RETURN_VALUE)
[2023-12-04 18:15:04,467] [12/0] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function dynamo_normalization_capturing_compiler
[2023-12-04 18:15:04,467] [12/0] torch._dynamo.output_graph: [INFO] Step 2: done compiler function dynamo_normalization_capturing_compiler
[2023-12-04 18:15:04,469] [12/0] torch.fx.experimental.symbolic_shapes: [INFO] produce_guards
[2023-12-04 18:15:04,497] torch._dynamo.eval_frame: [INFO] Summary of dimension constraints:
[2023-12-04 18:15:04,497] torch._dynamo.eval_frame: [INFO] The following dimensions CAN be dynamic.
[2023-12-04 18:15:04,497] torch._dynamo.eval_frame: [INFO] Please use the following code to specify the constraints they must satisfy:
[2023-12-04 18:15:04,497] torch._dynamo.eval_frame: [INFO] ```
[2023-12-04 18:15:04,497] torch._dynamo.eval_frame: [INFO] def specify_constraints(x, y):
[2023-12-04 18:15:04,497] torch._dynamo.eval_frame: [INFO] return [
[2023-12-04 18:15:04,497] torch._dynamo.eval_frame: [INFO] # x:
[2023-12-04 18:15:04,497] torch._dynamo.eval_frame: [INFO] dynamic_dim(x, 0) <= 16,
[2023-12-04 18:15:04,497] torch._dynamo.eval_frame: [INFO]
[2023-12-04 18:15:04,497] torch._dynamo.eval_frame: [INFO] # y:
[2023-12-04 18:15:04,497] torch._dynamo.eval_frame: [INFO] 17 <= dynamic_dim(y, 1),
[2023-12-04 18:15:04,497] torch._dynamo.eval_frame: [INFO] dynamic_dim(y, 0) == dynamic_dim(x, 1),
[2023-12-04 18:15:04,497] torch._dynamo.eval_frame: [INFO] ]
[2023-12-04 18:15:04,497] torch._dynamo.eval_frame: [INFO] ```
[2023-12-04 18:15:04,497] torch._dynamo.eval_frame: [INFO]
[2023-12-04 18:15:04,501] torch.fx.experimental.symbolic_shapes: [INFO] create_env
We can view an ExportedProgram
’s constraints using the range_constraints
and
equality_constraints
attributes. The logging above reveals what the symbols s0, s1, ...
represent.
print(exported_constraints_example3.range_constraints)
print(exported_constraints_example3.equality_constraints)
{s0: RangeConstraint(min_val=2, max_val=16), s1: RangeConstraint(min_val=2, max_val=9223372036854775806), s2: RangeConstraint(min_val=2, max_val=9223372036854775806), s3: RangeConstraint(min_val=17, max_val=9223372036854775806)}
[(InputDim(input_name='arg1_1', dim=0), InputDim(input_name='arg0_1', dim=1))]
We can also constrain on individual values in the source code itself using
constrain_as_value
and constrain_as_size
. constrain_as_value
specifies
that a given integer value is expected to fall within the provided minimum/maximum bounds (inclusive).
If a bound is not provided, then it is assumed to be unbounded.
from torch.export import constrain_as_size, constrain_as_value
def constraints_example4(x, y):
b = y.item()
constrain_as_value(b, 3, 5)
if b >= 3:
return x.cos()
return x.sin()
exported_constraints_example4 = export(constraints_example4, (torch.randn(3, 3), torch.tensor([4])))
print(exported_constraints_example4(torch.randn(3, 3), torch.tensor([5])))
try:
exported_constraints_example4(torch.randn(3, 3), torch.tensor([2]))
except Exception:
tb.print_exc()
tensor([[ 0.9048, 0.9745, -0.1208],
[-0.3705, 0.5393, 0.8433],
[ 0.9743, 0.9976, 0.9998]])
Traceback (most recent call last):
File "/var/lib/jenkins/workspace/intermediate_source/torch_export_tutorial.py", line 430, in <module>
exported_constraints_example4(torch.randn(3, 3), torch.tensor([2]))
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/export/__init__.py", line 342, in __call__
res = torch.fx.Interpreter(self.graph_module).run(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/interpreter.py", line 138, in run
self.env[node] = self.run_node(node)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/interpreter.py", line 195, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/interpreter.py", line 267, in call_function
return target(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_ops.py", line 448, in __call__
return self._op(*args, **kwargs or {})
RuntimeError: _local_scalar_dense is outside of inline constraint [3, 5].
While executing %_assert_async : [num_users=0] = call_function[target=torch.ops.aten._assert_async.msg](args = (%scalar_tensor, _local_scalar_dense is outside of inline constraint [3, 5].), kwargs = {})
Original traceback:
NoneType: None
constrain_as_size
is similar to constrain_as_value
, except that it should be used on integer values that
will be used to specify tensor shapes – in particular, the value must not be 0 or 1 because
many operations have special behavior for tensors with a shape value of 0 or 1.
def constraints_example5(x, y):
b = y.item()
constrain_as_size(b)
z = torch.ones(b, 4)
return x.sum() + z.sum()
exported_constraints_example5 = export(constraints_example5, (torch.randn(2, 2), torch.tensor([4])))
print(exported_constraints_example5(torch.randn(2, 2), torch.tensor([5])))
try:
exported_constraints_example5(torch.randn(2, 2), torch.tensor([1]))
except Exception:
tb.print_exc()
tensor(16.5664)
Custom Ops¶
torch.export
can export PyTorch programs with custom operators.
Currently, the steps to register a custom op for use by torch.export
are:
Define the custom op using
torch.library
(reference) as with any other custom op
Define a
"Meta"
implementation of the custom op that returns an empty tensor with the same shape as the expected output
@impl(m, "custom_op", "Meta")
def custom_op_meta(x):
return torch.empty_like(x)
Call the custom op from the code you want to export using
torch.ops
Export the code as before
exported_custom_op_example = export(custom_op_example, (torch.randn(3, 3),))
exported_custom_op_example.graph_module.print_readable()
print(exported_custom_op_example(torch.randn(3, 3)))
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 3]):
#
sin: f32[3, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
custom_op: f32[3, 3] = torch.ops.my_custom_library.custom_op.default(sin); sin = None
cos: f32[3, 3] = torch.ops.aten.cos.default(custom_op); custom_op = None
return (cos,)
custom_op called!
tensor([[1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 0.7589],
[1.0000, 1.0000, 0.8630]])
Note in the above outputs that the custom op is included in the exported graph.
And when we call the exported graph as a function, the original custom op is called,
as evidenced by the print
call.
If you have a custom operator implemented in C++, please refer to
this document
to make it compatible with torch.export
.
ExportDB¶
torch.export
will only ever export a single computation graph from a PyTorch program. Because of this requirement,
there will be Python or PyTorch features that are not compatible with torch.export
, which will require users to
rewrite parts of their model code. We have seen examples of this earlier in the tutorial – for example, rewriting
if-statements using cond
.
ExportDB is the standard reference that documents
supported and unsupported Python/PyTorch features for torch.export
. It is essentially a list a program samples, each
of which represents the usage of one particular Python/PyTorch feature and its interaction with torch.export
.
Examples are also tagged by category so that they can be more easily searched.
For example, let’s use ExportDB to get a better understanding of how the predicate works in the cond
operator.
We can look at the example called cond_predicate
, which has a torch.cond
tag. The example code looks like:
def cond_predicate(x):
"""
The conditional statement (aka predicate) passed to ``cond()`` must be one of the following:
- torch.Tensor with a single element
- boolean expression
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
pred = x.dim() > 2 and x.shape[2] > 10
return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
More generally, ExportDB can be used as a reference when one of the following occurs:
Before attempting
torch.export
, you know ahead of time that your model uses some tricky Python/PyTorch features and you want to know iftorch.export
covers that feature.When attempting
torch.export
, there is a failure and it’s unclear how to work around it.
ExportDB is not exhaustive, but is intended to cover all use cases found in typical PyTorch code. Feel free to reach
out if there is an important Python/PyTorch feature that should be added to ExportDB or supported by torch.export
.
Conclusion¶
We introduced torch.export
, the new PyTorch 2.X way to export single computation
graphs from PyTorch programs. In particular, we demonstrate several code modifications
and considerations (control flow ops, constraints, etc.) that need to be made in order to export a graph.
Total running time of the script: ( 0 minutes 1.691 seconds)