# mypy: allow-untyped-decorators# mypy: allow-untyped-defsimportcontextlibimportcopyimportdataclassesimportfunctoolsimportoperatorimporttypesimportwarningsfromcollectionsimportnamedtuplefromcontextlibimportcontextmanagerfromtypingimport(Any,Callable,Dict,final,Iterator,List,Optional,Tuple,Type,TYPE_CHECKING,Union,)fromtorch._higher_order_ops.utilsimportautograd_not_implementedfromtorch._library.fake_class_registryimportFakeScriptObjectfromtorch._subclasses.fake_implsimport(_deregister_op_impl,_is_op_registered_to_fake_rule,register_op_impl,)fromtorch._subclasses.fake_tensorimportFakeTensorModefromtorch.fx._utilsimportfirst_call_function_nn_module_stackfromtorch.fx.graphimport_PyTreeCodeGen,_PyTreeInfofromtorch.fx.immutable_collectionsimportimmutable_dict,immutable_listfromtorch.fx.passes.runtime_assertimportinsert_deferred_runtime_assertsifTYPE_CHECKING:# Import the following modules during type checking to enable code intelligence features,# such as auto-completion in tools like pylance, even when these modules are not explicitly# imported in user code.importsympyfromtorch.utils._sympy.value_rangesimportValueRangesimporttorchimporttorch.utils._pytreeaspytreefromtorch._export.utilsimport(_collect_all_valid_cia_ops,_collect_and_set_constant_attrs,_collect_param_buffer_metadata,_detect_fake_mode_from_gm,_get_decomp_for_cia,_is_preservable_cia_op,_name_hoo_subgraph_placeholders,_overwrite_signature_for_non_persistent_buffers,_populate_param_buffer_metadata_to_new_gm,_rename_without_collisions,_special_op_to_preserve_cia,)fromtorch._export.verifierimportVerifierfromtorch._guardsimportdetect_fake_modefromtorch._subclasses.fake_tensorimportunset_fake_temporarilyfromtorch.export._tree_utilsimportis_equivalent,reorder_kwargsfromtorch.export.decomp_utilsimportCustomDecompTablefromtorch.fx._compatibilityimportcompatibilityfromtorch.fx.passes.infra.pass_baseimportPassResultfromtorch.fx.passes.infra.pass_managerimportPassManagerfrom.graph_signatureimport(# noqa: F401ArgumentSpec,ConstantArgument,CustomObjArgument,ExportGraphSignature,InputKind,InputSpec,OutputKind,OutputSpec,SymBoolArgument,SymFloatArgument,SymIntArgument,TensorArgument,TokenArgument,)__all__=["ExportedProgram","ModuleCallEntry","ModuleCallSignature","default_decompositions",]PassType=Callable[[torch.fx.GraphModule],Optional[PassResult]]
def_disable_prexisiting_fake_mode(fn):@functools.wraps(fn)defwrapper(*args,**kwargs):withunset_fake_temporarily():returnfn(*args,**kwargs)returnwrapperdef_fx_collection_equivalence_fn(spec1_type:Optional[type],spec1_context:pytree.Context,spec2_type:Optional[type],spec2_context:pytree.Context,)->bool:"""Treat containers and their immutable variants as the same type. Otherwise compare as normal. """ifspec1_typeisNoneorspec2_typeisNone:returnspec1_typeisspec2_typeandspec1_context==spec2_contextifissubclass(spec1_type,(dict,immutable_dict))andissubclass(spec2_type,(dict,immutable_dict)):returnspec1_context==spec2_contextifissubclass(spec1_type,(list,immutable_list))andissubclass(spec2_type,(list,immutable_list)):returnspec1_context==spec2_contextreturnspec1_typeisspec2_typeandspec1_context==spec2_context# This list is compiled from DispatchKey.cpp.# The idea is that we use these keys to override# CIA decomp in export_AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE=[torch._C.DispatchKey.AutogradCPU,torch._C.DispatchKey.AutogradCUDA,torch._C.DispatchKey.AutogradMeta,torch._C.DispatchKey.AutogradXLA,torch._C.DispatchKey.AutogradLazy,torch._C.DispatchKey.AutogradIPU,torch._C.DispatchKey.AutogradXPU,torch._C.DispatchKey.AutogradMPS,torch._C.DispatchKey.AutogradHPU,torch._C.DispatchKey.AutogradPrivateUse1,torch._C.DispatchKey.AutogradPrivateUse2,torch._C.DispatchKey.AutogradPrivateUse3,]# This list is compiled from DispatchKey.cpp.# The idea is that we use these keys to add# python kernels that directly uses default# CIA decomp# See NOTE Registering old CIA to Backend kernel_BACKEND_KEYS_TO_OVERRIDE=[torch._C.DispatchKey.CPU,torch._C.DispatchKey.CUDA,torch._C.DispatchKey.Meta,torch._C.DispatchKey.XLA,torch._C.DispatchKey.Lazy,torch._C.DispatchKey.IPU,torch._C.DispatchKey.XPU,torch._C.DispatchKey.MPS,torch._C.DispatchKey.HPU,]@contextmanagerdef_override_composite_implicit_decomp(cia_ops_to_callable,safe=True):# This function overrides CompositeImplicitAutograd decomp for# functional composite ops that user specified. Ideally we want to not-decompose# ALL composite ops but today's C++ functinalization relies on# the fact that it is working with the opset after decomp is run.# Hence we can only do it for functional ops. One caveat is that# there are some composite ops that lie about their schema (claimed to be# functional but not really aka dropout), for these cases, we just decompose.# When safe=False, we will assume that ops_to_preserve can be mutating/aliasing# and their usual decompositions need to be shadowed rather than overridden.# Thus we will avoid asserting that they are valid to preserve, and will not# replace their CompositeImplicitAutograd kernels with NotImplemented.# The only current users of this mode are variants of aten::to that we will# replace with aten::_to_copy in FunctionalTensorMode.__torch_dispatch__.saved_tables={}patched_ops=set()forop_overload,decomp_callableincia_ops_to_callable.items():saved_tables[op_overload]=op_overload.py_kernels.copy()patched_ops.add(op_overload)foroverride_dispatch_keyin_AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE:ifoverride_dispatch_keynotinop_overload.py_kernels:# TODO (tmanlaibaatar)https://github.com/pytorch/pytorch/issues/129430op_overload.py_impl(override_dispatch_key)(autograd_not_implemented(op_overload,deferred_error=True))# See NOTE: Registering old CIA to Backend kernel# It is important that we cache this before we override py_kernels.orig_cia_callable=_get_decomp_for_cia(op_overload)iftorch._C.DispatchKey.CompositeImplicitAutogradinop_overload.py_kernels:delop_overload.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd]ifsafe:op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)(decomp_callable)# [NOTE] Directly registering fake tensor rule to CIA ops# The problem we are facing here is if your CIA custom rule# says we want to preserve the op, we will return NotImplemented.# Unfortunately, this will invoke meta device tracing in fake tensor# resulting in divergent behaviour for CIA kernels that has device based# branching (one case is torch.ops.aten.scaled_dot_product.attention)# To get around this issue, we register direct fake impl so that we# run the kernel before we actually try to decompose the op in FakeTensorMode.# Note that is a no-op in most cases, because:# 1) In post dispatch tracing, CIA would have already decomposed# 2) Most CIA impl are device agnostic.def_force_dispatch_to_orig_cia_callable(fake_tensor_mode,op,*args,**kwargs):orig_cia_callable=kwargs["original_callable"]delkwargs["original_callable"]withfake_tensor_mode:returnorig_cia_callable(*args,**kwargs)ifnot_is_op_registered_to_fake_rule(op_overload):register_op_impl(op_overload)(functools.partial(_force_dispatch_to_orig_cia_callable,original_callable=orig_cia_callable,))forkeyin_BACKEND_KEYS_TO_OVERRIDE:ifkeynotinop_overload.py_kernels:# [NOTE] Registering old CIA to Backend kernel# We always register original CIA behavior to the backend keys kernel# The reason is when we are fake tensor prop-ing or executing real kernel,# we end up calling an operator on respective backend, which in python dispatcher,# will resolve into CIA key. (see resolve_key in torch/_ops.py)# As a result, this CIA now will call into the custom user defined# CIA which can cause a problem.# To make it more concrete, the case we are handling is:# (1) there is a tensor constant we are performing constant propagation# on during tracing# (2) we invoke an op underneath autograd (either because we are below autograd,# or we are tracing in inference mode), so one of the backend keys gets hit# (3) the op we are invoking has a CIA impl that normally runs in eager mode# (and the user wants to tweak this CIA impl during tracing, but during# const-prop we want the original CIA to runop_overload.py_impl(key)(orig_cia_callable)try:yieldfinally:foropinpatched_ops:op.py_kernels.clear()op.py_kernels.update(saved_tables[op])op._dispatch_cache.clear()_deregister_op_impl(op)@contextmanagerdef_override_decomp_aten_to_variants():# Preserve variants of aten::to understanding that they are mutating/aliasing# and their CompositeImplicitAutograd kernels will not become NotImplemented.# We will later replace them with aten._to_copy when functionalizing.with_override_composite_implicit_decomp({torch.ops.aten.to.dtype_layout:_special_op_to_preserve_cia,torch.ops.aten.to.dtype:_special_op_to_preserve_cia,},safe=False,):yielddef_split_decomp_table_to_cia_and_python_decomp(decomp_table:Dict[torch._ops.OperatorBase,Callable])->Tuple[Dict[torch._ops.OperatorBase,Callable],...]:all_preservable_cia_ops=set(_collect_all_valid_cia_ops())cia_ops_to_callable={}foropinlist(decomp_table.keys()):# TODO we are silently allowing non-safe(non-functional) ops through a crack# due to core aten decomp table having non-functional entries. Once we have# a tigher check around core aten decomp, we should warn users about them.# Tracking issue: (https://github.com/pytorch/pytorch/issues/135759)# if it is a valid CIA op we can mess with in export, we check if it is:# 1. Has been marked as to be decomposed. Example:# decomp_table = decomp_table_to_core_aten()# del decomp_table[aten.linear]# In this case, user says decompose everything except for aten.linear# 2. Has been marked with custom decomp behavour. Example:# decomp_table = {aten.linear: some_op}# For (1), we want to remove all the CIA ops that weren't handled by user as# it suggests they are safe to decompose, so we should remove from preservable_list.# for (2), we just plumb the custom decomp to AOTDIspatcher.# In both cases, we want to remove this CIA op from the decomp_table as it is special# handled.ifopinall_preservable_cia_ops:cia_ops_to_callable[op]=decomp_table[op]all_preservable_cia_ops.remove(op)deldecomp_table[op]# If it is a custom op, we want to still preserve or do whatever# with it if it is a functional CIA. The reason we don't remove# from CIA list is because we don't query custom ops.elif_is_preservable_cia_op(op):op_name=op.name()assertnotop_name.startswith("aten"),"This should be a custom op"cia_ops_to_callable[op]=decomp_table[op]# If we reached here, it means user intentionally deleted these CIA ops from# decomp table.forkinall_preservable_cia_ops:cia_ops_to_callable[k]=_special_op_to_preserve_ciareturncia_ops_to_callable,decomp_table
[docs]defdefault_decompositions()->"CustomDecompTable":""" This is the default decomposition table which contains decomposition of all ATEN operators to core aten opset. Use this API together with :func:`run_decompositions()` """returnCustomDecompTable()
def_decompose_and_get_gm_with_new_signature_constants(ep,*,cia_to_decomp:Dict[torch._ops.OperatorBase,Callable],python_decomp_table:Dict[torch._ops.OperatorBase,Callable],joint_loss_index:Optional[int],):fromtorch._functorch.aot_autogradimportaot_export_modulefromtorch.export._traceimport(_export_to_aten_ir,_fakify_params_buffers,_ignore_backend_decomps,_verify_nn_module_stack,_verify_placeholder_names,_verify_stack_trace,)fromtorch.fx.experimental.symbolic_shapesimportShapeEnvdef_is_joint_ir_decomp(ep,joint_loss_index):return(joint_loss_indexisnotNoneorep.graph_signature.backward_signatureisnotNone)ifnot_is_joint_ir_decomp(ep,joint_loss_index):mod=ep.module()# TODO T204030333fake_mode=_detect_fake_mode_from_gm(ep.graph_module)iffake_modeisNone:fake_mode=FakeTensorMode(shape_env=ShapeEnv(),export=True)retracing_args=[]fornodeinmod.graph.nodes:ifnode.op=="placeholder":ifisinstance(node.meta["val"],CustomObjArgument):real_script_obj=Noneifnode.meta["val"].fake_valisNone:real_script_obj=ep.constants[node.meta["val"].name]else:real_script_obj=node.meta["val"].fake_val.real_objretracing_args.append(real_script_obj)else:retracing_args.append(node.meta["val"])retracing_args_unwrapped=pytree.tree_unflatten(retracing_args,mod._in_spec)# Fix the graph output signature to be tuple if scalarout_spec=mod._out_specorig_arg_names=mod.graph._codegen.pytree_info.orig_args# aot_export expect the return type to always be a tuple.ifout_spec.typenotin(list,tuple):out_spec=pytree.TreeSpec(tuple,None,[out_spec])mod.graph._codegen=_PyTreeCodeGen(_PyTreeInfo(orig_arg_names,mod._in_spec,out_spec,))mod.recompile()# the exported module will store constants & non-persistent buffers such that# retracing treats them as persistent buffers, so we inform the constants lifting pass# and overwrite the new graph signature using the previous program._collect_and_set_constant_attrs(ep.graph_signature,ep.constants,mod)# get params & buffers after excluding constantsfake_params_buffers=_fakify_params_buffers(fake_mode,mod)params_buffers_to_node_meta=_collect_param_buffer_metadata(mod)# TODO (tmanlaibaatar) Ideally run_decomp should just call _non_strict_export# but due to special handling of constants as non-persistent buffers make it little# diffucult. But we should unify this code path together. T206837815fromtorch._export.non_strict_utilsimport_fakify_script_objectswith(fake_mode),_override_decomp_aten_to_variants(),_override_composite_implicit_decomp(cia_to_decomp,):# this requires empty kwargs, but not in pytree.flattened formatwith_fakify_script_objects(mod,(*retracing_args_unwrapped[0],*retracing_args_unwrapped[1].values(),),{},fake_mode,)as(patched_mod,new_fake_args,new_fake_kwargs,new_fake_constant_attrs,map_fake_to_real,):aten_export_artifact=_export_to_aten_ir(patched_mod,new_fake_args,new_fake_kwargs,fake_params_buffers,new_fake_constant_attrs,decomp_table=python_decomp_table,_check_autograd_state=False,)# aten_export_artifact.constants contains only fake script objects, we need to map them backaten_export_artifact.constants={fqn:map_fake_to_real[obj]ifisinstance(obj,FakeScriptObject)elseobjforfqn,objinaten_export_artifact.constants.items()}gm=aten_export_artifact.gmnew_graph_signature=aten_export_artifact.sig_populate_param_buffer_metadata_to_new_gm(params_buffers_to_node_meta,gm,new_graph_signature)# overwrite signature for non-persistent buffersnew_graph_signature=_overwrite_signature_for_non_persistent_buffers(ep.graph_signature,new_graph_signature)_verify_nn_module_stack(gm)_verify_stack_trace(gm)_verify_placeholder_names(gm,new_graph_signature)return_remove_unneccessary_copy_op_pass(gm,new_graph_signature)old_placeholders=[nodefornodeinep.graph_module.graph.nodesifnode.op=="placeholder"]fake_args=[node.meta["val"]fornodeinold_placeholders]buffers_to_remove=[nameforname,_inep.graph_module.named_buffers()]fornameinbuffers_to_remove:delattr(ep.graph_module,name)# TODO(zhxhchen17) Return the new graph_signature directly.fake_mode=detect_fake_mode(fake_args)fake_mode=contextlib.nullcontext()iffake_modeisNoneelsefake_modewith_ignore_backend_decomps(),fake_mode,_override_composite_implicit_decomp(cia_to_decomp):gm,graph_signature=aot_export_module(ep.graph_module,fake_args,decompositions=python_decomp_table,trace_joint=Trueifjoint_loss_indexisnotNoneelseFalse,output_loss_index=(joint_loss_indexifjoint_loss_indexisnotNoneelseNone),)gm.graph.eliminate_dead_code()# Update the signatures with the new placeholder names in case they# changed when calling aot_exportdefupdate_arg(old_arg,new_ph):ifisinstance(old_arg,ConstantArgument):returnold_argelifisinstance(old_arg,TensorArgument):returnTensorArgument(name=new_ph.name)elifisinstance(old_arg,SymIntArgument):returnSymIntArgument(name=new_ph.name)elifisinstance(old_arg,SymFloatArgument):returnSymFloatArgument(name=new_ph.name)elifisinstance(old_arg,SymBoolArgument):returnSymBoolArgument(name=new_ph.name)raiseRuntimeError(f"Type of old_arg not supported: {type(old_arg)}")new_placeholders=[nodefornodeingm.graph.nodesifnode.op=="placeholder"]new_outputs=list(gm.graph.nodes)[-1].args[0]# rename the placeholdersassertlen(new_placeholders)==len(old_placeholders)forold_ph,new_phinzip(old_placeholders,new_placeholders):new_ph.name=new_ph.target=old_ph.name# handle name collisions with newly decomposed graph nodesname_map={ph.name:ph.nameforphinnew_placeholders}fornodeingm.graph.nodes:ifnode.op=="placeholder":continuenode.name=_rename_without_collisions(name_map,node.name,node.name)# propagate names to higher order op subgraphs_name_hoo_subgraph_placeholders(gm)# Run this pass before creating input/output specs, since size-related CSE/DCE might affect output signature.# Overwrite output specs afterwards.fromtorch._export.passes._node_metadata_hookimport(_node_metadata_hook,_set_node_metadata_hook,)fromtorch._functorch._aot_autograd.input_output_analysisimport_graph_output_namesifnottorch._dynamo.config.do_not_emit_runtime_asserts:stack_trace=('File "torch/fx/passes/runtime_assert.py", line 24, '"in insert_deferred_runtime_asserts")shape_env=_get_shape_env(gm)ifshape_envisnotNone:with_set_node_metadata_hook(gm,functools.partial(_node_metadata_hook,stack_trace=stack_trace)):insert_deferred_runtime_asserts(gm,shape_env,f"exported program: {first_call_function_nn_module_stack(gm.graph)}",export=True,)# update output specsgm.recompile()fori,nameinenumerate(_graph_output_names(gm)):ifisinstance(new_outputs[i],torch.fx.Node):new_outputs[i].name=name# To match the output target with correct input for input mutations# need to find the old to new placeholder mapold_new_placeholder_map={spec.arg.name:new_placeholders[i].namefori,specinenumerate(ep.graph_signature.input_specs)ifnotisinstance(spec.arg,ConstantArgument)}input_specs=[InputSpec(spec.kind,update_arg(spec.arg,new_placeholders[i]),spec.target,spec.persistent,)fori,specinenumerate(ep.graph_signature.input_specs)]output_specs=[OutputSpec(OutputKind.LOSS_OUTPUTifi==joint_loss_indexelsespec.kind,update_arg(spec.arg,new_outputs[i]),old_new_placeholder_map.get(spec.target,spec.target),)fori,specinenumerate(ep.graph_signature.output_specs)]ifjoint_loss_indexisnotNone:assertgraph_signature.backward_signatureisnotNonegradients=graph_signature.backward_signature.gradients_to_user_inputsassertlen(graph_signature.user_inputs)==len(ep.graph_signature.input_specs)specs={graph_signature.user_inputs[i]:specfori,specinenumerate(ep.graph_signature.input_specs)ifisinstance(spec.arg,TensorArgument)}fori,nodeinenumerate(new_outputs[len(output_specs):]):source=gradients[node.name]spec=specs[source]# type: ignore[index]ifspec.kind==InputKind.PARAMETER:kind=OutputKind.GRADIENT_TO_PARAMETERtarget=spec.targetelifspec.kind==InputKind.USER_INPUT:kind=OutputKind.GRADIENT_TO_USER_INPUTtarget=sourceelse:raiseAssertionError(f"Unknown input kind: {spec.kind}")output_specs.append(OutputSpec(kind,TensorArgument(name=node.name),target,))assertlen(new_placeholders)==len(old_placeholders)new_graph_signature=ExportGraphSignature(input_specs=input_specs,output_specs=output_specs)# NOTE: aot_export adds symint metadata for placeholders with int# values; since these become specialized, we replace such metadata with# the original values.# Also, set the param/buffer metadata back to the placeholders.forold_node,new_nodeinzip(old_placeholders,new_placeholders):ifnotisinstance(old_node.meta["val"],torch.Tensor):new_node.meta["val"]=old_node.meta["val"]if(new_node.targetinnew_graph_signature.inputs_to_parametersornew_node.targetinnew_graph_signature.inputs_to_buffers):fork,vinold_node.meta.items():new_node.meta[k]=vreturngm,new_graph_signaturedef_remove_unneccessary_copy_op_pass(gm:torch.fx.GraphModule,new_graph_signature:ExportGraphSignature)->Tuple[torch.fx.GraphModule,ExportGraphSignature]:""" Removes redundant copy_ node that was introduced due to mutated buffer. """withgm._set_replace_hook(new_graph_signature.get_replace_hook()):fornodeingm.graph.nodes:ifnode.op=="output":args,_=pytree.tree_flatten(node.args)foroutinargs:if(isinstance(out,torch.fx.Node)andout.nameinnew_graph_signature.buffers_to_mutate):if(out.op=="call_function"andout.target==torch.ops.aten.copy.default):out.replace_all_uses_with(out.args[1])# type: ignore[arg-type]gm.graph.erase_node(out)gm.recompile()returngm,new_graph_signaturedef_common_getitem_elimination_pass(gm:torch.fx.GraphModule,graph_signature,module_call_graph):withgm._set_replace_hook(graph_signature.get_replace_hook()):formoduleingm.modules():ifnotisinstance(module,torch.fx.GraphModule):continuenode_id:Dict[torch.fx.Node,str]={}getitems:Dict[str,torch.fx.Node]={}fornodeinlist(module.graph.nodes):ifnode.op=="call_function"andnode.target==operator.getitem:source,idx=node.argsnew_id=f"{node_id[source]}.{idx}"ifnew_idingetitems:node.replace_all_uses_with(getitems[new_id])forentryinmodule_call_graph:ifentry.signatureisnotNone:entry.signature.replace_all_uses_with(node,getitems[new_id])module.graph.erase_node(node)else:getitems[new_id]=nodenode_id[node]=new_idelse:node_id[node]=node.namedef_get_updated_module_call_graph(gm:torch.fx.GraphModule,old_module_call_graph:List[ModuleCallEntry],):new_module_call_graph=copy.deepcopy(old_module_call_graph)# use node-level provenance metadata to create a map# from old node names to new node namesprovenance:Dict[str,str]={}fornodeingm.graph.nodes:ifhistory:=node.meta.get("from_node",[]):provenance[history[-1].name]=node.name# map old names to new names in module call signaturesforentryinnew_module_call_graph:signature=entry.signatureifsignatureisNone:continueforxin[*signature.inputs,*signature.outputs]:x.name=provenance.get(x.name,x.name)returnnew_module_call_graphdef_decompose_exported_program(ep,*,cia_to_decomp:Dict[torch._ops.OperatorBase,Callable],python_decomp_table:Dict[torch._ops.OperatorBase,Callable],joint_loss_index:Optional[int],):gm,new_graph_signature=_decompose_and_get_gm_with_new_signature_constants(ep,cia_to_decomp=cia_to_decomp,python_decomp_table=python_decomp_table,joint_loss_index=joint_loss_index,)# The signatures of ep.module_call_graph refer to input / output nodes of# the original graph module. However, the new graph module may have# new nodes due to decompositions. So we need to update these signatures# in the decomposed exported program's module_call_graph.new_module_call_graph=_get_updated_module_call_graph(gm,ep.module_call_graph,)# TODO unfortunately preserving graph-level metadata is not# working well with aot_export. So we manually copy it.# (The node-level meta is addressed above.)gm.meta.update(ep.graph_module.meta)new_range_constraints=_get_updated_range_constraints(gm,ep.range_constraints,)exported_program=ExportedProgram(root=gm,graph=gm.graph,graph_signature=new_graph_signature,state_dict=ep.state_dict,range_constraints=new_range_constraints,module_call_graph=new_module_call_graph,example_inputs=ep.example_inputs,constants=ep.constants,)returnexported_program
[docs]classExportedProgram:""" Package of a program from :func:`export`. It contains an :class:`torch.fx.Graph` that represents Tensor computation, a state_dict containing tensor values of all lifted parameters and buffers, and various metadata. You can call an ExportedProgram like the original callable traced by :func:`export` with the same calling convention. To perform transformations on the graph, use ``.module`` property to access an :class:`torch.fx.GraphModule`. You can then use `FX transformation <https://pytorch.org/docs/stable/fx.html#writing-transformations>`_ to rewrite the graph. Afterwards, you can simply use :func:`export` again to construct a correct ExportedProgram. """def__init__(self,root:Union[torch.nn.Module,Dict[str,Any]],graph:torch.fx.Graph,graph_signature:ExportGraphSignature,state_dict:Dict[str,Union[torch.Tensor,torch.nn.Parameter]],range_constraints:"Dict[sympy.Symbol, Any]",module_call_graph:List[ModuleCallEntry],example_inputs:Optional[Tuple[Tuple[Any,...],Dict[str,Any]]]=None,constants:Optional[Dict[str,Union[torch.Tensor,FakeScriptObject,torch._C.ScriptObject]]]=None,*,verifiers:Optional[List[Type[Verifier]]]=None,):# Remove codegen related things from the graph. It should just be a flat graph.graph._codegen=torch.fx.graph.CodeGen()self._graph_module=_create_graph_module_for_export(root,graph)ifisinstance(root,torch.fx.GraphModule):self._graph_module.meta.update(root.meta)_common_getitem_elimination_pass(self._graph_module,graph_signature,module_call_graph)self._graph_signature:ExportGraphSignature=graph_signatureself._state_dict:Dict[str,Any]=state_dictself._range_constraints:Dict[sympy.Symbol,ValueRanges]=range_constraintsassertmodule_call_graphisnotNoneself._module_call_graph:List[ModuleCallEntry]=module_call_graphself._example_inputs=example_inputsself._constants=constantsor{}verifiers=verifiersor[Verifier]assertall(issubclass(v,Verifier)forvinverifiers)self._verifiers=verifiers# Validate should be always the last step of the constructor.self.validate()@property@compatibility(is_backward_compatible=False)defgraph_module(self):returnself._graph_module@graph_module.setter@compatibility(is_backward_compatible=False)defgraph_module(self,value):raiseRuntimeError("Unable to set ExportedProgram's graph_module attribute.")@property@compatibility(is_backward_compatible=False)defgraph(self):returnself.graph_module.graph@graph.setter@compatibility(is_backward_compatible=False)defgraph(self,value):raiseRuntimeError("Unable to set ExportedProgram's graph attribute.")@property@compatibility(is_backward_compatible=False)defgraph_signature(self):returnself._graph_signature@graph_signature.setter@compatibility(is_backward_compatible=False)defgraph_signature(self,value):raiseRuntimeError("Unable to set ExportedProgram's graph_signature attribute.")@property@compatibility(is_backward_compatible=False)defstate_dict(self):returnself._state_dict@state_dict.setter@compatibility(is_backward_compatible=False)defstate_dict(self,value):raiseRuntimeError("Unable to set ExportedProgram's state_dict attribute.")
[docs]@compatibility(is_backward_compatible=False)defparameters(self)->Iterator[torch.nn.Parameter]:""" Returns an iterator over original module's parameters. """for_,paraminself.named_parameters():yieldparam
[docs]@compatibility(is_backward_compatible=False)defnamed_parameters(self)->Iterator[Tuple[str,torch.nn.Parameter]]:""" Returns an iterator over original module parameters, yielding both the name of the parameter as well as the parameter itself. """forparam_nameinself.graph_signature.parameters:yieldparam_name,self.state_dict[param_name]
[docs]@compatibility(is_backward_compatible=False)defbuffers(self)->Iterator[torch.Tensor]:""" Returns an iterator over original module buffers. """for_,bufinself.named_buffers():yieldbuf
[docs]@compatibility(is_backward_compatible=False)defnamed_buffers(self)->Iterator[Tuple[str,torch.Tensor]]:""" Returns an iterator over original module buffers, yielding both the name of the buffer as well as the buffer itself. """non_persistent_buffers=set(self.graph_signature.non_persistent_buffers)forbuffer_nameinself.graph_signature.buffers:ifbuffer_nameinnon_persistent_buffers:yieldbuffer_name,self.constants[buffer_name]else:yieldbuffer_name,self.state_dict[buffer_name]
@property@compatibility(is_backward_compatible=False)defrange_constraints(self):returnself._range_constraints@range_constraints.setter@compatibility(is_backward_compatible=False)defrange_constraints(self,value):raiseRuntimeError("Unable to set ExportedProgram's range_constraints attribute.")@property@compatibility(is_backward_compatible=False)defmodule_call_graph(self):returnself._module_call_graph@module_call_graph.setter@compatibility(is_backward_compatible=False)defmodule_call_graph(self,value):raiseRuntimeError("Unable to set ExportedProgram's module_call_graph attribute.")@property@compatibility(is_backward_compatible=False)defexample_inputs(self):returnself._example_inputs@example_inputs.setter@compatibility(is_backward_compatible=False)defexample_inputs(self,value):# This is allowedifnot(isinstance(value,tuple)andlen(value)==2andisinstance(value[0],tuple)andisinstance(value[1],dict)):raiseValueError("Example inputs should be a tuple containing example arguments (as ""a tuple), and example kwargs (as a dictionary).")args,kwargs=valuefrom._unliftimport_check_inputs_match_check_inputs_match(args,kwargs,self.call_spec.in_spec)self._example_inputs=value@property@compatibility(is_backward_compatible=False)defcall_spec(self):CallSpec=namedtuple("CallSpec",["in_spec","out_spec"])iflen(self.module_call_graph)==0:returnCallSpec(in_spec=None,out_spec=None)assertself.module_call_graph[0].fqn==""returnCallSpec(in_spec=self.module_call_graph[0].signature.in_spec,out_spec=self.module_call_graph[0].signature.out_spec,)@call_spec.setter@compatibility(is_backward_compatible=False)defcall_spec(self,value):raiseRuntimeError("Unable to set ExportedProgram's call_spec attribute.")@property@compatibility(is_backward_compatible=False)defverifier(self)->Any:returnself._verifiers[0]@verifier.setter@compatibility(is_backward_compatible=False)defverifier(self,value):raiseRuntimeError("Unable to set ExportedProgram's verifier attribute.")@property@compatibility(is_backward_compatible=False)defdialect(self)->str:assertself._verifiersisnotNonereturnself._verifiers[0].dialect@dialect.setter@compatibility(is_backward_compatible=False)defdialect(self,value):raiseRuntimeError("Unable to set ExportedProgram's dialect attribute.")@property@compatibility(is_backward_compatible=False)defverifiers(self):returnself._verifiers@verifiers.setter@compatibility(is_backward_compatible=False)defverifiers(self,value):raiseRuntimeError("Unable to set ExportedProgram's verifiers attribute.")@property@compatibility(is_backward_compatible=False)deftensor_constants(self):returnself._constants@tensor_constants.setter@compatibility(is_backward_compatible=False)deftensor_constants(self,value):raiseRuntimeError("Unable to set ExportedProgram's tensor_constants attribute.")@property@compatibility(is_backward_compatible=False)defconstants(self):returnself._constants@constants.setter@compatibility(is_backward_compatible=False)defconstants(self,value):raiseRuntimeError("Unable to set ExportedProgram's constants attribute.")def_get_flat_args_with_check(self,args,kwargs):"""Flatten args, kwargs using pytree, then, check specs. Args: args: List[Any] original args passed to __call__ kwargs: Dict[str, Any] original kwargs passed to __call Returns: A tuple of (flat_args, received_spec) flat_args is flattend args / kwargs received_spec is the pytree spec produced while flattening the tuple (args, kwargs) """in_spec=self.call_spec.in_specifin_specisnotNone:kwargs=reorder_kwargs(kwargs,in_spec)flat_args_with_path,received_spec=pytree.tree_flatten_with_path((args,kwargs))self._check_input_constraints(flat_args_with_path)flat_args=tuple(x[1]forxinflat_args_with_path)returnflat_args,received_specdef_graph_module_flat_inputs(self,args:Any,kwargs:Any)->Any:"""Transform args, kwargs of __call__ to args for graph_module. self.graph_module takes stuff from state dict as inputs. The invariant is for ep: ExportedProgram is ep(args, kwargs) == ep.postprocess(ep.graph_module(ep.graph_module_flat_inputs(args, kwargs))) """in_spec=self.call_spec.in_specflat_args,received_spec=self._get_flat_args_with_check(args,kwargs)ifin_specisnotNoneandnotis_equivalent(received_spec,in_spec,_fx_collection_equivalence_fn):raiseValueError("Trying to flatten user inputs with exported input tree spec: \n"f"{in_spec}\n""but actually got inputs with tree spec of: \n"f"{received_spec}")additional_inputs=[]forinput_inself.graph_signature.input_specs:ifinput_.kind==InputKind.USER_INPUT:continueelifinput_.kindin(InputKind.PARAMETER,InputKind.BUFFER,):ifinput_.persistentisFalse:# This is a non-persistent buffer, grab it from our# constants instead of the state dict.additional_inputs.append(self.constants[input_.target])else:additional_inputs.append(self.state_dict[input_.target])elifinput_.kindin(InputKind.CONSTANT_TENSOR,InputKind.CUSTOM_OBJ,):additional_inputs.append(self.constants[input_.target])additional_inputs=tuple(additional_inputs)# NOTE: calling convention is first params, then buffers, then args as user supplied them.# See: torch/_functorch/aot_autograd.py#L1034returnadditional_inputs+flat_argsdef__call__(self,*args:Any,**kwargs:Any)->Any:raiseRuntimeError("Unable to call ExportedProgram directly. ""You should use `exported_program.module()` instead.")def_postprocess_graph_module_outputs(self,res,orig_args,orig_kwargs):"""Process potential mutations to the input. Because self.graph_module is functional, so mutations has to be written back after execution of graph_module. """importtorch._export.erroraserrorflat_args,_=self._get_flat_args_with_check(orig_args,orig_kwargs)ifself.call_spec.out_specisnotNone:buffer_mutation=self.graph_signature.buffers_to_mutateuser_input_mutation=self.graph_signature.user_inputs_to_mutatenum_mutated=len(buffer_mutation)+len(user_input_mutation)mutated_values=res[:num_mutated]# Exclude dependency token from final result.assertion_dep_token=self.graph_signature.assertion_dep_tokenifassertion_dep_tokenisnotNone:assertion_dep_token_index=next(iter(assertion_dep_token.keys()))res=res[:assertion_dep_token_index]res=res[num_mutated:]try:res=pytree.tree_unflatten(res,self.call_spec.out_spec)exceptException:_,received_spec=pytree.tree_flatten(res)raiseerror.InternalError(# noqa: B904"Trying to flatten user outputs with exported output tree spec: \n"f"{self.call_spec.out_spec}\n""but actually got outputs with tree spec of: \n"f"{received_spec}")finally:user_inputs=[specforspecinself.graph_signature.input_specsifspec.kind==InputKind.USER_INPUT]fori,valueinenumerate(mutated_values):output_spec=self.graph_signature.output_specs[i]ifoutput_spec.kind==OutputKind.BUFFER_MUTATION:assertoutput_spec.targetisnotNoneself.state_dict[output_spec.target]=valueelifoutput_spec.kind==OutputKind.USER_INPUT_MUTATION:assertoutput_spec.targetisnotNoneindex=next(ifori,specinenumerate(user_inputs)ifspec.arg.name==output_spec.target)flat_args[index].copy_(value)else:raiseAssertionError(f"Unexpected kind: {output_spec.kind}")returnresdef__str__(self)->str:graph_module=self.graph_module.print_readable(print_output=False,colored=False).replace("\n","\n ")string=("ExportedProgram:\n"f" {graph_module}\n"f"Graph signature: {self.graph_signature}\n"f"Range constraints: {self.range_constraints}\n")returnstring
[docs]defmodule(self)->torch.nn.Module:""" Returns a self contained GraphModule with all the parameters/buffers inlined. """from._unliftimport_unlift_exported_program_lifted_statesmodule=_unlift_exported_program_lifted_states(self)def_train(self,mode:bool=True):raiseNotImplementedError("Calling train() is not supported yet.")def_eval(self,mode:bool=True):raiseNotImplementedError("Calling eval() is not supported yet.")module.train=types.MethodType(_train,module)# type: ignore[method-assign]module.eval=types.MethodType(_eval,module)# type: ignore[method-assign]returnmodule
[docs]@_disable_prexisiting_fake_modedefrun_decompositions(self,decomp_table:Optional[Dict[torch._ops.OperatorBase,Callable]]=None,)->"ExportedProgram":""" Run a set of decompositions on the exported program and returns a new exported program. By default we will run the Core ATen decompositions to get operators in the `Core ATen Operator Set <https://pytorch.org/docs/stable/torch.compiler_ir.html>`_. For now, we do not decompose joint graphs. Args: decomp_table: An optional argument that specifies decomp behaviour for Aten ops (1) If None, we decompose to core aten decompositions (2) If empty, we don't decompose any operator Some examples: If you don't want to decompose anything .. code-block:: python ep = torch.export.export(model, ...) ep = ep.run_decompositions(decomp_table={}) If you want to get a core aten operator set except for certain operator, you can do following: .. code-block:: python ep = torch.export.export(model, ...) decomp_table = torch.export.default_decompositions() decomp_table[your_op] = your_custom_decomp ep = ep.run_decompositions(decomp_table=decomp_table) """_decomp_table=(default_decompositions()ifdecomp_tableisNoneelsedict(decomp_table))ifisinstance(_decomp_table,CustomDecompTable):_decomp_table=_decomp_table.materialize()# Note [Seperating decomp_table into CIA decomps and non-CIA decomps]# At this point, we have a decomp_table that contains decomp behaviour for# both CIA and post-autograd ops.# We need to separate the op into two categories:# 1. CIA op: These are the ops that we want to override# CompositeImplicitAutograd decomp for. For them, we need to use _override_composite_implicit_decomp# context manager to plumb it through AOTDispatcher# 2. Non-CIA op: These ops are only relevant after AOTDIspatcher runs, so just# checking if they are statically functional is enough.# For joint IR case tho, we need to use the old path because we can't register# custom decomps this way because we can't use context manager as it installs# autograd_error node.(cia_to_decomp,python_decomp_table,)=_split_decomp_table_to_cia_and_python_decomp(_decomp_table)return_decompose_exported_program(self,cia_to_decomp=cia_to_decomp,python_decomp_table=python_decomp_table,joint_loss_index=None,)
def_transform_do_not_use(self,*passes:PassType)->"ExportedProgram":pm=PassManager(list(passes))# Since we abstractly run the passes, we need to disable backend decomp here# again.fromtorch.export._traceimport_ignore_backend_decompswith_ignore_backend_decomps():res=pm(self.graph_module)transformed_gm=res.graph_moduleifresisnotNoneelseself.graph_moduleasserttransformed_gmisnotNoneiftransformed_gmisself.graph_moduleandnotres.modified:returnself# TODO(zhxchen17) Remove this.def_get_updated_graph_signature(old_signature:ExportGraphSignature,new_gm:torch.fx.GraphModule,)->ExportGraphSignature:""" Update the graph signature's user_input/user_outputs. """new_input_specs=[]fori,nodeinenumerate(new_gm.graph.nodes):ifnode.op!="placeholder":breakasserti<len(old_signature.input_specs),"Number of inputs changed after transformation"old_input_spec=old_signature.input_specs[i]arg=(old_input_spec.argifisinstance(old_input_spec.arg,(ConstantArgument,CustomObjArgument))elsetype(old_input_spec.arg)(node.name))new_input_specs.append(InputSpec(old_input_spec.kind,arg,old_input_spec.target,old_input_spec.persistent,))output_node=list(new_gm.graph.nodes)[-1]assertoutput_node.op=="output"new_output_specs=[]fori,nodeinenumerate(output_node.args[0]):asserti<len(old_signature.output_specs),"Number of outputs changed after transformation"old_output_spec=old_signature.output_specs[i]arg=(old_output_spec.argifisinstance(old_output_spec.arg,(ConstantArgument,CustomObjArgument))elsetype(old_output_spec.arg)(node.name))new_output_specs.append(OutputSpec(old_output_spec.kind,arg,old_output_spec.target))new_signature=ExportGraphSignature(input_specs=new_input_specs,output_specs=new_output_specs)returnnew_signaturetransformed_ep=ExportedProgram(root=transformed_gm,graph=transformed_gm.graph,graph_signature=_get_updated_graph_signature(self.graph_signature,transformed_gm),state_dict=self.state_dict,range_constraints=_get_updated_range_constraints(transformed_gm,self.range_constraints,),module_call_graph=copy.deepcopy(self._module_call_graph),example_inputs=self.example_inputs,constants=self.constants,verifiers=self.verifiers,)transformed_ep.graph_module.meta.update(self.graph_module.meta)transformed_ep.graph_module.meta.update(res.graph_module.meta)returntransformed_epdef_check_input_constraints(self,flat_args_with_path):fromtorch._export.utilsimport_check_input_constraints_for_graphplaceholders=[pforpinself.graph.nodesifp.op=="placeholder"]input_placeholders=[pforp,sinzip(placeholders,self.graph_signature.input_specs)ifs.kind==InputKind.USER_INPUT]_check_input_constraints_for_graph(input_placeholders,flat_args_with_path,self.range_constraints)@compatibility(is_backward_compatible=False)defvalidate(self):self._validate()# TODO: remove this@finaldef_validate(self):assert(len(self.verifiers)>0),"ExportedProgram must have at least one verifier."forvinself.verifiers:v().check(self)# TODO(zhxchen17) Formalize this.def_update(self,graph_module,graph_signature,*,state_dict=None,verifiers=None)->"ExportedProgram":returnExportedProgram(root=graph_module,graph=graph_module.graph,graph_signature=graph_signature,state_dict=state_dictifstate_dictisnotNoneelseself.state_dict,range_constraints=copy.deepcopy(self.range_constraints),module_call_graph=copy.deepcopy(self._module_call_graph),example_inputs=self.example_inputs,constants=self.constants,verifiers=verifiersifverifiersisnotNoneelseself.verifiers,)
def_get_shape_env(gm):vals=[node.meta["val"]fornodeingm.graph.nodesifnode.meta.get("val",None)isnotNone]fromtorch._guardsimportdetect_fake_modefake_mode=detect_fake_mode(vals)iffake_modeisnotNone:returnfake_mode.shape_envforvinvals:ifisinstance(v,torch.SymInt):returnv.node.shape_envdef_get_updated_range_constraints(gm:torch.fx.GraphModule,old_range_constraints:"Optional[Dict[sympy.Symbol, Any]]"=None,)->"Dict[sympy.Symbol, Any]":assertold_range_constraintsisnotNoneshape_env=_get_shape_env(gm)ifshape_envisNone:return{}range_constraints=copy.copy(old_range_constraints)range_constraints={k:vfork,vinrange_constraints.items()ifknotinshape_env.replacements}# Only when we have an unbacked symint, and it's used as constructor inputs,# runtime_var_to_range will make a difference compated to var_to_range.# e.g. [2, oo) -> [0, oo)fork,vinshape_env.var_to_range.items():ifknotinshape_env.replacementsandknotinrange_constraints:range_constraints[k]=vreturnrange_constraintsdef_create_graph_module_for_export(root,graph):try:gm=torch.fx.GraphModule(root,graph)exceptSyntaxError:# If custom objects stored in memory are being used in the graph,# the generated python code will result in a syntax error on the custom# object, since it is unable to parse the in-memory object. However# we can still run the graph eagerly through torch.fx.Interpreter,# so we will bypass this error.warnings.warn("Unable to execute the generated python source code from ""the graph. The graph module will no longer be directly callable, ""but you can still run the ExportedProgram, and if needed, you can ""run the graph module eagerly using torch.fx.Interpreter.")gm=torch.fx.GraphModule(root,torch.fx.Graph())gm._graph=graphreturngm
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.