torch.export¶
Warning
This feature is a prototype under active development and there WILL BE BREAKING CHANGES in the future.
Overview¶
torch.export.export()
takes an arbitrary Python callable (a
torch.nn.Module
, a function or a method) and produces a traced graph
representing only the Tensor computation of the function in an Ahead-of-Time
(AOT) fashion, which can subsequently be executed with different outputs or
serialized.
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
a = torch.sin(x)
b = torch.cos(y)
return a + b
example_args = (torch.randn(10, 10), torch.randn(10, 10))
exported_program: torch.export.ExportedProgram = export(
Mod(), args=example_args
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[10, 10], arg1_1: f32[10, 10]):
# code: a = torch.sin(x)
sin: f32[10, 10] = torch.ops.aten.sin.default(arg0_1);
# code: b = torch.cos(y)
cos: f32[10, 10] = torch.ops.aten.cos.default(arg1_1);
# code: return a + b
add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos);
return (add,)
Graph signature: ExportGraphSignature(
parameters=[],
buffers=[],
user_inputs=['arg0_1', 'arg1_1'],
user_outputs=['add'],
inputs_to_parameters={},
inputs_to_buffers={},
buffers_to_mutate={},
backward_signature=None,
assertion_dep_token=None,
)
Range constraints: {}
torch.export
produces a clean intermediate representation (IR) with the
following invariants. More specifications about the IR can be found
here.
Soundness: It is guaranteed to be a sound representation of the original program, and maintains the same calling conventions of the original program.
Normalized: There are no Python semantics within the graph. Submodules from the original programs are inlined to form one fully flattened computational graph.
Graph properties: The graph is purely functional, meaning it does not contain operations with side effects such as mutations or aliasing. It does not mutate any intermediate values, parameters, or buffers.
Metadata: The graph contains metadata captured during tracing, such as a stacktrace from user’s code.
Under the hood, torch.export
leverages the following latest technologies:
TorchDynamo (torch._dynamo) is an internal API that uses a CPython feature called the Frame Evaluation API to safely trace PyTorch graphs. This provides a massively improved graph capturing experience, with much fewer rewrites needed in order to fully trace the PyTorch code.
AOT Autograd provides a functionalized PyTorch graph and ensures the graph is decomposed/lowered to the ATen operator set.
Torch FX (torch.fx) is the underlying representation of the graph, allowing flexible Python-based transformations.
Existing frameworks¶
torch.compile()
also utilizes the same PT2 stack as torch.export
, but
is slightly different:
JIT vs. AOT:
torch.compile()
is a JIT compiler whereas which is not intended to be used to produce compiled artifacts outside of deployment.Partial vs. Full Graph Capture: When
torch.compile()
runs into an untraceable part of a model, it will “graph break” and fall back to running the program in the eager Python runtime. In comparison,torch.export
aims to get a full graph representation of a PyTorch model, so it will error out when something untraceable is reached. Sincetorch.export
produces a full graph disjoint from any Python features or runtime, this graph can then be saved, loaded, and run in different environments and languages.Usability tradeoff: Since
torch.compile()
is able to fallback to the Python runtime whenever it reaches something untraceable, it is a lot more flexible.torch.export
will instead require users to provide more information or rewrite their code to make it traceable.
Compared to torch.fx.symbolic_trace()
, torch.export
traces using
TorchDynamo which operates at the Python bytecode level, giving it the ability
to trace arbitrary Python constructs not limited by what Python operator
overloading supports. Additionally, torch.export
keeps fine-grained track of
tensor metadata, so that conditionals on things like tensor shapes do not
fail tracing. In general, torch.export
is expected to work on more user
programs, and produce lower-level graphs (at the torch.ops.aten
operator
level). Note that users can still use torch.fx.symbolic_trace()
as a
preprocessing step before torch.export
.
Compared to torch.jit.script()
, torch.export
does not capture Python
control flow or data structures, but it supports more Python language features
than TorchScript (as it is easier to have comprehensive coverage over Python
bytecodes). The resulting graphs are simpler and only have straight line control
flow (except for explicit control flow operators).
Compared to torch.jit.trace()
, torch.export
is sound: it is able to
trace code that performs integer computation on sizes and records all of the
side-conditions necessary to show that a particular trace is valid for other
inputs.
Exporting a PyTorch Model¶
An Example¶
The main entrypoint is through torch.export.export()
, which takes a
callable (torch.nn.Module
, function, or method) and sample inputs, and
captures the computation graph into an torch.export.ExportedProgram
. An
example:
import torch
from torch.export import export
# Simple module for demonstration
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=3, out_channels=16, kernel_size=3, padding=1
)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
a = self.conv(x)
a.add_(constant)
return self.maxpool(self.relu(a))
example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}
exported_program: torch.export.ExportedProgram = export(
M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256], arg3_1: f32[1, 16, 256, 256]):
# code: a = self.conv(x)
convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(
arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1
);
# code: a.add_(constant)
add: f32[1, 16, 256, 256] = torch.ops.aten.add.Tensor(convolution, arg3_1);
# code: return self.maxpool(self.relu(a))
relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(add);
max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default(
relu, [3, 3], [3, 3]
);
getitem: f32[1, 16, 85, 85] = max_pool2d_with_indices[0];
return (getitem,)
Graph signature: ExportGraphSignature(
parameters=['L__self___conv.weight', 'L__self___conv.bias'],
buffers=[],
user_inputs=['arg2_1', 'arg3_1'],
user_outputs=['getitem'],
inputs_to_parameters={
'arg0_1': 'L__self___conv.weight',
'arg1_1': 'L__self___conv.bias',
},
inputs_to_buffers={},
buffers_to_mutate={},
backward_signature=None,
assertion_dep_token=None,
)
Range constraints: {}
Inspecting the ExportedProgram
, we can note the following:
The
torch.fx.Graph
contains the computation graph of the original program, along with records of the original code for easy debugging.The graph contains only
torch.ops.aten
operators found here and custom operators, and is fully functional, without any inplace operators such astorch.add_
.The parameters (weight and bias to conv) are lifted as inputs to the graph, resulting in no
get_attr
nodes in the graph, which previously existed in the result oftorch.fx.symbolic_trace()
.The
torch.export.ExportGraphSignature
models the input and output signature, along with specifying which inputs are parameters.The resulting shape and dtype of tensors produced by each node in the graph is noted. For example, the
convolution
node will result in a tensor of dtypetorch.float32
and shape (1, 16, 256, 256).
Non-Strict Export¶
In PyTorch 2.3, we introduced a new mode of tracing called non-strict mode. It’s still going through hardening, so if you run into any issues, please file them to Github with the “oncall: export” tag.
In non-strict mode, we trace through the program using the Python interpreter. Your code will execute exactly as it would in eager mode; the only difference is that all Tensor objects will be replaced by ProxyTensors, which will record all their operations into a graph.
In strict mode, which is currently the default, we first trace through the program using TorchDynamo, a bytecode analysis engine. TorchDynamo does not actually execute your Python code. Instead, it symbolically analyzes it and builds a graph based on the results. This analysis allows torch.export to provide stronger guarantees about safety, but not all Python code is supported.
An example of a case where one might want to use non-strict mode is if you run into a unsupported TorchDynamo feature that might not be easily solved, and you know the python code is not exactly needed for computation. For example:
import contextlib
import torch
class ContextManager():
def __init__(self):
self.count = 0
def __enter__(self):
self.count += 1
def __exit__(self, exc_type, exc_value, traceback):
self.count -= 1
class M(torch.nn.Module):
def forward(self, x):
with ContextManager():
return x.sin() + x.cos()
export(M(), (torch.ones(3, 3),), strict=False) # Non-strict traces successfully
export(M(), (torch.ones(3, 3),)) # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager
In this example, the first call using non-strict mode (through the
strict=False
flag) traces successfully whereas the second call using strict
mode (default) results with a failure, where TorchDynamo is unable to support
context managers. One option is to rewrite the code (see Limitations of torch.expot), but seeing as the context manager does not affect the tensor
computations in the model, we can go with the non-strict mode’s result.
Expressing Dynamism¶
By default torch.export
will trace the program assuming all input shapes are
static, and specializing the exported program to those dimensions. However,
some dimensions, such as a batch dimension, can be dynamic and vary from run to
run. Such dimensions must be specified by using the
torch.export.Dim()
API to create them and by passing them into
torch.export.export()
through the dynamic_shapes
argument. An example:
import torch
from torch.export import Dim, export
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.branch1 = torch.nn.Sequential(
torch.nn.Linear(64, 32), torch.nn.ReLU()
)
self.branch2 = torch.nn.Sequential(
torch.nn.Linear(128, 64), torch.nn.ReLU()
)
self.buffer = torch.ones(32)
def forward(self, x1, x2):
out1 = self.branch1(x1)
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)
example_args = (torch.randn(32, 64), torch.randn(32, 128))
# Create a dynamic batch size
batch = Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
exported_program: torch.export.ExportedProgram = export(
M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[32, 64], arg1_1: f32[32], arg2_1: f32[64, 128], arg3_1: f32[64], arg4_1: f32[32], arg5_1: f32[s0, 64], arg6_1: f32[s0, 128]):
# code: out1 = self.branch1(x1)
permute: f32[64, 32] = torch.ops.aten.permute.default(arg0_1, [1, 0]);
addmm: f32[s0, 32] = torch.ops.aten.addmm.default(arg1_1, arg5_1, permute);
relu: f32[s0, 32] = torch.ops.aten.relu.default(addmm);
# code: out2 = self.branch2(x2)
permute_1: f32[128, 64] = torch.ops.aten.permute.default(arg2_1, [1, 0]);
addmm_1: f32[s0, 64] = torch.ops.aten.addmm.default(arg3_1, arg6_1, permute_1);
relu_1: f32[s0, 64] = torch.ops.aten.relu.default(addmm_1); addmm_1 = None
# code: return (out1 + self.buffer, out2)
add: f32[s0, 32] = torch.ops.aten.add.Tensor(relu, arg4_1);
return (add, relu_1)
Graph signature: ExportGraphSignature(
parameters=[
'branch1.0.weight',
'branch1.0.bias',
'branch2.0.weight',
'branch2.0.bias',
],
buffers=['L__self___buffer'],
user_inputs=['arg5_1', 'arg6_1'],
user_outputs=['add', 'relu_1'],
inputs_to_parameters={
'arg0_1': 'branch1.0.weight',
'arg1_1': 'branch1.0.bias',
'arg2_1': 'branch2.0.weight',
'arg3_1': 'branch2.0.bias',
},
inputs_to_buffers={'arg4_1': 'L__self___buffer'},
buffers_to_mutate={},
backward_signature=None,
assertion_dep_token=None,
)
Range constraints: {s0: RangeConstraint(min_val=2, max_val=9223372036854775806)}
Some additional things to note:
Through the
torch.export.Dim()
API and thedynamic_shapes
argument, we specified the first dimension of each input to be dynamic. Looking at the inputsarg5_1
andarg6_1
, they have a symbolic shape of (s0, 64) and (s0, 128), instead of the (32, 64) and (32, 128) shaped tensors that we passed in as example inputs.s0
is a symbol representing that this dimension can be a range of values.exported_program.range_constraints
describes the ranges of each symbol appearing in the graph. In this case, we see thats0
has the range [2, inf]. For technical reasons that are difficult to explain here, they are assumed to be not 0 or 1. This is not a bug, and does not necessarily mean that the exported program will not work for dimensions 0 or 1. See The 0/1 Specialization Problem for an in-depth discussion of this topic.
We can also specify more expressive relationships between input shapes, such as where a pair of shapes might differ by one, a shape might be double of another, or a shape is even. An example:
class M(torch.nn.Module):
def forward(self, x, y):
return x + y[1:]
x, y = torch.randn(5), torch.randn(6)
dimx = torch.export.Dim("dimx", min=3, max=6)
dimy = dimx + 1
exported_program = torch.export.export(
M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}),
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[s0]", arg1_1: "f32[s0 + 1]"):
# code: return x + y[1:]
slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(arg1_1, 0, 1, 9223372036854775807); arg1_1 = None
add: "f32[s0]" = torch.ops.aten.add.Tensor(arg0_1, slice_1); arg0_1 = slice_1 = None
return (add,)
Graph signature: ExportGraphSignature(
input_specs=[
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None),
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)
],
output_specs=[
OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)]
)
Range constraints: {s0: ValueRanges(lower=3, upper=6, is_bool=False), s0 + 1: ValueRanges(lower=4, upper=7, is_bool=False)}
Some things to note:
By specifying
{0: dimx}
for the first input, we see that the resulting shape of the first input is now dynamic, being[s0]
. And now by specifying{0: dimy}
for the second input, we see that the resulting shape of the second input is also dynamic. However, because we expresseddimy = dimx + 1
, instead ofarg1_1
’s shape containing a new symbol, we see that it is now being represented with the same symbol used inarg0_1
,s0
. We can see that relationship ofdimy = dimx + 1
is being shown throughs0 + 1
.Looking at the range constraints, we see that
s0
has the range [3, 6], which is specified initially, and we can see thats0 + 1
has the solved range of [4, 7].
Serialization¶
To save the ExportedProgram
, users can use the torch.export.save()
and
torch.export.load()
APIs. A convention is to save the ExportedProgram
using a .pt2
file extension.
An example:
import torch
import io
class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10
exported_program = torch.export.export(MyModule(), torch.randn(5))
torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')
Specializations¶
A key concept in understanding the behavior of torch.export
is the
difference between static and dynamic values.
A dynamic value is one that can change from run to run. These behave like normal arguments to a Python function—you can pass different values for an argument and expect your function to do the right thing. Tensor data is treated as dynamic.
A static value is a value that is fixed at export time and cannot change between executions of the exported program. When the value is encountered during tracing, the exporter will treat it as a constant and hard-code it into the graph.
When an operation is performed (e.g. x + y
) and all inputs are static, then
the output of the operation will be directly hard-coded into the graph, and the
operation won’t show up (i.e. it will get constant-folded).
When a value has been hard-coded into the graph, we say that the graph has been specialized to that value.
The following values are static:
Input Tensor Shapes¶
By default, torch.export
will trace the program specializing on the input
tensors’ shapes, unless a dimension is specified as dynamic via the
dynamic_shapes
argumen to torch.export
. This means that if there exists
shape-dependent control flow, torch.export
will specialize on the branch
that is being taken with the given sample inputs. For example:
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x):
if x.shape[0] > 5:
return x + 1
else:
return x - 1
example_inputs = (torch.rand(10, 2),)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[10, 2]):
add: f32[10, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
return (add,)
The conditional of (x.shape[0] > 5
) does not appear in the
ExportedProgram
because the example inputs have the static
shape of (10, 2). Since torch.export
specializes on the inputs’ static
shapes, the else branch (x - 1
) will never be reached. To preserve the dynamic
branching behavior based on the shape of a tensor in the traced graph,
torch.export.dynamic_dim()
will need to be used to specify the dimension
of the input tensor (x.shape[0]
) to be dynamic, and the source code will
need to be rewritten.
Note that tensors that are part of the module state (e.g. parameters and buffers) always have static shapes.
Python Primitives¶
torch.export
also specializes on Python primtivies,
such as int
, float
, bool
, and str
. However they do have dynamic
variants such as SymInt
, SymFloat
, and SymBool
.
For example:
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, const: int, times: int):
for i in range(times):
x = x + const
return x
example_inputs = (torch.rand(2, 2), 1, 3)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[2, 2], arg1_1, arg2_1):
add: f32[2, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
add_1: f32[2, 2] = torch.ops.aten.add.Tensor(add, 1);
add_2: f32[2, 2] = torch.ops.aten.add.Tensor(add_1, 1);
return (add_2,)
Because integers are specialized, the torch.ops.aten.add.Tensor
operations
are all computed with the hard-coded constant 1
, rather than arg1_1
. If
a user passes a different value for arg1_1
at runtime, like 2, than the one used
during export time, 1, this will result in an error.
Additionally, the times
iterator used in the for
loop is also “inlined”
in the graph through the 3 repeated torch.ops.aten.add.Tensor
calls, and the
input arg2_1
is never used.
Python Containers¶
Python containers (List
, Dict
, NamedTuple
, etc.) are considered to
have static structure.
Limitations of torch.export¶
Graph Breaks¶
As torch.export
is a one-shot process for capturing a computation graph from
a PyTorch program, it might ultimately run into untraceable parts of programs as
it is nearly impossible to support tracing all PyTorch and Python features. In
the case of torch.compile
, an unsupported operation will cause a “graph
break” and the unsupported operation will be run with default Python evaluation.
In contrast, torch.export
will require users to provide additional
information or rewrite parts of their code to make it traceable. As the
tracing is based on TorchDynamo, which evaluates at the Python
bytecode level, there will be significantly fewer rewrites required compared to
previous tracing frameworks.
When a graph break is encountered, ExportDB is a great resource for learning about the kinds of programs that are supported and unsupported, along with ways to rewrite programs to make them traceable.
An option to get past dealing with this graph breaks is by using non-strict export
Data/Shape-Dependent Control Flow¶
Graph breaks can also be encountered on data-dependent control flow (if
x.shape[0] > 2
) when shapes are not being specialized, as a tracing compiler cannot
possibly deal with without generating code for a combinatorially exploding
number of paths. In such cases, users will need to rewrite their code using
special control flow operators. Currently, we support torch.cond
to express if-else like control flow (more coming soon!).
Missing Fake/Meta/Abstract Kernels for Operators¶
When tracing, a FakeTensor kernel (aka meta kernel, abstract impl) is required for all operators. This is used to reason about the input/output shapes for this operator.
Please see torch.library.register_fake()
for more details.
In the unfortunate case where your model uses an ATen operator that is does not have a FakeTensor kernel implementation yet, please file an issue.
Read More¶
API Reference¶
- torch.export.export(mod, args, kwargs=None, *, dynamic_shapes=None, strict=True, preserve_module_call_signature=())[source]¶
export()
takes an arbitrary Python callable (an nn.Module, a function or a method) along with example inputs, and produces a traced graph representing only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, which can subsequently be executed with different inputs or serialized. The traced graph (1) produces normalized operators in the functional ATen operator set (as well as any user-specified custom operators), (2) has eliminated all Python control flow and data structures (with certain exceptions), and (3) records the set of shape constraints needed to show that this normalization and control-flow elimination is sound for future inputs.Soundness Guarantee
While tracing,
export()
takes note of shape-related assumptions made by the user program and the underlying PyTorch operator kernels. The outputExportedProgram
is considered valid only when these assumptions hold true.Tracing makes assumptions on the shapes (not values) of input tensors. Such assumptions must be validated at graph capture time for
export()
to succeed. Specifically:Assumptions on static shapes of input tensors are automatically validated without additional effort.
Assumptions on dynamic shape of input tensors require explicit specification by using the
Dim()
API to construct dynamic dimensions and by associating them with example inputs through thedynamic_shapes
argument.
If any assumption can not be validated, a fatal error will be raised. When that happens, the error message will include suggested fixes to the specification that are needed to validate the assumptions. For example
export()
might suggest the following fix to the definition of a dynamic dimensiondim0_x
, say appearing in the shape associated with inputx
, that was previously defined asDim("dim0_x")
:dim = Dim("dim0_x", max=5)
This example means the generated code requires dimension 0 of input
x
to be less than or equal to 5 to be valid. You can inspect the suggested fixes to dynamic dimension definitions and then copy them verbatim into your code without needing to change thedynamic_shapes
argument to yourexport()
call.- Parameters
mod (Module) – We will trace the forward method of this module.
kwargs (Optional[Dict[str, Any]]) – Optional example keyword inputs.
dynamic_shapes (Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]]) –
An optional argument where the type should either be: 1) a dict from argument names of
f
to their dynamic shape specifications, 2) a tuple that specifies dynamic shape specifications for each input in original order. If you are specifying dynamism on keyword args, you will need to pass them in the order that is defined in the original function signature.The dynamic shape of a tensor argument can be specified as either (1) a dict from dynamic dimension indices to
Dim()
types, where it is not required to include static dimension indices in this dict, but when they are, they should be mapped to None; or (2) a tuple / list ofDim()
types or None, where theDim()
types correspond to dynamic dimensions, and static dimensions are denoted by None. Arguments that are dicts or tuples / lists of tensors are recursively specified by using mappings or sequences of contained specifications.strict (bool) – When enabled (default), the export function will trace the program through TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the exported program will not validate the implicit assumptions baked into the graph and may cause behavior divergence between the original model and the exported one. This is useful when users need to workaround bugs in the tracer, or simply want incrementally enable safety in their models. Note that this does not affect the resulting IR spec to be different and the model will be serialized in the same way regardless of what value is passed here. WARNING: This option is experimental and use this at your own risk.
- Returns
An
ExportedProgram
containing the traced callable.- Return type
Acceptable input/output types
Acceptable types of inputs (for
args
andkwargs
) and outputs include:Primitive types, i.e.
torch.Tensor
,int
,float
,bool
andstr
.Dataclasses, but they must be registered by calling
register_dataclass()
first.(Nested) Data structures comprising of
dict
,list
,tuple
,namedtuple
andOrderedDict
containing all above types.
- torch.export.dynamic_shapes.dynamic_dim(t, index, debug_name=None)[source]¶
Warning
(This feature is DEPRECATED. See
Dim()
instead.)dynamic_dim()
constructs a_Constraint
object that describes the dynamism of a dimensionindex
of tensort
._Constraint
objects should be passed toconstraints
argument ofexport()
.- Parameters
t (torch.Tensor) – Example input tensor that have dynamic dimension size(s)
index (int) – Index of dynamic dimension
- Returns
A
_Constraint
object that describes shape dynamism. It can be passed toexport()
so thatexport()
does not assume static size of specified tensor, i.e. keeping it dynamic as a symbolic size rather than specializing according to size of example tracing input.
Specifically
dynamic_dim()
can be used to express following types of dynamism.Size of a dimension is dynamic and unbounded:
t0 = torch.rand(2, 3) t1 = torch.rand(3, 4) # First dimension of t0 can be dynamic size rather than always being static size 2 constraints = [dynamic_dim(t0, 0)] ep = export(fn, (t0, t1), constraints=constraints)
Size of a dimension is dynamic with a lower bound:
t0 = torch.rand(10, 3) t1 = torch.rand(3, 4) # First dimension of t0 can be dynamic size with a lower bound of 5 (inclusive) # Second dimension of t1 can be dynamic size with a lower bound of 2 (exclusive) constraints = [ dynamic_dim(t0, 0) >= 5, dynamic_dim(t1, 1) > 2, ] ep = export(fn, (t0, t1), constraints=constraints)
Size of a dimension is dynamic with an upper bound:
t0 = torch.rand(10, 3) t1 = torch.rand(3, 4) # First dimension of t0 can be dynamic size with a upper bound of 16 (inclusive) # Second dimension of t1 can be dynamic size with a upper bound of 8 (exclusive) constraints = [ dynamic_dim(t0, 0) <= 16, dynamic_dim(t1, 1) < 8, ] ep = export(fn, (t0, t1), constraints=constraints)
Size of a dimension is dynamic and it is always equal to size of another dynamic dimension:
t0 = torch.rand(10, 3) t1 = torch.rand(3, 4) # Sizes of second dimension of t0 and first dimension are always equal constraints = [ dynamic_dim(t0, 1) == dynamic_dim(t1, 0), ] ep = export(fn, (t0, t1), constraints=constraints)
Mix and match all types above as long as they do not express conflicting requirements
- torch.export.save(ep, f, *, extra_files=None, opset_version=None)[source]¶
Warning
Under active development, saved files may not be usable in newer versions of PyTorch.
Saves an
ExportedProgram
to a file-like object. It can then be loaded using the Python APItorch.export.load
.- Parameters
ep (ExportedProgram) – The exported program to save.
f (Union[str, os.PathLike, io.BytesIO) – A file-like object (has to implement write and flush) or a string containing a file name.
extra_files (Optional[Dict[str, Any]]) – Map from filename to contents which will be stored as part of f.
opset_version (Optional[Dict[str, int]]) – A map of opset names to the version of this opset
Example:
import torch import io class MyModule(torch.nn.Module): def forward(self, x): return x + 10 ep = torch.export.export(MyModule(), (torch.randn(5),)) # Save to file torch.export.save(ep, 'exported_program.pt2') # Save to io.BytesIO buffer buffer = io.BytesIO() torch.export.save(ep, buffer) # Save with extra files extra_files = {'foo.txt': b'bar'.decode('utf-8')} torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
- torch.export.load(f, *, extra_files=None, expected_opset_version=None)[source]¶
Warning
Under active development, saved files may not be usable in newer versions of PyTorch.
Loads an
ExportedProgram
previously saved withtorch.export.save
.- Parameters
ep (ExportedProgram) – The exported program to save.
f (Union[str, os.PathLike, io.BytesIO) – A file-like object (has to implement write and flush) or a string containing a file name.
extra_files (Optional[Dict[str, Any]]) – The extra filenames given in this map would be loaded and their content would be stored in the provided map.
expected_opset_version (Optional[Dict[str, int]]) – A map of opset names to expected opset versions
- Returns
An
ExportedProgram
object- Return type
Example:
import torch import io # Load ExportedProgram from file ep = torch.export.load('exported_program.pt2') # Load ExportedProgram from io.BytesIO object with open('exported_program.pt2', 'rb') as f: buffer = io.BytesIO(f.read()) buffer.seek(0) ep = torch.export.load(buffer) # Load with extra files. extra_files = {'foo.txt': ''} # values will be replaced with data ep = torch.export.load('exported_program.pt2', extra_files=extra_files) print(extra_files['foo.txt']) print(ep(torch.randn(5)))
- torch.export.register_dataclass(cls, *, serialized_type_name=None)[source]¶
Registers a dataclass as a valid input/output type for
torch.export.export()
.- Parameters
Example:
@dataclass class InputDataClass: feature: torch.Tensor bias: int class OutputDataClass: res: torch.Tensor torch.export.register_dataclass(InputDataClass) torch.export.register_dataclass(OutputDataClass) def fn(o: InputDataClass) -> torch.Tensor: res = res=o.feature + o.bias return OutputDataClass(res=res) ep = torch.export.export(fn, (InputDataClass(torch.ones(2, 2), 1), )) print(ep)
- torch.export.dynamic_shapes.Dim(name, *, min=None, max=None)[source]¶
Dim()
constructs a type analogous to a named symbolic integer with a range. It can be used to describe multiple possible values of a dynamic tensor dimension. Note that different dynamic dimensions of the same tensor, or of different tensors, can be described by the same type.
- class torch.export.dynamic_shapes.ShapesCollection[source]¶
Builder for dynamic_shapes. Used to assign dynamic shape specifications to tensors that appear in inputs.
- Example::
args = ({“x”: tensor_x, “others”: [tensor_y, tensor_z]})
dim = torch.export.Dim(…) dynamic_shapes = torch.export.ShapesCollection() dynamic_shapes[tensor_x] = (dim, dim + 1, 8) dynamic_shapes[tensor_y] = {0: dim * 2} # This is equivalent to the following (now auto-generated): # dynamic_shapes = {“x”: (dim, dim + 1, 8), “others”: [{0: dim * 2}, None]}
torch.export(…, args, dynamic_shapes=dynamic_shapes)
- torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(msg, dynamic_shapes)[source]¶
For working with export’s dynamic shapes suggested fixes, and/or automatic dynamic shapes. Refines the given dynamic shapes spec, given a ConstraintViolation error message and the original dynamic shapes.
For most cases behavior is straightforward - i.e. for suggested fixes that specialize or refine a Dim’s range, or fixes that suggest a derived relation, the new dynamic shapes spec will be updated as such.
e.g. Suggested fixes:
dim = Dim(‘dim’, min=3, max=6) -> this just refines the dim’s range dim = 4 -> this specializes to a constant dy = dx + 1 -> dy was specified as an independent dim, but is actually tied to dx with this relation
However, suggested fixes associated with derived dims can be more complicated. For example, if a suggested fix is provided for a root dim, the new derived dim value is evaluated based on the root.
e.g. dx = Dim(‘dx’) dy = dx + 2 dynamic_shapes = {“x”: (dx,), “y”: (dy,)}
Suggested fixes:
dx = 4 # specialization will lead to dy also specializing = 6 dx = Dim(‘dx’, max=6) # dy now has max = 8
Derived dims suggested fixes can also be used to express divisibility constraints. This involves creating new root dims that aren’t tied to a particular input shape. In this case the root dims won’t appear directly in the new spec, but as a root of one of the dims.
e.g. Suggested fixes:
_dx = Dim(‘_dx’, max=1024) # this won’t appear in the return result, but dx will dx = 4*_dx # dx is now divisible by 4, with a max value of 4096
- class torch.export.ExportedProgram(root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs=None, verifier=None, tensor_constants=None, constants=None)[source]¶
Package of a program from
export()
. It contains antorch.fx.Graph
that represents Tensor computation, a state_dict containing tensor values of all lifted parameters and buffers, and various metadata.You can call an ExportedProgram like the original callable traced by
export()
with the same calling convention.To perform transformations on the graph, use
.module
property to access antorch.fx.GraphModule
. You can then use FX transformation to rewrite the graph. Afterwards, you can simply useexport()
again to construct a correct ExportedProgram.- module()[source]¶
Returns a self contained GraphModule with all the parameters/buffers inlined.
- Return type
- Module
- buffers()[source]¶
Returns an iterator over original module buffers.
Warning
This API is experimental and is NOT backward-compatible.
- named_buffers()[source]¶
Returns an iterator over original module buffers, yielding both the name of the buffer as well as the buffer itself.
Warning
This API is experimental and is NOT backward-compatible.
- parameters()[source]¶
Returns an iterator over original module’s parameters.
Warning
This API is experimental and is NOT backward-compatible.
- named_parameters()[source]¶
Returns an iterator over original module parameters, yielding both the name of the parameter as well as the parameter itself.
Warning
This API is experimental and is NOT backward-compatible.
- run_decompositions(decomp_table=None)[source]¶
Run a set of decompositions on the exported program and returns a new exported program. By default we will run the Core ATen decompositions to get operators in the Core ATen Operator Set.
For now, we do not decompose joint graphs.
- Return type
- class torch.export.ExportBackwardSignature(gradients_to_parameters: Dict[str, str], gradients_to_user_inputs: Dict[str, str], loss_output: str)[source]¶
- class torch.export.ExportGraphSignature(input_specs, output_specs)[source]¶
ExportGraphSignature
models the input/output signature of Export Graph, which is a fx.Graph with stronger invariants gurantees.Export Graph is functional and does not access “states” like parameters or buffers within the graph via
getattr
nodes. Instead,export()
gurantees that parameters, buffers, and constant tensors are lifted out of the graph as inputs. Similarly, any mutations to buffers are not included in the graph either, instead the updated values of mutated buffers are modeled as additional outputs of Export Graph.The ordering of all inputs and outputs are:
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]
e.g. If following module is exported:
class CustomModule(nn.Module): def __init__(self): super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output
Resulting Graph would be:
graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)
Resulting ExportGraphSignature would be:
ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )
- class torch.export.ModuleCallSignature(inputs: List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], outputs: List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], in_spec: torch.utils._pytree.TreeSpec, out_spec: torch.utils._pytree.TreeSpec)[source]¶
- class torch.export.ModuleCallEntry(fqn: str, signature: Union[torch.export.exported_program.ModuleCallSignature, NoneType] = None)[source]¶
- class torch.export.graph_signature.InputSpec(kind: torch.export.graph_signature.InputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Union[str, NoneType], persistent: Union[bool, NoneType] = None)[source]¶
- class torch.export.graph_signature.OutputSpec(kind: torch.export.graph_signature.OutputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Union[str, NoneType])[source]¶
- class torch.export.graph_signature.ExportGraphSignature(input_specs, output_specs)[source]¶
ExportGraphSignature
models the input/output signature of Export Graph, which is a fx.Graph with stronger invariants gurantees.Export Graph is functional and does not access “states” like parameters or buffers within the graph via
getattr
nodes. Instead,export()
gurantees that parameters, buffers, and constant tensors are lifted out of the graph as inputs. Similarly, any mutations to buffers are not included in the graph either, instead the updated values of mutated buffers are modeled as additional outputs of Export Graph.The ordering of all inputs and outputs are:
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]
e.g. If following module is exported:
class CustomModule(nn.Module): def __init__(self): super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output
Resulting Graph would be:
graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)
Resulting ExportGraphSignature would be:
ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )
- class torch.export.unflatten.FlatArgsAdapter[source]¶
Adapts input arguments with
input_spec
to aligntarget_spec
.
- class torch.export.unflatten.InterpreterModule(graph)[source]¶
A module that uses torch.fx.Interpreter to execute instead of the usual codegen that GraphModule uses. This provides better stack trace information and makes it easier to debug execution.
- torch.export.unflatten.unflatten(module, flat_args_adapter=None)[source]¶
Unflatten an ExportedProgram, producing a module with the same module hierarchy as the original eager module. This can be useful if you are trying to use
torch.export
with another system that expects a module hierachy instead of the flat graph thattorch.export
usually produces.Note
The args/kwargs of unflattened modules will not necessarily match the eager module, so doing a module swap (e.g.
self.submod = new_mod
) will not necessarily work. If you need to swap a module out, you need to set thepreserve_module_call_signature
parameter oftorch.export.export()
.- Parameters
module (ExportedProgram) – The ExportedProgram to unflatten.
flat_args_adapter (Optional[FlatArgsAdapter]) – Adapt flat args if input TreeSpec does not match with exported module’s.
- Returns
An instance of
UnflattenedModule
, which has the same module hierarchy as the original eager module pre-export.- Return type
UnflattenedModule