Shortcuts

Source code for torch.export

import copy
import dataclasses
import io
import pathlib
import typing
from enum import auto, Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import sympy

import torch
import torch.fx._pytree as fx_pytree
import torch.utils._pytree as pytree
from torch.fx._compatibility import compatibility

from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint

from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.infra.pass_manager import PassManager


__all__ = [
    "ArgumentKind",
    "ArgumentSpec",
    "Constraint",
    "ExportBackwardSignature",
    "ExportGraphSignature",
    "ExportedProgram",
    "ModuleCallEntry",
    "ModuleCallSignature",
    "constrain_as_size",
    "constrain_as_value",
    "dynamic_dim",
    "export",
    "load",
    "save",
]


PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]


[docs]@dataclasses.dataclass class ExportBackwardSignature: gradients_to_parameters: Dict[str, str] gradients_to_user_inputs: Dict[str, str] loss_output: str
[docs]@dataclasses.dataclass class ExportGraphSignature: """ :class:`ExportGraphSignature` models the input/output signature of Export Graph, which is a fx.Graph with stronger invariants gurantees. Export Graph is functional and does not access "states" like parameters or buffers within the graph via ``getattr`` nodes. Instead, :func:`export` gurantees that parameters and buffers are lifted out of the graph as inputs. Similarly, any mutations to buffers are not included in the graph either, instead the updated values of mutated buffers are modeled as additional outputs of Export Graph. The ordering of all inputs and outputs are:: Inputs = [*parameters_buffers, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs] e.g. If following module is exported:: class CustomModule(nn.Module): def __init__(self): super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output Resulting Graph would be:: graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1) Resulting ExportGraphSignature would be:: ExportGraphSignature( # Indicates that there is one parameter named `my_parameter` parameters=['L__self___my_parameter'], # Indicates that there are two buffers, `my_buffer1` and `my_buffer2` buffers=['L__self___my_buffer1', 'L__self___my_buffer2'], # Indicates that the nodes `arg3_1` and `arg4_1` in produced graph map to # original user inputs, ie. x1 and x2 user_inputs=['arg3_1', 'arg4_1'], # Indicates that the node `add_tensor_1` maps to output of original program user_outputs=['add_tensor_1'], # Indicates that there is one parameter (self.my_parameter) captured, # its name is now mangled to be `L__self___my_parameter`, which is now # represented by node `arg0_1` in the graph. inputs_to_parameters={'arg0_1': 'L__self___my_parameter'}, # Indicates that there are two buffers (self.my_buffer1, self.my_buffer2) captured, # their name are now mangled to be `L__self___my_my_buffer1` and `L__self___my_buffer2`. # They are now represented by nodes `arg1_1` and `arg2_1` in the graph. inputs_to_buffers={'arg1_1': 'L__self___my_buffer1', 'arg2_1': 'L__self___my_buffer2'}, # Indicates that one buffer named `L__self___my_buffer2` is mutated during execution, # its new value is output from the graph represented by the node named `add_tensor_2` buffers_to_mutate={'add_tensor_2': 'L__self___my_buffer2'}, # Backward graph not captured backward_signature=None, # Work in progress feature, please ignore now. assertion_dep_token=None ) """ # A list of parameters uniquely identified by mangled fully qualified name parameters: List[str] # A list of buffers uniquely identified by mangled fully qualified name buffers: List[str] # Graph node names of pytree-flattened inputs of original program user_inputs: List[str] # Graph node names of pytree-flattened outputs of original program user_outputs: List[str] # A dictionary mapping graph input node names to parameters. If a graph input # name is found in this dictionary, it is guranteed to be a lifted parameter. inputs_to_parameters: Dict[str, str] # A dictionary mapping graph input node names to buffers. If a graph input # name is found in this dictionary, it is guranteed to be a lifted buffer. inputs_to_buffers: Dict[str, str] # A dictionary mapping graph output node names to buffers that are mutated in the # original program. Buffers that are not mutated will not be found in this dictionary. buffers_to_mutate: Dict[str, str] backward_signature: Optional[ExportBackwardSignature] # Map from assertion dependency token index to assertion dep token output # name in output. The shape of output after aot_autograd will be like: # (updated_inputs, user_outputs, dep_token). assertion_dep_token: Optional[Dict[int, str]] = None def __post_init__(self) -> None: assertion_dep_token = self.assertion_dep_token if assertion_dep_token is None: return assert len(assertion_dep_token) == 1 assertion_dep_token_index = list(assertion_dep_token.keys())[0] assert ( len(self.user_outputs) + len(self.buffers_to_mutate) == assertion_dep_token_index )
[docs]class ArgumentKind(Enum): Tensor = auto() SymInt = auto() Constant = auto()
[docs]@dataclasses.dataclass class ArgumentSpec: kind: ArgumentKind value: Any def __post_init__(self): if self.kind in (ArgumentKind.Tensor, ArgumentKind.SymInt): assert isinstance(self.value, str)
[docs]@dataclasses.dataclass class ModuleCallSignature: inputs: List[ArgumentSpec] outputs: List[ArgumentSpec] in_spec: pytree.TreeSpec out_spec: pytree.TreeSpec
[docs]@dataclasses.dataclass class ModuleCallEntry: fqn: str signature: Optional[ModuleCallSignature] = None
[docs]class ExportedProgram: """ Package of a program from :func:`export`. It contains an :class:`torch.fx.Graph` that represents Tensor computation, a state_dict containing tensor values of all lifted parameters and buffers, and various metadata. You can call an ExportedProgram like the original callable traced by :func:`export` with the same calling convention. To perform transformations on the graph, use ``.module`` property to access an :class:`torch.fx.GraphModule`. You can then use `FX transformation <https://pytorch.org/docs/stable/fx.html#writing-transformations>`_ to rewrite the graph. Afterwards, you can simply use :func:`export` again to construct a correct ExportedProgram. """ def __init__( self, root: Union[torch.nn.Module, Dict[str, Any]], graph: torch.fx.Graph, graph_signature: ExportGraphSignature, call_spec: Any, state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]], range_constraints: Dict[sympy.Symbol, Any], equality_constraints: List[Tuple[Any, Any]], module_call_graph: List[ModuleCallEntry], example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None, ): from torch._export.exported_program import ( _create_graph_module_for_export, CallSpec, ) from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( InputDim, RangeConstraint, ) # Remove codegen related things from the graph. It should just be a flat graph. graph._codegen = torch.fx.graph.CodeGen() self._graph_module = _create_graph_module_for_export(root, graph) if isinstance(root, torch.fx.GraphModule): self._graph_module.meta.update(root.meta) self._graph_signature: ExportGraphSignature = graph_signature self._call_spec: CallSpec = call_spec self._state_dict: Dict[str, Any] = state_dict self._range_constraints: Dict[sympy.Symbol, RangeConstraint] = range_constraints self._equality_constraints: List[ Tuple[InputDim, InputDim] ] = equality_constraints self._module_call_graph: List[ModuleCallEntry] = module_call_graph self._example_inputs = example_inputs @property @compatibility(is_backward_compatible=False) def graph_module(self): return self._graph_module @property @compatibility(is_backward_compatible=False) def graph(self): return self.graph_module.graph @property @compatibility(is_backward_compatible=False) def graph_signature(self): return self._graph_signature @property @compatibility(is_backward_compatible=False) def state_dict(self): return self._state_dict @property @compatibility(is_backward_compatible=False) def call_spec(self): return self._call_spec @property @compatibility(is_backward_compatible=False) def range_constraints(self): return self._range_constraints @property @compatibility(is_backward_compatible=False) def equality_constraints(self): return self._equality_constraints @property @compatibility(is_backward_compatible=False) def module_call_graph(self): return self._module_call_graph @property @compatibility(is_backward_compatible=False) def example_inputs(self): return self._example_inputs def __call__(self, *args: Any, **kwargs: Any) -> Any: import torch._export.error as error from torch._export import combine_args_kwargs if self.call_spec.in_spec is not None: try: user_args = combine_args_kwargs(args, kwargs) args = fx_pytree.tree_flatten_spec(user_args, self.call_spec.in_spec) # type: ignore[assignment] except Exception: _, received_spec = pytree.tree_flatten(user_args) raise error.InternalError( "Trying to flatten user inputs with exported input tree spec: \n" f"{self.call_spec.in_spec}\n" "but actually got inputs with tree spec of: \n" f"{received_spec}" ) ordered_params = tuple( self.state_dict[name] for name in self.graph_signature.parameters ) ordered_buffers = tuple( self.state_dict[name] for name in self.graph_signature.buffers ) self._check_input_constraints(*ordered_params, *ordered_buffers, *args) with torch.no_grad(): # NOTE: calling convention is first params, then buffers, then args as user supplied them. # See: torch/_functorch/aot_autograd.py#L1034 res = torch.fx.Interpreter(self.graph_module).run( *ordered_params, *ordered_buffers, *args, enable_io_processing=False ) if self.call_spec.out_spec is not None: mutation = self.graph_signature.buffers_to_mutate num_mutated = len(mutation) mutated_buffers = res[:num_mutated] # Exclude dependency token from final result. assertion_dep_token = self.graph_signature.assertion_dep_token if assertion_dep_token is not None: assertion_dep_token_index = list(assertion_dep_token.keys())[0] res = res[:assertion_dep_token_index] res = res[num_mutated:] try: res = pytree.tree_unflatten(res, self.call_spec.out_spec) except Exception: _, received_spec = pytree.tree_flatten(res) raise error.InternalError( "Trying to flatten user outputs with exported output tree spec: \n" f"{self.call_spec.out_spec}\n" "but actually got outputs with tree spec of: \n" f"{received_spec}" ) finally: ix = 0 for buffer in self.graph_signature.buffers_to_mutate.values(): self.state_dict[buffer] = mutated_buffers[ix] ix += 1 return res def __str__(self) -> str: graph_module = self.graph_module.print_readable(print_output=False).replace( "\n", "\n " ) string = ( "ExportedProgram:\n" f" {graph_module}\n" f"Graph Signature: {self.graph_signature}\n" f"Symbol to range: {self.range_constraints}\n" ) return string
[docs] def module(self) -> torch.nn.Module: """ Returns a self contained GraphModule with all the parameters/buffers inlined. """ from torch._export.exported_program import unlift_exported_program_lifted_states return unlift_exported_program_lifted_states(self)
def _transform(self, *passes: PassType) -> "ExportedProgram": from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( RangeConstraint, ) pm = PassManager(list(passes)) res = pm(self.graph_module) transformed_gm = res.graph_module if res is not None else self.graph_module assert transformed_gm is not None def _get_updated_range_constraints( gm: torch.fx.GraphModule, ) -> Dict[sympy.Symbol, RangeConstraint]: def get_shape_env(gm): vals = [ node.meta["val"] for node in gm.graph.nodes if node.meta.get("val", None) is not None ] from torch._guards import detect_fake_mode fake_mode = detect_fake_mode(vals) if fake_mode is not None: return fake_mode.shape_env for v in vals: if isinstance(v, torch.SymInt): return v.node.shape_env shape_env = get_shape_env(gm) if shape_env is None: return {} range_constraints = { k: RangeConstraint(v.lower, v.upper) for k, v in shape_env.var_to_range.items() } return range_constraints def get_output_node_names(gm): output_node = list(gm.graph.nodes)[-1] assert output_node.op == "output" return [str(arg) for arg in output_node.args[0]] def get_input_node_names(gm): return [node.name for node in gm.graph.nodes if node.op == "placeholder"] def _generate_new_graph_signature(old_ep, new_gm): """ Update graph_signature according to graph after transformation. Transformations can lead to node name changes, which are used in graph_signature to identify inputs and outputs. Therefore, after each transformation, we need to update the graph_signature according to new node names. WARNING: This implementation makes a few assumptions - The transformation doesn't change number of inputs/outputs - Each input/output still has the same meaning. - For inputs, that means that the inputs in transformed graph map to the same lifted parameter/buffer or user input as the input of the same position in the graph before transformation. - Similarly for outputs, each output should correspond to the same mutated buffer or user output as the output value of the same position in the graph before transformation. It is difficult to programatically validate these assumptions, but they should hold true most of the time as inputs/outputs of the graph rarely need to be changed. """ old_signature = old_ep.graph_signature old_gm = old_ep.graph_module old_graph_input_node_names = get_input_node_names(old_gm) new_graph_input_node_names = get_input_node_names(new_gm) assert len(old_graph_input_node_names) == len( new_graph_input_node_names ), f""" Number of input nodes changed from {len(old_graph_input_node_names)} to {len(new_graph_input_node_names)} after transformation. This transformation is currently not supported. """ old_graph_output_node_names = get_output_node_names(old_gm) new_graph_output_node_names = get_output_node_names(new_gm) assert len(old_graph_output_node_names) == len( new_graph_output_node_names ), f""" Number of output values changed from {len(old_graph_output_node_names)} to {len(new_graph_output_node_names)} after transformation. This transformation is currently not supported. """ node_names_mapping = dict( zip( old_graph_input_node_names + old_graph_output_node_names, new_graph_input_node_names + new_graph_output_node_names, ) ) new_signature = copy.deepcopy(old_signature) new_signature.user_inputs = [ node_names_mapping[old_user_input] for old_user_input in old_signature.user_inputs ] new_signature.user_outputs = [ node_names_mapping[old_user_output] for old_user_output in old_signature.user_outputs ] new_signature.inputs_to_parameters = { node_names_mapping[old_input_name]: old_signature.inputs_to_parameters[ old_input_name ] for old_input_name in old_signature.inputs_to_parameters.keys() } new_signature.inputs_to_buffers = { node_names_mapping[old_input_name]: old_signature.inputs_to_buffers[ old_input_name ] for old_input_name in old_signature.inputs_to_buffers.keys() } new_signature.buffers_to_mutate = { node_names_mapping[old_output_name]: old_signature.buffers_to_mutate[ old_output_name ] for old_output_name in old_signature.buffers_to_mutate.keys() } return new_signature new_graph_signature = _generate_new_graph_signature(self, transformed_gm) transformed_ep = ExportedProgram( transformed_gm, transformed_gm.graph, new_graph_signature, copy.deepcopy(self.call_spec), self.state_dict, _get_updated_range_constraints(transformed_gm), copy.deepcopy(self.equality_constraints), copy.deepcopy(self._module_call_graph), self.example_inputs, ) transformed_ep.graph_module.meta.update(self.graph_module.meta) transformed_ep.graph_module.meta.update(res.graph_module.meta) return transformed_ep def _check_input_constraints(self, *args): from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( _AddRuntimeAssertionsForConstraintsPass, ) # TODO(zhxchen17) Don't generate a runtime graph on the fly. _assertion_graph = torch.fx.GraphModule({}, torch.fx.Graph()) for p in self.graph.nodes: if p.op != "placeholder": continue new_p = _assertion_graph.graph.placeholder(p.name) new_p.meta = p.meta _assertion_graph.graph.output(()) _assertion_graph_res = _AddRuntimeAssertionsForConstraintsPass( self.range_constraints, self.equality_constraints, )(_assertion_graph) assert _assertion_graph_res is not None _assertion_graph = _assertion_graph_res.graph_module _assertion_graph(*args) def _validate(self): # TODO(zhxchen17) check for get_attr # TODO(zhxchen17) check for funcitonal ops for gm in self.graph_module.modules(): if not isinstance(gm, torch.fx.GraphModule): continue for node in gm.graph.nodes: if node.op == "call_function": assert node.target != torch.ops.higher_order._export_tracepoint
@dataclasses.dataclass class _ConstraintTarget: """ This represents input tensor dimensions. Don't create this class directly; instead, use :func:`dynamic_dim`. """ w_tensor: Any # weakref to torch.Tensor # TODO: We don't need t_id; we can get it off of w_tensor t_id: int dim: int class _ConstraintFactory(type): """ Metaclass that ensures a private constructor for :class:`Constraint` """ def __call__(cls, *args, **kwargs): raise TypeError( f"{cls.__module__}.{cls.__qualname__} has no public constructor. " f"Please use torch.export.dynamic_dim() to create one" ) def _create(cls, w_tensor, t_id, dim, constraint_range, shared=None): return super().__call__(w_tensor, t_id, dim, constraint_range, shared) def _create_constraint(w_tensor, t_id, dim, constraint_range, shared=None): return Constraint._create(w_tensor, t_id, dim, constraint_range, shared)
[docs]@dataclasses.dataclass class Constraint(_ConstraintTarget, metaclass=_ConstraintFactory): """ .. warning:: Do not construct :class:`Constraint` directly, use :func:`dynamic_dim` instead. This represents constraints on input tensor dimensions, e.g., requiring them to be fully polymorphic or within some range. """ # NOTE(avik): In the future, this could be Union[StrictMinMaxConstraint, <other kinds>] constraint_range: StrictMinMaxConstraint # Represent that `constraint_range` is shared with another _ConstraintTarget, which # typically arises because of a specified equality with another dynamic dimension. shared: Optional[_ConstraintTarget] = None def _clone_with_range(self, lower=2, upper=sympy.oo): from torch.utils._sympy.value_ranges import ValueRanges constraint_range = StrictMinMaxConstraint( vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper), warn_only=False, ) return _create_constraint( self.w_tensor, self.t_id, self.dim, constraint_range, self.shared ) def __ge__(self, lower): return self._clone_with_range(lower=lower) def __gt__(self, lower): return self._clone_with_range(lower=lower + 1) def __le__(self, upper): return self._clone_with_range(upper=upper) def __lt__(self, upper): return self._clone_with_range(upper=upper - 1) def __bool__(self): # NOTE(avik): We do not support compound expressions like a <= x <= b. # This is because Python implicitly desugars them into bool(a <= x) and bool(x <= b), # and moreover, enforces that any overload of __bool__ must return True or False. # FWIW, sympy also raises TypeError in this case. raise TypeError( "Cannot determine truth value of Constraint. " "If you are trying to combine Constraint's with logical connectives, " "you can specify them separately instead." ) @property def serializable_spec(self): # We need a serialization compatible format of the constraint so that it # can be savedin the graph module w/o breaking the module serialization. # The saved constraints will be used directly for the post-exporting pass # that converts constraints to runtime assertion. The saved constraints # will not be saved in the serialized module. # TODO: A better way is needed. Currently we use 't_id' to map the constraint, # which is not reliable return { "t_id": self.t_id, "dim": self.dim, "min": self.constraint_range.vr.lower, "max": self.constraint_range.vr.upper, "shared": ( None if self.shared is None else { "t_id": self.shared.t_id, "dim": self.shared.dim, } ), } def __eq__(self, other): if not isinstance(other, Constraint): raise TypeError( "A dynamic dim can be specified equal only to another dynamic dim. " f"Equality with {type(other)} is not supported." ) constraint_range = StrictMinMaxConstraint( vr=self.constraint_range.vr & other.constraint_range.vr, warn_only=False, ) return _create_constraint( self.w_tensor, self.t_id, self.dim, constraint_range, shared=_ConstraintTarget(other.w_tensor, other.t_id, other.dim), )
[docs]def constrain_as_value(symbol, min: Optional[int] = None, max: Optional[int] = None): """ Hint :func:`export` about the constraint of an intermediate scalar value so that subsequent branching behaviors that check on the range of aforementioned scalar value can be soundly traced. .. warning:: (Note that if the intermediate scalar value will be used like a size, including being passed as size arg to a tensor factory or view, call :func:`constrain_as_size` instead.) Args: symbol: Intermediate scalar value (int-only now) to apply range constraint on. min (Optional[int]): Minimum possible value of given symbol (inclusive) max (Optional[int]): Maximum possible value of given symbol (inclusive) Returns: None For example, following program can not be traced soundly:: def fn(x): v = x.max().item() if v > 1024: return x else: return x * 2 ``v`` is a data-dependent value, which is assumed to have a range of (-inf, inf). :func:`export()` a hint about which branch to take would not be able to determine if the traced branching decision is correct or not. Thus :func:`export()` would give following error:: torch._dynamo.exc.UserError: Consider annotating your code using torch.export.constrain_as_size() or torch.export().constrain_as_value() APIs. It appears that you're trying to get a value out of symbolic int/float whose value is data-dependent (and thus we do not know the true value.) The expression we were trying to evaluate is f0 > 1024 (unhinted: f0 > 1024). Assuming the actual range of ``v`` can be between [10, 200], you can add a call to :func:`constrain_as_value` in the source code like this:: def fn(x): v = x.max().item() # Give export() a hint torch.export.constrain_as_value(v, min=10, max=200) if v > 1024: return x else: return x * 2 With the additional hint, :func:`export` would be able to trace the program correctly by taking the ``else`` branch, resulting in following graph:: graph(): %arg0_1 := placeholder[target=arg0_1] # v = x.max().item() %max_1 := call_function[target=torch.ops.aten.max.default](args = (%arg0_1,)) %_local_scalar_dense := call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%max_1,)) # Asserting 10 <= v <= 200 %ge := call_function[target=operator.ge](args = (%_local_scalar_dense, 10)) %scalar_tensor := call_function[target=torch.ops.aten.scalar_tensor.default](args = (%ge,)) %_assert_async := call_function[target=torch.ops.aten._assert_async.msg]( args = (%scalar_tensor, _local_scalar_dense is outside of inline constraint [10, 200].)) %le := call_function[target=operator.le](args = (%_local_scalar_dense, 200)) %scalar_tensor_1 := call_function[target=torch.ops.aten.scalar_tensor.default](args = (%le,)) %_assert_async_1 := call_function[target=torch.ops.aten._assert_async.msg]( args = (%scalar_tensor_1, _local_scalar_dense is outside of inline constraint [10, 200].)) %sym_constrain_range := call_function[target=torch.ops.aten.sym_constrain_range.default]( args = (%_local_scalar_dense,), kwargs = {min: 10, max: 200}) # Always taking `else` branch to multiply elements `x` by 2 due to hints above %mul := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg0_1, 2), kwargs = {}) return (mul,) """ from torch._export.constraints import constrain_as_value return constrain_as_value(symbol, min, max)
[docs]def constrain_as_size(symbol, min: Optional[int] = None, max: Optional[int] = None): """ Hint :func:`export` about the constraint of an intermediate scalar value that represents shape of a tensor so that subsequent tensor constructors can be traced correctly because many operators need to make assumption about range of sizes. Args: symbol: Intermediate scalar value (int-only now) to apply range constraint on. min (Optional[int]): Minimum possible value of given symbol (inclusive) max (Optional[int]): Maximum possible value of given symbol (inclusive) Returns: None For example, following program can not be traced soundly wihout using :func:`constrain_as_size` to give :func:`export` a hint about shape ranges:: def fn(x): d = x.max().item() return torch.ones(v) :func:`export` would give following error:: torch._dynamo.exc.Unsupported: guard on data-dependent symbolic int/float Assuming the actual range of ``d`` can be between [3, 10], you can add a call to :func:`constrain_as_size` in the source code like this:: def fn(x): d = x.max().item() torch.export.constrain_as_size(d, min=3, max=10) return torch.ones(d) With the additional hint, :func:`export` would be able to trace the program correctly by taking the ``else`` branch, resulting in following graph:: graph(): %arg0_1 := placeholder[target=arg0_1] # d = x.max().item() %max_1 := call_function[target=torch.ops.aten.max.default](args = (%arg0_1,)) %_local_scalar_dense := call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%max_1,)) # Asserting 3 <= d <= 10 %ge := call_function[target=operator.ge](args = (%_local_scalar_dense, 3)) %scalar_tensor := call_function[target=torch.ops.aten.scalar_tensor.default](args = (%ge,)) %_assert_async := call_function[target=torch.ops.aten._assert_async.msg]( args = (%scalar_tensor, _local_scalar_dense is outside of inline constraint [3, 10].)) %le := call_function[target=operator.le](args = (%_local_scalar_dense, 10)) %scalar_tensor_1 := call_function[target=torch.ops.aten.scalar_tensor.default](args = (%le,)) %_assert_async_1 := call_function[target=torch.ops.aten._assert_async.msg]( args = (%scalar_tensor_1, _local_scalar_dense is outside of inline constraint [3, 10].)) %sym_constrain_range_for_size := call_function[target=torch.ops.aten.sym_constrain_range_for_size.default]( args = (%_local_scalar_dense,), kwargs = {min: 3, max: 10}) # Constructing new tensor with d %full := call_function[target=torch.ops.aten.full.default]( args = ([%_local_scalar_dense], 1), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu, pin_memory: False}) ...... .. warning:: if your size is intended to be dynamic, do NOT test if sizes are equal to 0 or 1, these will SILENTLY report false and be bypassed """ from torch._export.constraints import constrain_as_size return constrain_as_size(symbol, min, max)
[docs]def dynamic_dim(t: torch.Tensor, index: int): """ :func:`dynamic_dim` constructs a :class:`Constraint` object that describes the dynamism of a dimension ``index`` of tensor ``t``. :class:`Constraint` objects should be passed to ``constraints`` argument of :func:`export`. Args: t (torch.Tensor): Example input tensor that have dynamic dimension size(s) index (int): Index of dynamic dimension Returns: A :class:`Constraint` object that describes shape dynamism. It can be passed to :func:`export` so that :func:`export` does not assume static size of specified tensor, i.e. keeping it dynamic as a symbolic size rather than specializing according to size of example tracing input. Specifically :func:`dynamic_dim` can be used to express following types of dynamism. - Size of a dimension is dynamic and unbounded:: t0 = torch.rand(2, 3) t1 = torch.rand(3, 4) # First dimension of t0 can be dynamic size rather than always being static size 2 constraints = [dynamic_dim(t0, 0)] ep = export(fn, (t0, t1), constraints=constraints) - Size of a dimension is dynamic with a lower bound:: t0 = torch.rand(10, 3) t1 = torch.rand(3, 4) # First dimension of t0 can be dynamic size with a lower bound of 5 (inclusive) # Second dimension of t1 can be dynamic size with a lower bound of 2 (exclusive) constraints = [ dynamic_dim(t0, 0) >= 5, dynamic_dim(t1, 1) > 2, ] ep = export(fn, (t0, t1), constraints=constraints) - Size of a dimension is dynamic with an upper bound:: t0 = torch.rand(10, 3) t1 = torch.rand(3, 4) # First dimension of t0 can be dynamic size with a upper bound of 16 (inclusive) # Second dimension of t1 can be dynamic size with a upper bound of 8 (exclusive) constraints = [ dynamic_dim(t0, 0) <= 16, dynamic_dim(t1, 1) < 8, ] ep = export(fn, (t0, t1), constraints=constraints) - Size of a dimension is dynamic and it is always equal to size of another dynamic dimension:: t0 = torch.rand(10, 3) t1 = torch.rand(3, 4) # Sizes of second dimension of t0 and first dimension are always equal constraints = [ dynamic_dim(t0, 1) == dynamic_dim(t1, 0), ] ep = export(fn, (t0, t1), constraints=constraints) - Mix and match all types above as long as they do not express conflicting requirements """ from torch._export import dynamic_dim return dynamic_dim(t, index)
[docs]def export( f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None, *, constraints: Optional[List[Constraint]] = None, ) -> ExportedProgram: """ :func:`export` takes an arbitrary Python callable (an nn.Module, a function or a method) and produces a traced graph representing only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, which can subsequently be executed with different outputs or serialized. The traced graph (1) produces a normalized operator set consisting only of functional `Core ATen Operator Set <https://pytorch.org/docs/stable/ir.html>`_ and user specified custom operators, (2) has eliminated all Python control flow and data structures (except for certain conditions), and (3) has the set of shape constraints needed to show that this normalization and control flow elimination is sound for a future input. **Soundness Guarantee** While tracing, :func:`export()` takes note of shape-related assumptions made by the user program and the underlying PyTorch operator kernels. The output :class:`ExportedProgram` is considered valid only when these assumptions hold true. There are 2 types of assumptions made during tracing - Shapes (not values) of input tensors. - Ranges (lower and upper bound) of values extracted from intermediate tensors via ``.item()`` or direct indexing. All assumptions must be validated at graph capture time for :func:`export` to succeed. Specifically: - Assumptions on static shapes of input tensors are automatically validated without additional effort. - Assumptions on dynamic shape of input tensors require explicit `Input Constraint` constructed with :func:`dynamic_dim` APIs - Assumptions on range of intermediate values require explicit `Inline Constraint`, constructed use :func:`constrain_as_size` and :func:`constraint_as_value` APIs. If any assumption can not be validated, a fatal error will be raised. When that happens, the error message will include suggested code needed to construct necessary constraints to validate the assumptions, for example :func:`export` would suggest following code for input constraints:: def specify_constraints(x): return [ # x: dynamic_dim(x, 0) <= 5, ] This example means the program requires the dim 0 of input ``x`` to be less than or equal to 5 to be valid. You can inspect the constraints needed and then copy this exact function into your code to generated needed constraints to be passed into ``constraints`` argument. Args: f: The callable to trace. args: Example positional inputs. kwargs: Optional example keyword inputs. constraints: An optional list of constraints on the dynamic arguments that specify their possible range of shapes. By default, shapes of input torch.Tensors are assumed to be static. If an input torch.Tensor is expected to have dynamic shapes, please use :func:`dynamic_dim` to define :class:`Constraint` objects that specify the dynamics and the possible range of shapes. See :func:`dynamic_dim` docstring for examples on how to use it. Returns: An :class:`ExportedProgram` containing the traced callable. **Acceptable input/output types** Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include: - Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``. - (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and ``OrderedDict`` containing all above types. """ from torch._export import export return export(f, args, kwargs, constraints)
[docs]def save( ep: ExportedProgram, f: Union[str, pathlib.Path, io.BytesIO], *, extra_files: Optional[Dict[str, Any]] = None, opset_version: Optional[Dict[str, int]] = None, ) -> None: """ .. warning:: Under active development, saved files may not be usable in newer versions of PyTorch. Saves an :class:`ExportedProgram` to a file-like object. It can then be loaded using the Python API :func:`torch.export.load <torch.export.load>`. Args: ep (ExportedProgram): The exported program to save. f (Union[str, pathlib.Path, io.BytesIO): A file-like object (has to implement write and flush) or a string containing a file name. extra_files (Optional[Dict[str, Any]]): Map from filename to contents which will be stored as part of f. opset_version (Optional[Dict[str, int]]): A map of opset names to the version of this opset Example:: import torch import io class MyModule(torch.nn.Module): def forward(self, x): return x + 10 ep = torch.export.export(MyModule(), torch.randn(5)) # Save to file torch.export.save(ep, 'exported_program.pt2') # Save to io.BytesIO buffer buffer = io.BytesIO() torch.export.save(ep, buffer) # Save with extra files extra_files = {'foo.txt': b'bar'} torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files) """ from torch._export import save save(ep, f, extra_files=extra_files, opset_version=opset_version)
[docs]def load( f: Union[str, pathlib.Path, io.BytesIO], *, extra_files: Optional[Dict[str, Any]] = None, expected_opset_version: Optional[Dict[str, int]] = None, ) -> ExportedProgram: """ .. warning:: Under active development, saved files may not be usable in newer versions of PyTorch. Loads an :class:`ExportedProgram` previously saved with :func:`torch.export.save <torch.export.save>`. Args: ep (ExportedProgram): The exported program to save. f (Union[str, pathlib.Path, io.BytesIO): A file-like object (has to implement write and flush) or a string containing a file name. extra_files (Optional[Dict[str, Any]]): The extra filenames given in this map would be loaded and their content would be stored in the provided map. expected_opset_version (Optional[Dict[str, int]]): A map of opset names to expected opset versions Returns: An :class:`ExportedProgram` object Example:: import torch import io # Load ExportedProgram from file ep = torch.export.load('exported_program.pt2') # Load ExportedProgram from io.BytesIO object with open('exported_program.pt2', 'rb') as f: buffer = io.BytesIO(f.read()) buffer.seek(0) ep = torch.export.load(buffer) # Load with extra files. extra_files = {'foo.txt': ''} # values will be replaced with data ep = torch.export.load('exported_program.pt2', extra_files=extra_files) print(extra_files['foo.txt']) """ from torch._export import load return load( f, extra_files=extra_files, expected_opset_version=expected_opset_version )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources