Shortcuts

Source code for torch.export.exported_program

# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import contextlib
import copy
import dataclasses
import functools
import operator
import types
import warnings
from collections import namedtuple
from contextlib import contextmanager
from typing import (
    Any,
    Callable,
    Dict,
    final,
    Iterator,
    List,
    Optional,
    Tuple,
    Type,
    TYPE_CHECKING,
    Union,
)

from torch._higher_order_ops.utils import autograd_not_implemented
from torch._library.fake_class_registry import FakeScriptObject
from torch._subclasses.fake_impls import (
    _deregister_op_impl,
    _is_op_registered_to_fake_rule,
    register_op_impl,
)
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx._utils import first_call_function_nn_module_stack
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from torch.fx.immutable_collections import immutable_dict, immutable_list
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts


if TYPE_CHECKING:
    # Import the following modules during type checking to enable code intelligence features,
    # such as auto-completion in tools like pylance, even when these modules are not explicitly
    # imported in user code.

    import sympy

    from torch.utils._sympy.value_ranges import ValueRanges

import torch
import torch.utils._pytree as pytree
from torch._export.utils import (
    _collect_all_valid_cia_ops,
    _collect_and_set_constant_attrs,
    _collect_param_buffer_metadata,
    _detect_fake_mode_from_gm,
    _get_decomp_for_cia,
    _is_preservable_cia_op,
    _name_hoo_subgraph_placeholders,
    _overwrite_signature_for_non_persistent_buffers,
    _populate_param_buffer_metadata_to_new_gm,
    _rename_without_collisions,
    _special_op_to_preserve_cia,
)
from torch._export.verifier import Verifier
from torch._guards import detect_fake_mode
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch.export._tree_utils import is_equivalent, reorder_kwargs
from torch.export.decomp_utils import CustomDecompTable
from torch.fx._compatibility import compatibility
from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.infra.pass_manager import PassManager

from .graph_signature import (  # noqa: F401
    ArgumentSpec,
    ConstantArgument,
    CustomObjArgument,
    ExportGraphSignature,
    InputKind,
    InputSpec,
    OutputKind,
    OutputSpec,
    SymBoolArgument,
    SymFloatArgument,
    SymIntArgument,
    TensorArgument,
    TokenArgument,
)


__all__ = [
    "ExportedProgram",
    "ModuleCallEntry",
    "ModuleCallSignature",
    "default_decompositions",
]


PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]


[docs]@dataclasses.dataclass class ModuleCallSignature: inputs: List[ArgumentSpec] outputs: List[ArgumentSpec] in_spec: pytree.TreeSpec out_spec: pytree.TreeSpec forward_arg_names: Optional[List[str]] = None def replace_all_uses_with(self, original_node, new_node): for i in self.inputs: if i.name == original_node.name: i.name = new_node.name for o in self.outputs: if o.name == original_node.name: o.name = new_node.name
[docs]@dataclasses.dataclass class ModuleCallEntry: fqn: str signature: Optional[ModuleCallSignature] = None
def _disable_prexisiting_fake_mode(fn): @functools.wraps(fn) def wrapper(*args, **kwargs): with unset_fake_temporarily(): return fn(*args, **kwargs) return wrapper def _fx_collection_equivalence_fn( spec1_type: Optional[type], spec1_context: pytree.Context, spec2_type: Optional[type], spec2_context: pytree.Context, ) -> bool: """Treat containers and their immutable variants as the same type. Otherwise compare as normal. """ if spec1_type is None or spec2_type is None: return spec1_type is spec2_type and spec1_context == spec2_context if issubclass(spec1_type, (dict, immutable_dict)) and issubclass( spec2_type, (dict, immutable_dict) ): return spec1_context == spec2_context if issubclass(spec1_type, (list, immutable_list)) and issubclass( spec2_type, (list, immutable_list) ): return spec1_context == spec2_context return spec1_type is spec2_type and spec1_context == spec2_context # This list is compiled from DispatchKey.cpp. # The idea is that we use these keys to override # CIA decomp in export _AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE = [ torch._C.DispatchKey.AutogradCPU, torch._C.DispatchKey.AutogradCUDA, torch._C.DispatchKey.AutogradMeta, torch._C.DispatchKey.AutogradXLA, torch._C.DispatchKey.AutogradLazy, torch._C.DispatchKey.AutogradIPU, torch._C.DispatchKey.AutogradXPU, torch._C.DispatchKey.AutogradMPS, torch._C.DispatchKey.AutogradHPU, torch._C.DispatchKey.AutogradPrivateUse1, torch._C.DispatchKey.AutogradPrivateUse2, torch._C.DispatchKey.AutogradPrivateUse3, ] # This list is compiled from DispatchKey.cpp. # The idea is that we use these keys to add # python kernels that directly uses default # CIA decomp # See NOTE Registering old CIA to Backend kernel _BACKEND_KEYS_TO_OVERRIDE = [ torch._C.DispatchKey.CPU, torch._C.DispatchKey.CUDA, torch._C.DispatchKey.Meta, torch._C.DispatchKey.XLA, torch._C.DispatchKey.Lazy, torch._C.DispatchKey.IPU, torch._C.DispatchKey.XPU, torch._C.DispatchKey.MPS, torch._C.DispatchKey.HPU, ] @contextmanager def _override_composite_implicit_decomp(cia_ops_to_callable, safe=True): # This function overrides CompositeImplicitAutograd decomp for # functional composite ops that user specified. Ideally we want to not-decompose # ALL composite ops but today's C++ functinalization relies on # the fact that it is working with the opset after decomp is run. # Hence we can only do it for functional ops. One caveat is that # there are some composite ops that lie about their schema (claimed to be # functional but not really aka dropout), for these cases, we just decompose. # When safe=False, we will assume that ops_to_preserve can be mutating/aliasing # and their usual decompositions need to be shadowed rather than overridden. # Thus we will avoid asserting that they are valid to preserve, and will not # replace their CompositeImplicitAutograd kernels with NotImplemented. # The only current users of this mode are variants of aten::to that we will # replace with aten::_to_copy in FunctionalTensorMode.__torch_dispatch__. saved_tables = {} patched_ops = set() for op_overload, decomp_callable in cia_ops_to_callable.items(): saved_tables[op_overload] = op_overload.py_kernels.copy() patched_ops.add(op_overload) for override_dispatch_key in _AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE: if override_dispatch_key not in op_overload.py_kernels: # TODO (tmanlaibaatar)https://github.com/pytorch/pytorch/issues/129430 op_overload.py_impl(override_dispatch_key)( autograd_not_implemented(op_overload, deferred_error=True) ) # See NOTE: Registering old CIA to Backend kernel # It is important that we cache this before we override py_kernels. orig_cia_callable = _get_decomp_for_cia(op_overload) if torch._C.DispatchKey.CompositeImplicitAutograd in op_overload.py_kernels: del op_overload.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd] if safe: op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)( decomp_callable ) # [NOTE] Directly registering fake tensor rule to CIA ops # The problem we are facing here is if your CIA custom rule # says we want to preserve the op, we will return NotImplemented. # Unfortunately, this will invoke meta device tracing in fake tensor # resulting in divergent behaviour for CIA kernels that has device based # branching (one case is torch.ops.aten.scaled_dot_product.attention) # To get around this issue, we register direct fake impl so that we # run the kernel before we actually try to decompose the op in FakeTensorMode. # Note that is a no-op in most cases, because: # 1) In post dispatch tracing, CIA would have already decomposed # 2) Most CIA impl are device agnostic. def _force_dispatch_to_orig_cia_callable(fake_tensor_mode, op, *args, **kwargs): orig_cia_callable = kwargs["original_callable"] del kwargs["original_callable"] with fake_tensor_mode: return orig_cia_callable(*args, **kwargs) if not _is_op_registered_to_fake_rule(op_overload): register_op_impl(op_overload)( functools.partial( _force_dispatch_to_orig_cia_callable, original_callable=orig_cia_callable, ) ) for key in _BACKEND_KEYS_TO_OVERRIDE: if key not in op_overload.py_kernels: # [NOTE] Registering old CIA to Backend kernel # We always register original CIA behavior to the backend keys kernel # The reason is when we are fake tensor prop-ing or executing real kernel, # we end up calling an operator on respective backend, which in python dispatcher, # will resolve into CIA key. (see resolve_key in torch/_ops.py) # As a result, this CIA now will call into the custom user defined # CIA which can cause a problem. # To make it more concrete, the case we are handling is: # (1) there is a tensor constant we are performing constant propagation # on during tracing # (2) we invoke an op underneath autograd (either because we are below autograd, # or we are tracing in inference mode), so one of the backend keys gets hit # (3) the op we are invoking has a CIA impl that normally runs in eager mode # (and the user wants to tweak this CIA impl during tracing, but during # const-prop we want the original CIA to run op_overload.py_impl(key)(orig_cia_callable) try: yield finally: for op in patched_ops: op.py_kernels.clear() op.py_kernels.update(saved_tables[op]) op._dispatch_cache.clear() _deregister_op_impl(op) @contextmanager def _override_decomp_aten_to_variants(): # Preserve variants of aten::to understanding that they are mutating/aliasing # and their CompositeImplicitAutograd kernels will not become NotImplemented. # We will later replace them with aten._to_copy when functionalizing. with _override_composite_implicit_decomp( { torch.ops.aten.to.dtype_layout: _special_op_to_preserve_cia, torch.ops.aten.to.dtype: _special_op_to_preserve_cia, }, safe=False, ): yield def _split_decomp_table_to_cia_and_python_decomp( decomp_table: Dict[torch._ops.OperatorBase, Callable] ) -> Tuple[Dict[torch._ops.OperatorBase, Callable], ...]: all_preservable_cia_ops = set(_collect_all_valid_cia_ops()) cia_ops_to_callable = {} for op in list(decomp_table.keys()): # TODO we are silently allowing non-safe(non-functional) ops through a crack # due to core aten decomp table having non-functional entries. Once we have # a tigher check around core aten decomp, we should warn users about them. # Tracking issue: (https://github.com/pytorch/pytorch/issues/135759) # if it is a valid CIA op we can mess with in export, we check if it is: # 1. Has been marked as to be decomposed. Example: # decomp_table = decomp_table_to_core_aten() # del decomp_table[aten.linear] # In this case, user says decompose everything except for aten.linear # 2. Has been marked with custom decomp behavour. Example: # decomp_table = {aten.linear: some_op} # For (1), we want to remove all the CIA ops that weren't handled by user as # it suggests they are safe to decompose, so we should remove from preservable_list. # for (2), we just plumb the custom decomp to AOTDIspatcher. # In both cases, we want to remove this CIA op from the decomp_table as it is special # handled. if op in all_preservable_cia_ops: cia_ops_to_callable[op] = decomp_table[op] all_preservable_cia_ops.remove(op) del decomp_table[op] # If it is a custom op, we want to still preserve or do whatever # with it if it is a functional CIA. The reason we don't remove # from CIA list is because we don't query custom ops. elif _is_preservable_cia_op(op): op_name = op.name() assert not op_name.startswith("aten"), "This should be a custom op" cia_ops_to_callable[op] = decomp_table[op] # If we reached here, it means user intentionally deleted these CIA ops from # decomp table. for k in all_preservable_cia_ops: cia_ops_to_callable[k] = _special_op_to_preserve_cia return cia_ops_to_callable, decomp_table
[docs]def default_decompositions() -> "CustomDecompTable": """ This is the default decomposition table which contains decomposition of all ATEN operators to core aten opset. Use this API together with :func:`run_decompositions()` """ return CustomDecompTable()
def _decompose_and_get_gm_with_new_signature_constants( ep, *, cia_to_decomp: Dict[torch._ops.OperatorBase, Callable], python_decomp_table: Dict[torch._ops.OperatorBase, Callable], joint_loss_index: Optional[int], ): from torch._functorch.aot_autograd import aot_export_module from torch.export._trace import ( _export_to_aten_ir, _fakify_params_buffers, _ignore_backend_decomps, _verify_nn_module_stack, _verify_placeholder_names, _verify_stack_trace, ) from torch.fx.experimental.symbolic_shapes import ShapeEnv def _is_joint_ir_decomp(ep, joint_loss_index): return ( joint_loss_index is not None or ep.graph_signature.backward_signature is not None ) if not _is_joint_ir_decomp(ep, joint_loss_index): mod = ep.module() wrapped_params = dict(mod.named_parameters(remove_duplicate=False)) from torch._functorch._aot_autograd.subclass_parametrization import ( unwrap_tensor_subclass_parameters, ) # [NOTE] Unwrapping subclasses AOT # In torch.compile, the subclass unwrapping/wrapping happen at runtime # but at export, this is impossible as it is intented to be run on # C++ environment. As a result, we unwrap subclass parameters AOT. After this, # ExportedProgram state_dict won't be same as eager model because eager model # could have subclass weights while ExportedProgram will have desugared versions. # This is fine because run_decompositions is supposed to specialize to post-autograd # graph where the subclass desugaring is supposed to happen. unwrap_tensor_subclass_parameters(mod) unwrapped_params = dict(mod.named_parameters(remove_duplicate=False)) # TODO T204030333 fake_mode = _detect_fake_mode_from_gm(ep.graph_module) if fake_mode is None: fake_mode = FakeTensorMode(shape_env=ShapeEnv(), export=True) retracing_args = [] for node in mod.graph.nodes: if node.op == "placeholder": if isinstance(node.meta["val"], CustomObjArgument): real_script_obj = None if node.meta["val"].fake_val is None: real_script_obj = ep.constants[node.meta["val"].name] else: real_script_obj = node.meta["val"].fake_val.real_obj retracing_args.append(real_script_obj) else: retracing_args.append(node.meta["val"]) retracing_args_unwrapped = pytree.tree_unflatten(retracing_args, mod._in_spec) # Fix the graph output signature to be tuple if scalar out_spec = mod._out_spec orig_arg_names = mod.graph._codegen.pytree_info.orig_args # aot_export expect the return type to always be a tuple. if out_spec.type not in (list, tuple): out_spec = pytree.TreeSpec(tuple, None, [out_spec]) mod.graph._codegen = _PyTreeCodeGen( _PyTreeInfo( orig_arg_names, mod._in_spec, out_spec, ) ) mod.recompile() # the exported module will store constants & non-persistent buffers such that # retracing treats them as persistent buffers, so we inform the constants lifting pass # and overwrite the new graph signature using the previous program. _collect_and_set_constant_attrs(ep.graph_signature, ep.constants, mod) # get params & buffers after excluding constants fake_params_buffers = _fakify_params_buffers(fake_mode, mod) params_buffers_to_node_meta = _collect_param_buffer_metadata(mod) # TODO (tmanlaibaatar) Ideally run_decomp should just call _non_strict_export # but due to special handling of constants as non-persistent buffers make it little # diffucult. But we should unify this code path together. T206837815 from torch._export.non_strict_utils import _fakify_script_objects with ( fake_mode ), _override_decomp_aten_to_variants(), _override_composite_implicit_decomp( cia_to_decomp, ): # this requires empty kwargs, but not in pytree.flattened format with _fakify_script_objects( mod, ( *retracing_args_unwrapped[0], *retracing_args_unwrapped[1].values(), ), {}, fake_mode, ) as ( patched_mod, new_fake_args, new_fake_kwargs, new_fake_constant_attrs, map_fake_to_real, ): aten_export_artifact = _export_to_aten_ir( patched_mod, new_fake_args, new_fake_kwargs, fake_params_buffers, new_fake_constant_attrs, decomp_table=python_decomp_table, _check_autograd_state=False, ) # aten_export_artifact.constants contains only fake script objects, we need to map them back aten_export_artifact.constants = { fqn: map_fake_to_real[obj] if isinstance(obj, FakeScriptObject) else obj for fqn, obj in aten_export_artifact.constants.items() } gm = aten_export_artifact.gm new_graph_signature = aten_export_artifact.sig _populate_param_buffer_metadata_to_new_gm( params_buffers_to_node_meta, gm, new_graph_signature ) # overwrite signature for non-persistent buffers new_graph_signature = _overwrite_signature_for_non_persistent_buffers( ep.graph_signature, new_graph_signature ) _verify_nn_module_stack(gm) _verify_stack_trace(gm) _verify_placeholder_names(gm, new_graph_signature) gm, new_graph_signature = _remove_unneccessary_copy_op_pass( gm, new_graph_signature ) # When we apply parameterixzation rule to unwrap # subclasses, the state dict will now have different # desugared parameters. We need to manually filter those # and update the ep.state_dict. Ideally, we should just return # the state dict of ep.module but ep.module only stores params # buffers that participate in forward. If we undo this behaviour, # it would break some downstream users. for name, p in unwrapped_params.items(): if name not in wrapped_params: ep.state_dict[name] = p for name, p in wrapped_params.items(): assert name in ep.state_dict if name not in unwrapped_params: ep.state_dict.pop(name) return gm, new_graph_signature, ep.state_dict old_placeholders = [ node for node in ep.graph_module.graph.nodes if node.op == "placeholder" ] fake_args = [node.meta["val"] for node in old_placeholders] buffers_to_remove = [name for name, _ in ep.graph_module.named_buffers()] for name in buffers_to_remove: delattr(ep.graph_module, name) # TODO(zhxhchen17) Return the new graph_signature directly. fake_mode = detect_fake_mode(fake_args) fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode with _ignore_backend_decomps(), fake_mode, _override_composite_implicit_decomp( cia_to_decomp ): gm, graph_signature = aot_export_module( ep.graph_module, fake_args, decompositions=python_decomp_table, trace_joint=True if joint_loss_index is not None else False, output_loss_index=( joint_loss_index if joint_loss_index is not None else None ), ) gm.graph.eliminate_dead_code() # Update the signatures with the new placeholder names in case they # changed when calling aot_export def update_arg(old_arg, new_ph): if isinstance(old_arg, ConstantArgument): return old_arg elif isinstance(old_arg, TensorArgument): return TensorArgument(name=new_ph.name) elif isinstance(old_arg, SymIntArgument): return SymIntArgument(name=new_ph.name) elif isinstance(old_arg, SymFloatArgument): return SymFloatArgument(name=new_ph.name) elif isinstance(old_arg, SymBoolArgument): return SymBoolArgument(name=new_ph.name) raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}") new_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] new_outputs = list(gm.graph.nodes)[-1].args[0] # rename the placeholders assert len(new_placeholders) == len(old_placeholders) for old_ph, new_ph in zip(old_placeholders, new_placeholders): new_ph.name = new_ph.target = old_ph.name # handle name collisions with newly decomposed graph nodes name_map = {ph.name: ph.name for ph in new_placeholders} for node in gm.graph.nodes: if node.op == "placeholder": continue node.name = _rename_without_collisions(name_map, node.name, node.name) # propagate names to higher order op subgraphs _name_hoo_subgraph_placeholders(gm) # Run this pass before creating input/output specs, since size-related CSE/DCE might affect output signature. # Overwrite output specs afterwards. from torch._export.passes._node_metadata_hook import ( _node_metadata_hook, _set_node_metadata_hook, ) from torch._functorch._aot_autograd.input_output_analysis import _graph_output_names if not torch._dynamo.config.do_not_emit_runtime_asserts: stack_trace = ( 'File "torch/fx/passes/runtime_assert.py", line 24, ' "in insert_deferred_runtime_asserts" ) shape_env = _get_shape_env(gm) if shape_env is not None: with _set_node_metadata_hook( gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) ): insert_deferred_runtime_asserts( gm, shape_env, f"exported program: {first_call_function_nn_module_stack(gm.graph)}", export=True, ) # update output specs gm.recompile() for i, name in enumerate(_graph_output_names(gm)): if isinstance(new_outputs[i], torch.fx.Node): new_outputs[i].name = name # To match the output target with correct input for input mutations # need to find the old to new placeholder map old_new_placeholder_map = { spec.arg.name: new_placeholders[i].name for i, spec in enumerate(ep.graph_signature.input_specs) if not isinstance(spec.arg, ConstantArgument) } input_specs = [ InputSpec( spec.kind, update_arg(spec.arg, new_placeholders[i]), spec.target, spec.persistent, ) for i, spec in enumerate(ep.graph_signature.input_specs) ] output_specs = [ OutputSpec( OutputKind.LOSS_OUTPUT if i == joint_loss_index else spec.kind, update_arg(spec.arg, new_outputs[i]), old_new_placeholder_map.get(spec.target, spec.target), ) for i, spec in enumerate(ep.graph_signature.output_specs) ] if joint_loss_index is not None: assert graph_signature.backward_signature is not None gradients = graph_signature.backward_signature.gradients_to_user_inputs assert len(graph_signature.user_inputs) == len(ep.graph_signature.input_specs) specs = { graph_signature.user_inputs[i]: spec for i, spec in enumerate(ep.graph_signature.input_specs) if isinstance(spec.arg, TensorArgument) } for i, node in enumerate(new_outputs[len(output_specs) :]): source = gradients[node.name] spec = specs[source] # type: ignore[index] if spec.kind == InputKind.PARAMETER: kind = OutputKind.GRADIENT_TO_PARAMETER target = spec.target elif spec.kind == InputKind.USER_INPUT: kind = OutputKind.GRADIENT_TO_USER_INPUT target = source else: raise AssertionError(f"Unknown input kind: {spec.kind}") output_specs.append( OutputSpec( kind, TensorArgument(name=node.name), target, ) ) assert len(new_placeholders) == len(old_placeholders) new_graph_signature = ExportGraphSignature( input_specs=input_specs, output_specs=output_specs ) # NOTE: aot_export adds symint metadata for placeholders with int # values; since these become specialized, we replace such metadata with # the original values. # Also, set the param/buffer metadata back to the placeholders. for old_node, new_node in zip(old_placeholders, new_placeholders): if not isinstance(old_node.meta["val"], torch.Tensor): new_node.meta["val"] = old_node.meta["val"] if ( new_node.target in new_graph_signature.inputs_to_parameters or new_node.target in new_graph_signature.inputs_to_buffers ): for k, v in old_node.meta.items(): new_node.meta[k] = v return gm, new_graph_signature, ep.state_dict def _remove_unneccessary_copy_op_pass( gm: torch.fx.GraphModule, new_graph_signature: ExportGraphSignature ) -> Tuple[torch.fx.GraphModule, ExportGraphSignature]: """ Removes redundant copy_ node that was introduced due to mutated buffer. """ with gm._set_replace_hook(new_graph_signature.get_replace_hook()): for node in gm.graph.nodes: if node.op == "output": args, _ = pytree.tree_flatten(node.args) for out in args: if ( isinstance(out, torch.fx.Node) and out.name in new_graph_signature.buffers_to_mutate ): if ( out.op == "call_function" and out.target == torch.ops.aten.copy.default ): out.replace_all_uses_with(out.args[1]) # type: ignore[arg-type] gm.graph.erase_node(out) gm.recompile() return gm, new_graph_signature def _common_getitem_elimination_pass( gm: torch.fx.GraphModule, graph_signature, module_call_graph ): with gm._set_replace_hook(graph_signature.get_replace_hook()): for module in gm.modules(): if not isinstance(module, torch.fx.GraphModule): continue node_id: Dict[torch.fx.Node, str] = {} getitems: Dict[str, torch.fx.Node] = {} for node in list(module.graph.nodes): if node.op == "call_function" and node.target == operator.getitem: source, idx = node.args new_id = f"{node_id[source]}.{idx}" if new_id in getitems: node.replace_all_uses_with(getitems[new_id]) for entry in module_call_graph: if entry.signature is not None: entry.signature.replace_all_uses_with( node, getitems[new_id] ) module.graph.erase_node(node) else: getitems[new_id] = node node_id[node] = new_id else: node_id[node] = node.name def _get_updated_module_call_graph( gm: torch.fx.GraphModule, old_module_call_graph: List[ModuleCallEntry], ): new_module_call_graph = copy.deepcopy(old_module_call_graph) # use node-level provenance metadata to create a map # from old node names to new node names provenance: Dict[str, str] = {} for node in gm.graph.nodes: if history := node.meta.get("from_node", []): provenance[history[-1].name] = node.name # map old names to new names in module call signatures for entry in new_module_call_graph: signature = entry.signature if signature is None: continue for x in [*signature.inputs, *signature.outputs]: x.name = provenance.get(x.name, x.name) return new_module_call_graph def _decompose_exported_program( ep, *, cia_to_decomp: Dict[torch._ops.OperatorBase, Callable], python_decomp_table: Dict[torch._ops.OperatorBase, Callable], joint_loss_index: Optional[int], ): ( gm, new_graph_signature, state_dict, ) = _decompose_and_get_gm_with_new_signature_constants( ep, cia_to_decomp=cia_to_decomp, python_decomp_table=python_decomp_table, joint_loss_index=joint_loss_index, ) # The signatures of ep.module_call_graph refer to input / output nodes of # the original graph module. However, the new graph module may have # new nodes due to decompositions. So we need to update these signatures # in the decomposed exported program's module_call_graph. new_module_call_graph = _get_updated_module_call_graph( gm, ep.module_call_graph, ) # TODO unfortunately preserving graph-level metadata is not # working well with aot_export. So we manually copy it. # (The node-level meta is addressed above.) gm.meta.update(ep.graph_module.meta) new_range_constraints = _get_updated_range_constraints( gm, ep.range_constraints, ) exported_program = ExportedProgram( root=gm, graph=gm.graph, graph_signature=new_graph_signature, state_dict=state_dict, range_constraints=new_range_constraints, module_call_graph=new_module_call_graph, example_inputs=ep.example_inputs, constants=ep.constants, ) return exported_program
[docs]class ExportedProgram: """ Package of a program from :func:`export`. It contains an :class:`torch.fx.Graph` that represents Tensor computation, a state_dict containing tensor values of all lifted parameters and buffers, and various metadata. You can call an ExportedProgram like the original callable traced by :func:`export` with the same calling convention. To perform transformations on the graph, use ``.module`` property to access an :class:`torch.fx.GraphModule`. You can then use `FX transformation <https://pytorch.org/docs/stable/fx.html#writing-transformations>`_ to rewrite the graph. Afterwards, you can simply use :func:`export` again to construct a correct ExportedProgram. """ def __init__( self, root: Union[torch.nn.Module, Dict[str, Any]], graph: torch.fx.Graph, graph_signature: ExportGraphSignature, state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]], range_constraints: "Dict[sympy.Symbol, Any]", module_call_graph: List[ModuleCallEntry], example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None, constants: Optional[ Dict[str, Union[torch.Tensor, FakeScriptObject, torch._C.ScriptObject]] ] = None, *, verifiers: Optional[List[Type[Verifier]]] = None, ): # Remove codegen related things from the graph. It should just be a flat graph. graph._codegen = torch.fx.graph.CodeGen() self._graph_module = _create_graph_module_for_export(root, graph) if isinstance(root, torch.fx.GraphModule): self._graph_module.meta.update(root.meta) _common_getitem_elimination_pass( self._graph_module, graph_signature, module_call_graph ) self._graph_signature: ExportGraphSignature = graph_signature self._state_dict: Dict[str, Any] = state_dict self._range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints assert module_call_graph is not None self._module_call_graph: List[ModuleCallEntry] = module_call_graph self._example_inputs = example_inputs self._constants = constants or {} verifiers = verifiers or [Verifier] assert all(issubclass(v, Verifier) for v in verifiers) self._verifiers = verifiers # Validate should be always the last step of the constructor. self.validate() @property @compatibility(is_backward_compatible=False) def graph_module(self): return self._graph_module @graph_module.setter @compatibility(is_backward_compatible=False) def graph_module(self, value): raise RuntimeError("Unable to set ExportedProgram's graph_module attribute.") @property @compatibility(is_backward_compatible=False) def graph(self): return self.graph_module.graph @graph.setter @compatibility(is_backward_compatible=False) def graph(self, value): raise RuntimeError("Unable to set ExportedProgram's graph attribute.") @property @compatibility(is_backward_compatible=False) def graph_signature(self): return self._graph_signature @graph_signature.setter @compatibility(is_backward_compatible=False) def graph_signature(self, value): raise RuntimeError("Unable to set ExportedProgram's graph_signature attribute.") @property @compatibility(is_backward_compatible=False) def state_dict(self): return self._state_dict @state_dict.setter @compatibility(is_backward_compatible=False) def state_dict(self, value): raise RuntimeError("Unable to set ExportedProgram's state_dict attribute.")
[docs] @compatibility(is_backward_compatible=False) def parameters(self) -> Iterator[torch.nn.Parameter]: """ Returns an iterator over original module's parameters. """ for _, param in self.named_parameters(): yield param
[docs] @compatibility(is_backward_compatible=False) def named_parameters(self) -> Iterator[Tuple[str, torch.nn.Parameter]]: """ Returns an iterator over original module parameters, yielding both the name of the parameter as well as the parameter itself. """ for param_name in self.graph_signature.parameters: yield param_name, self.state_dict[param_name]
[docs] @compatibility(is_backward_compatible=False) def buffers(self) -> Iterator[torch.Tensor]: """ Returns an iterator over original module buffers. """ for _, buf in self.named_buffers(): yield buf
[docs] @compatibility(is_backward_compatible=False) def named_buffers(self) -> Iterator[Tuple[str, torch.Tensor]]: """ Returns an iterator over original module buffers, yielding both the name of the buffer as well as the buffer itself. """ non_persistent_buffers = set(self.graph_signature.non_persistent_buffers) for buffer_name in self.graph_signature.buffers: if buffer_name in non_persistent_buffers: yield buffer_name, self.constants[buffer_name] else: yield buffer_name, self.state_dict[buffer_name]
@property @compatibility(is_backward_compatible=False) def range_constraints(self): return self._range_constraints @range_constraints.setter @compatibility(is_backward_compatible=False) def range_constraints(self, value): raise RuntimeError( "Unable to set ExportedProgram's range_constraints attribute." ) @property @compatibility(is_backward_compatible=False) def module_call_graph(self): return self._module_call_graph @module_call_graph.setter @compatibility(is_backward_compatible=False) def module_call_graph(self, value): raise RuntimeError( "Unable to set ExportedProgram's module_call_graph attribute." ) @property @compatibility(is_backward_compatible=False) def example_inputs(self): return self._example_inputs @example_inputs.setter @compatibility(is_backward_compatible=False) def example_inputs(self, value): # This is allowed if not ( isinstance(value, tuple) and len(value) == 2 and isinstance(value[0], tuple) and isinstance(value[1], dict) ): raise ValueError( "Example inputs should be a tuple containing example arguments (as " "a tuple), and example kwargs (as a dictionary)." ) args, kwargs = value from ._unlift import _check_inputs_match _check_inputs_match(args, kwargs, self.call_spec.in_spec) self._example_inputs = value @property @compatibility(is_backward_compatible=False) def call_spec(self): CallSpec = namedtuple("CallSpec", ["in_spec", "out_spec"]) if len(self.module_call_graph) == 0: return CallSpec(in_spec=None, out_spec=None) assert self.module_call_graph[0].fqn == "" return CallSpec( in_spec=self.module_call_graph[0].signature.in_spec, out_spec=self.module_call_graph[0].signature.out_spec, ) @call_spec.setter @compatibility(is_backward_compatible=False) def call_spec(self, value): raise RuntimeError("Unable to set ExportedProgram's call_spec attribute.") @property @compatibility(is_backward_compatible=False) def verifier(self) -> Any: return self._verifiers[0] @verifier.setter @compatibility(is_backward_compatible=False) def verifier(self, value): raise RuntimeError("Unable to set ExportedProgram's verifier attribute.") @property @compatibility(is_backward_compatible=False) def dialect(self) -> str: assert self._verifiers is not None return self._verifiers[0].dialect @dialect.setter @compatibility(is_backward_compatible=False) def dialect(self, value): raise RuntimeError("Unable to set ExportedProgram's dialect attribute.") @property @compatibility(is_backward_compatible=False) def verifiers(self): return self._verifiers @verifiers.setter @compatibility(is_backward_compatible=False) def verifiers(self, value): raise RuntimeError("Unable to set ExportedProgram's verifiers attribute.") @property @compatibility(is_backward_compatible=False) def tensor_constants(self): return self._constants @tensor_constants.setter @compatibility(is_backward_compatible=False) def tensor_constants(self, value): raise RuntimeError( "Unable to set ExportedProgram's tensor_constants attribute." ) @property @compatibility(is_backward_compatible=False) def constants(self): return self._constants @constants.setter @compatibility(is_backward_compatible=False) def constants(self, value): raise RuntimeError("Unable to set ExportedProgram's constants attribute.") def _get_flat_args_with_check(self, args, kwargs): """Flatten args, kwargs using pytree, then, check specs. Args: args: List[Any] original args passed to __call__ kwargs: Dict[str, Any] original kwargs passed to __call Returns: A tuple of (flat_args, received_spec) flat_args is flattend args / kwargs received_spec is the pytree spec produced while flattening the tuple (args, kwargs) """ in_spec = self.call_spec.in_spec if in_spec is not None: kwargs = reorder_kwargs(kwargs, in_spec) flat_args_with_path, received_spec = pytree.tree_flatten_with_path( (args, kwargs) ) self._check_input_constraints(flat_args_with_path) flat_args = tuple(x[1] for x in flat_args_with_path) return flat_args, received_spec def _graph_module_flat_inputs(self, args: Any, kwargs: Any) -> Any: """Transform args, kwargs of __call__ to args for graph_module. self.graph_module takes stuff from state dict as inputs. The invariant is for ep: ExportedProgram is ep(args, kwargs) == ep.postprocess(ep.graph_module(ep.graph_module_flat_inputs(args, kwargs))) """ in_spec = self.call_spec.in_spec flat_args, received_spec = self._get_flat_args_with_check(args, kwargs) if in_spec is not None and not is_equivalent( received_spec, in_spec, _fx_collection_equivalence_fn ): raise ValueError( "Trying to flatten user inputs with exported input tree spec: \n" f"{in_spec}\n" "but actually got inputs with tree spec of: \n" f"{received_spec}" ) additional_inputs = [] for input_ in self.graph_signature.input_specs: if input_.kind == InputKind.USER_INPUT: continue elif input_.kind in ( InputKind.PARAMETER, InputKind.BUFFER, ): if input_.persistent is False: # This is a non-persistent buffer, grab it from our # constants instead of the state dict. additional_inputs.append(self.constants[input_.target]) else: additional_inputs.append(self.state_dict[input_.target]) elif input_.kind in ( InputKind.CONSTANT_TENSOR, InputKind.CUSTOM_OBJ, ): additional_inputs.append(self.constants[input_.target]) additional_inputs = tuple(additional_inputs) # NOTE: calling convention is first params, then buffers, then args as user supplied them. # See: torch/_functorch/aot_autograd.py#L1034 return additional_inputs + flat_args def __call__(self, *args: Any, **kwargs: Any) -> Any: raise RuntimeError( "Unable to call ExportedProgram directly. " "You should use `exported_program.module()` instead." ) def _postprocess_graph_module_outputs(self, res, orig_args, orig_kwargs): """Process potential mutations to the input. Because self.graph_module is functional, so mutations has to be written back after execution of graph_module. """ import torch._export.error as error flat_args, _ = self._get_flat_args_with_check(orig_args, orig_kwargs) if self.call_spec.out_spec is not None: buffer_mutation = self.graph_signature.buffers_to_mutate user_input_mutation = self.graph_signature.user_inputs_to_mutate num_mutated = len(buffer_mutation) + len(user_input_mutation) mutated_values = res[:num_mutated] # Exclude dependency token from final result. assertion_dep_token = self.graph_signature.assertion_dep_token if assertion_dep_token is not None: assertion_dep_token_index = next(iter(assertion_dep_token.keys())) res = res[:assertion_dep_token_index] res = res[num_mutated:] try: res = pytree.tree_unflatten(res, self.call_spec.out_spec) except Exception: _, received_spec = pytree.tree_flatten(res) raise error.InternalError( # noqa: B904 "Trying to flatten user outputs with exported output tree spec: \n" f"{self.call_spec.out_spec}\n" "but actually got outputs with tree spec of: \n" f"{received_spec}" ) finally: user_inputs = [ spec for spec in self.graph_signature.input_specs if spec.kind == InputKind.USER_INPUT ] for i, value in enumerate(mutated_values): output_spec = self.graph_signature.output_specs[i] if output_spec.kind == OutputKind.BUFFER_MUTATION: assert output_spec.target is not None self.state_dict[output_spec.target] = value elif output_spec.kind == OutputKind.USER_INPUT_MUTATION: assert output_spec.target is not None index = next( i for i, spec in enumerate(user_inputs) if spec.arg.name == output_spec.target ) flat_args[index].copy_(value) else: raise AssertionError(f"Unexpected kind: {output_spec.kind}") return res def __str__(self) -> str: graph_module = self.graph_module.print_readable( print_output=False, colored=False ).replace("\n", "\n ") string = ( "ExportedProgram:\n" f" {graph_module}\n" f"Graph signature: {self.graph_signature}\n" f"Range constraints: {self.range_constraints}\n" ) return string
[docs] def module(self) -> torch.nn.Module: """ Returns a self contained GraphModule with all the parameters/buffers inlined. """ from ._unlift import _unlift_exported_program_lifted_states module = _unlift_exported_program_lifted_states(self) def _train(self, mode: bool = True): raise NotImplementedError("Calling train() is not supported yet.") def _eval(self, mode: bool = True): raise NotImplementedError("Calling eval() is not supported yet.") module.train = types.MethodType(_train, module) # type: ignore[method-assign] module.eval = types.MethodType(_eval, module) # type: ignore[method-assign] return module
def _num_lifted_params_buffers(self): return next( ( i for i, s in enumerate(self._graph_signature.input_specs) if s.kind == InputKind.USER_INPUT ), len(self._graph_signature.input_specs), )
[docs] @_disable_prexisiting_fake_mode def run_decompositions( self, decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None, ) -> "ExportedProgram": """ Run a set of decompositions on the exported program and returns a new exported program. By default we will run the Core ATen decompositions to get operators in the `Core ATen Operator Set <https://pytorch.org/docs/stable/torch.compiler_ir.html>`_. For now, we do not decompose joint graphs. Args: decomp_table: An optional argument that specifies decomp behaviour for Aten ops (1) If None, we decompose to core aten decompositions (2) If empty, we don't decompose any operator Some examples: If you don't want to decompose anything .. code-block:: python ep = torch.export.export(model, ...) ep = ep.run_decompositions(decomp_table={}) If you want to get a core aten operator set except for certain operator, you can do following: .. code-block:: python ep = torch.export.export(model, ...) decomp_table = torch.export.default_decompositions() decomp_table[your_op] = your_custom_decomp ep = ep.run_decompositions(decomp_table=decomp_table) """ _decomp_table = ( default_decompositions() if decomp_table is None else dict(decomp_table) ) if isinstance(_decomp_table, CustomDecompTable): _decomp_table = _decomp_table.materialize() # Note [Seperating decomp_table into CIA decomps and non-CIA decomps] # At this point, we have a decomp_table that contains decomp behaviour for # both CIA and post-autograd ops. # We need to separate the op into two categories: # 1. CIA op: These are the ops that we want to override # CompositeImplicitAutograd decomp for. For them, we need to use _override_composite_implicit_decomp # context manager to plumb it through AOTDispatcher # 2. Non-CIA op: These ops are only relevant after AOTDIspatcher runs, so just # checking if they are statically functional is enough. # For joint IR case tho, we need to use the old path because we can't register # custom decomps this way because we can't use context manager as it installs # autograd_error node. ( cia_to_decomp, python_decomp_table, ) = _split_decomp_table_to_cia_and_python_decomp(_decomp_table) return _decompose_exported_program( self, cia_to_decomp=cia_to_decomp, python_decomp_table=python_decomp_table, joint_loss_index=None, )
def _transform_do_not_use(self, *passes: PassType) -> "ExportedProgram": pm = PassManager(list(passes)) # Since we abstractly run the passes, we need to disable backend decomp here # again. from torch.export._trace import _ignore_backend_decomps with _ignore_backend_decomps(): res = pm(self.graph_module) transformed_gm = res.graph_module if res is not None else self.graph_module assert transformed_gm is not None if transformed_gm is self.graph_module and not res.modified: return self # TODO(zhxchen17) Remove this. def _get_updated_graph_signature( old_signature: ExportGraphSignature, new_gm: torch.fx.GraphModule, ) -> ExportGraphSignature: """ Update the graph signature's user_input/user_outputs. """ new_input_specs = [] for i, node in enumerate(new_gm.graph.nodes): if node.op != "placeholder": break assert i < len( old_signature.input_specs ), "Number of inputs changed after transformation" old_input_spec = old_signature.input_specs[i] arg = ( old_input_spec.arg if isinstance( old_input_spec.arg, (ConstantArgument, CustomObjArgument) ) else type(old_input_spec.arg)(node.name) ) new_input_specs.append( InputSpec( old_input_spec.kind, arg, old_input_spec.target, old_input_spec.persistent, ) ) output_node = list(new_gm.graph.nodes)[-1] assert output_node.op == "output" new_output_specs = [] for i, node in enumerate(output_node.args[0]): assert i < len( old_signature.output_specs ), "Number of outputs changed after transformation" old_output_spec = old_signature.output_specs[i] arg = ( old_output_spec.arg if isinstance( old_output_spec.arg, (ConstantArgument, CustomObjArgument) ) else type(old_output_spec.arg)(node.name) ) new_output_specs.append( OutputSpec(old_output_spec.kind, arg, old_output_spec.target) ) new_signature = ExportGraphSignature( input_specs=new_input_specs, output_specs=new_output_specs ) return new_signature transformed_ep = ExportedProgram( root=transformed_gm, graph=transformed_gm.graph, graph_signature=_get_updated_graph_signature( self.graph_signature, transformed_gm ), state_dict=self.state_dict, range_constraints=_get_updated_range_constraints( transformed_gm, self.range_constraints, ), module_call_graph=copy.deepcopy(self._module_call_graph), example_inputs=self.example_inputs, constants=self.constants, verifiers=self.verifiers, ) transformed_ep.graph_module.meta.update(self.graph_module.meta) transformed_ep.graph_module.meta.update(res.graph_module.meta) return transformed_ep def _check_input_constraints(self, flat_args_with_path): from torch._export.utils import _check_input_constraints_for_graph placeholders = [p for p in self.graph.nodes if p.op == "placeholder"] input_placeholders = [ p for p, s in zip(placeholders, self.graph_signature.input_specs) if s.kind == InputKind.USER_INPUT ] _check_input_constraints_for_graph( input_placeholders, flat_args_with_path, self.range_constraints ) @compatibility(is_backward_compatible=False) def validate(self): self._validate() # TODO: remove this @final def _validate(self): assert ( len(self.verifiers) > 0 ), "ExportedProgram must have at least one verifier." for v in self.verifiers: v().check(self) # TODO(zhxchen17) Formalize this. def _update( self, graph_module, graph_signature, *, state_dict=None, verifiers=None ) -> "ExportedProgram": return ExportedProgram( root=graph_module, graph=graph_module.graph, graph_signature=graph_signature, state_dict=state_dict if state_dict is not None else self.state_dict, range_constraints=copy.deepcopy(self.range_constraints), module_call_graph=copy.deepcopy(self._module_call_graph), example_inputs=self.example_inputs, constants=self.constants, verifiers=verifiers if verifiers is not None else self.verifiers, )
def _get_shape_env(gm): vals = [ node.meta["val"] for node in gm.graph.nodes if node.meta.get("val", None) is not None ] from torch._guards import detect_fake_mode fake_mode = detect_fake_mode(vals) if fake_mode is not None: return fake_mode.shape_env for v in vals: if isinstance(v, torch.SymInt): return v.node.shape_env def _get_updated_range_constraints( gm: torch.fx.GraphModule, old_range_constraints: "Optional[Dict[sympy.Symbol, Any]]" = None, ) -> "Dict[sympy.Symbol, Any]": assert old_range_constraints is not None shape_env = _get_shape_env(gm) if shape_env is None: return {} range_constraints = copy.copy(old_range_constraints) range_constraints = { k: v for k, v in range_constraints.items() if k not in shape_env.replacements } # Only when we have an unbacked symint, and it's used as constructor inputs, # runtime_var_to_range will make a difference compated to var_to_range. # e.g. [2, oo) -> [0, oo) for k, v in shape_env.var_to_range.items(): if k not in shape_env.replacements and k not in range_constraints: range_constraints[k] = v return range_constraints def _create_graph_module_for_export(root, graph): try: gm = torch.fx.GraphModule(root, graph) except SyntaxError: # If custom objects stored in memory are being used in the graph, # the generated python code will result in a syntax error on the custom # object, since it is unable to parse the in-memory object. However # we can still run the graph eagerly through torch.fx.Interpreter, # so we will bypass this error. warnings.warn( "Unable to execute the generated python source code from " "the graph. The graph module will no longer be directly callable, " "but you can still run the ExportedProgram, and if needed, you can " "run the graph module eagerly using torch.fx.Interpreter." ) gm = torch.fx.GraphModule(root, torch.fx.Graph()) gm._graph = graph return gm

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources