# mypy: allow-untyped-decorators# mypy: allow-untyped-defsimportcontextlibimportcopyimportdataclassesimportfunctoolsimportoperatorimporttypesimportwarningsfromcollectionsimportnamedtuplefromcontextlibimportcontextmanagerfromtypingimport(Any,Callable,Dict,final,Iterator,List,Optional,Tuple,Type,TYPE_CHECKING,Union,)fromtorch._higher_order_ops.utilsimportautograd_not_implementedfromtorch._library.fake_class_registryimportFakeScriptObjectfromtorch.fx._utilsimportfirst_call_function_nn_module_stackfromtorch.fx.graphimport_PyTreeCodeGen,_PyTreeInfofromtorch.fx.immutable_collectionsimportimmutable_dict,immutable_listfromtorch.fx.passes.runtime_assertimportinsert_deferred_runtime_assertsifTYPE_CHECKING:# Import the following modules during type checking to enable code intelligence features,# such as auto-completion in tools like pylance, even when these modules are not explicitly# imported in user code.importsympyfromtorch.utils._sympy.value_rangesimportValueRangesimporttorchimporttorch.utils._pytreeaspytreefromtorch._export.utilsimport(_collect_and_set_constant_attrs,_collect_param_buffer_metadata,_detect_fake_mode_from_gm,_name_hoo_subgraph_placeholders,_overwrite_signature_for_non_persistent_buffers,_populate_param_buffer_metadata_to_new_gm,_rename_without_collisions,)fromtorch._export.verifierimportVerifierfromtorch._guardsimportdetect_fake_modefromtorch._subclasses.fake_tensorimportunset_fake_temporarilyfromtorch._subclasses.functional_tensorimportFunctionalTensorfromtorch.export._tree_utilsimportis_equivalent,reorder_kwargsfromtorch.fx._compatibilityimportcompatibilityfromtorch.fx.passes.infra.pass_baseimportPassResultfromtorch.fx.passes.infra.pass_managerimportPassManagerfrom.graph_signatureimport(# noqa: F401ArgumentSpec,ConstantArgument,CustomObjArgument,ExportGraphSignature,InputKind,InputSpec,OutputKind,OutputSpec,SymIntArgument,TensorArgument,TokenArgument,)__all__=["ExportedProgram","ModuleCallEntry","ModuleCallSignature",]PassType=Callable[[torch.fx.GraphModule],Optional[PassResult]]
def_disable_prexisiting_fake_mode(fn):@functools.wraps(fn)defwrapper(*args,**kwargs):withunset_fake_temporarily():returnfn(*args,**kwargs)returnwrapperdef_fx_collection_equivalence_fn(spec1_type:Optional[type],spec1_context:pytree.Context,spec2_type:Optional[type],spec2_context:pytree.Context,)->bool:"""Treat containers and their immutable variants as the same type. Otherwise compare as normal. """ifspec1_typeisNoneorspec2_typeisNone:returnspec1_typeisspec2_typeandspec1_context==spec2_contextifissubclass(spec1_type,(dict,immutable_dict))andissubclass(spec2_type,(dict,immutable_dict)):returnspec1_context==spec2_contextifissubclass(spec1_type,(list,immutable_list))andissubclass(spec2_type,(list,immutable_list)):returnspec1_context==spec2_contextreturnspec1_typeisspec2_typeandspec1_context==spec2_contextdef_register_cia_to_meta(*args,**kwargs):kernel=kwargs["kernel"]delkwargs["kernel"]asserttorch._C._dispatch_has_kernel_for_dispatch_key(kernel.name(),torch._C.DispatchKey.CompositeImplicitAutograd)returnkernel._op_dk(torch._C.DispatchKey.CompositeImplicitAutograd,*args,**kwargs)# This list is compiled from DispatchKey.cpp.# The idea is that we use these keys to override# CIA decomp in export_AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE=[torch._C.DispatchKey.AutogradCPU,torch._C.DispatchKey.AutogradCUDA,torch._C.DispatchKey.AutogradMeta,torch._C.DispatchKey.AutogradXLA,torch._C.DispatchKey.AutogradLazy,torch._C.DispatchKey.AutogradIPU,torch._C.DispatchKey.AutogradXPU,torch._C.DispatchKey.AutogradMPS,torch._C.DispatchKey.AutogradHPU,torch._C.DispatchKey.AutogradPrivateUse1,torch._C.DispatchKey.AutogradPrivateUse2,torch._C.DispatchKey.AutogradPrivateUse3,]@contextmanagerdef_override_composite_implicit_decomp(ops_to_preserve,decomp_table,safe=True):# This function overrides CompositeImplicitAutograd decomp for# functional composite ops that user specified. Ideally we want to not-decompose# ALL composite ops but today's C++ functinalization relies on# the fact that it is working with the opset after decomp is run.# Hence we can only do it for functional ops. One caveat is that# there are some composite ops that lie about their schema (claimed to be# functional but not really aka dropout), for these cases, we just decompose.# When safe=False, we will assume that ops_to_preserve can be mutating/aliasing# and their usual decompositions need to be shadowed rather than overridden.# Thus we will avoid asserting that they are valid to preserve, and will not# replace their CompositeImplicitAutograd kernels with NotImplemented.# The only current users of this mode are variants of aten::to that we will# replace with aten::_to_copy in FunctionalTensorMode.__torch_dispatch__.saved_tables={}patched_ops=set()removed_decomps={}forop_overloadinops_to_preserve:# Our strategy for deciding if we can preserve CIA is following:# 1. The op should be known statically that it is functional# 2. If it is maybe aliasing, we decompose because we must know if an op# is mutating or aliasing.# TODO (tmanlaibaatar) make this utility function and share it with functional_tensor# decomp part. (https://github.com/pytorch/pytorch/issues/129431)defassert_valid_to_preserve(op_overload):ifop_overloadinFunctionalTensor.maybe_aliasing_or_mutating_ops:raiseRuntimeError(f"We can't detect {op_overload} as a functional op statically, so we can't preserve it")ifop_overloadinFunctionalTensor.metadata_fns:raiseRuntimeError(f"{op_overload} is a metadata query function, ""it will be preserved implicitly in our tracing system. ""Please file an issue on github if you see otherwise")alias_info=len([iforiinop_overload._schema.argumentsifi.alias_infoisnotNone])is_mutating_or_aliasing=alias_info!=0orop_overload._schema.is_mutableifis_mutating_or_aliasing:raiseRuntimeError(f"{op_overload} is a mutating/aliasing op, we can't preserve it as is")ifnottorch._C._dispatch_has_kernel(op_overload.name()):raiseRuntimeError(f"{op_overload} is a TorchScript op, we can't preserve it as is")returnTrueifsafe:# If we didn't error, it means we can go aheadassert_valid_to_preserve(op_overload)saved_tables[op_overload]=op_overload.py_kernels.copy()patched_ops.add(op_overload)foroverride_dispatch_keyin_AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE:ifoverride_dispatch_keynotinop_overload.py_kernels:# TODO (tmanlaibaatar)https://github.com/pytorch/pytorch/issues/129430op_overload.py_impl(override_dispatch_key)(autograd_not_implemented(op_overload,deferred_error=True))iftorch._C.DispatchKey.CompositeImplicitAutogradinop_overload.py_kernels:delop_overload.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd]ifsafe:def_(*args,**kwargs):returnNotImplementedop_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)(_)# For fake tensor prop, we do want to register meta kernel directlyiftorch._C.DispatchKey.Metanotinop_overload.py_kernels:op_overload.py_impl(torch._C.DispatchKey.Meta)(functools.partial(_register_cia_to_meta,kernel=op_overload))ifop_overloadindecomp_table:removed_decomps[op_overload]=decomp_table[op_overload]deldecomp_table[op_overload]try:yieldfinally:foropinpatched_ops:op.py_kernels.clear()op.py_kernels.update(saved_tables[op])op._dispatch_cache.clear()forop,decompinremoved_decomps.items():decomp_table[op]=decomp@contextmanagerdef_override_decomp_aten_to_variants():# Preserve variants of aten::to understanding that they are mutating/aliasing# and their CompositeImplicitAutograd kernels will not become NotImplemented.# We will later replace them with aten._to_copy when functionalizing.with_override_composite_implicit_decomp((torch.ops.aten.to.dtype_layout,torch.ops.aten.to.dtype),{},safe=False,):yielddef_decompose_and_get_gm_with_new_signature_constants(ep,*,decomp_table:Dict[torch._ops.OperatorBase,Callable],_preserve_ops:Tuple[torch._ops.OpOverload],joint_loss_index:Optional[int],):fromtorch._functorch.aot_autogradimportaot_export_modulefromtorch._subclasses.fake_tensorimportFakeTensorModefromtorch.export._traceimport(_export_to_aten_ir,_fakify_params_buffers,_ignore_backend_decomps,_verify_nn_module_stack,_verify_placeholder_names,_verify_stack_trace,)fromtorch.fx.experimental.symbolic_shapesimportShapeEnv# TODO Merge this path with inference IR decomp, but it will require some additional work# so I will leave it for now. T200307782ifep.verifier.dialect=="TRAINING":mod=ep.module()fake_args=[]fornodeinmod.graph.nodes:ifnode.op=="placeholder":fake_args.append(node.meta["val"])fake_args_unwrapped=pytree.tree_unflatten(fake_args,mod._in_spec)fake_mode=_detect_fake_mode_from_gm(mod)iffake_modeisNone:fake_mode=FakeTensorMode(shape_env=ShapeEnv(),export=True)# Fix the graph output signature to be tuple if scalarout_spec=mod._out_specorig_arg_names=mod.graph._codegen.pytree_info.orig_args# type: ignore[attr-defined]# aot_export expect the return type to always be a tuple.ifout_spec.typenotin(list,tuple):out_spec=pytree.TreeSpec(tuple,None,[out_spec])mod.graph._codegen=_PyTreeCodeGen(_PyTreeInfo(orig_arg_names,mod._in_spec,out_spec,))mod.recompile()# the exported module will store constants & non-persistent buffers such that# retracing treats them as persistent buffers, so we inform the constants lifting pass# and overwrite the new graph signature using the previous program.constant_attrs=_collect_and_set_constant_attrs(ep.graph_signature,ep.constants,mod)# get params & buffers after excluding constantsfake_params_buffers=_fakify_params_buffers(fake_mode,mod)params_buffers_to_node_meta=_collect_param_buffer_metadata(mod)with_ignore_backend_decomps(),(fake_mode),_override_decomp_aten_to_variants(),_override_composite_implicit_decomp(_preserve_ops,decomp_table,):aten_export_artifact=_export_to_aten_ir(mod,# this requires empty kwargs, but not in pytree.flattened format(*fake_args_unwrapped[0],*fake_args_unwrapped[1].values(),),{},fake_params_buffers,constant_attrs,decomp_table=decomp_table,_check_autograd_state=False,)gm=aten_export_artifact.gmnew_graph_signature=aten_export_artifact.sig_populate_param_buffer_metadata_to_new_gm(params_buffers_to_node_meta,gm,new_graph_signature)# overwrite signature for non-persistent buffersnew_graph_signature=_overwrite_signature_for_non_persistent_buffers(ep.graph_signature,new_graph_signature)_verify_nn_module_stack(gm)_verify_stack_trace(gm)_verify_placeholder_names(gm,new_graph_signature)return_remove_unneccessary_copy_op_pass(gm,new_graph_signature)old_placeholders=[nodefornodeinep.graph_module.graph.nodesifnode.op=="placeholder"]fake_args=[node.meta["val"]fornodeinold_placeholders]buffers_to_remove=[nameforname,_inep.graph_module.named_buffers()]fornameinbuffers_to_remove:delattr(ep.graph_module,name)# TODO(zhxhchen17) Return the new graph_signature directly.fake_mode=detect_fake_mode(fake_args)fake_mode=contextlib.nullcontext()iffake_modeisNoneelsefake_modewith_ignore_backend_decomps(),fake_mode,_override_composite_implicit_decomp(_preserve_ops,decomp_table,):gm,graph_signature=aot_export_module(ep.graph_module,fake_args,decompositions=decomp_table,trace_joint=Trueifjoint_loss_indexisnotNoneelseFalse,output_loss_index=joint_loss_indexifjoint_loss_indexisnotNoneelseNone,)# Update the signatures with the new placeholder names in case they# changed when calling aot_exportdefupdate_arg(old_arg,new_ph):ifisinstance(old_arg,ConstantArgument):returnold_argelifisinstance(old_arg,TensorArgument):returnTensorArgument(name=new_ph.name)elifisinstance(old_arg,SymIntArgument):returnSymIntArgument(name=new_ph.name)raiseRuntimeError(f"Type of old_arg not supported: {type(old_arg)}")new_placeholders=[nodefornodeingm.graph.nodesifnode.op=="placeholder"]new_outputs=list(gm.graph.nodes)[-1].args[0]# rename the placeholdersassertlen(new_placeholders)==len(old_placeholders)forold_ph,new_phinzip(old_placeholders,new_placeholders):new_ph.name=new_ph.target=old_ph.name# handle name collisions with newly decomposed graph nodesname_map={ph.name:ph.nameforphinnew_placeholders}fornodeingm.graph.nodes:ifnode.op=="placeholder":continuenode.name=_rename_without_collisions(name_map,node.name,node.name)# propagate names to higher order op subgraphs_name_hoo_subgraph_placeholders(gm)# Run this pass before creating input/output specs, since size-related CSE/DCE might affect output signature.# Overwrite output specs afterwards.fromtorch._export.passes._node_metadata_hookimport(_node_metadata_hook,_set_node_metadata_hook,)fromtorch._functorch._aot_autograd.input_output_analysisimport_graph_output_namesifnottorch._dynamo.config.do_not_emit_runtime_asserts:stack_trace=('File "torch/fx/passes/runtime_assert.py", line 24, '"in insert_deferred_runtime_asserts")shape_env=_get_shape_env(gm)ifshape_envisnotNone:with_set_node_metadata_hook(gm,functools.partial(_node_metadata_hook,stack_trace=stack_trace)):insert_deferred_runtime_asserts(gm,shape_env,f"exported program: {first_call_function_nn_module_stack(gm.graph)}",export=True,)# update output specsgm.recompile()fori,nameinenumerate(_graph_output_names(gm)):ifisinstance(new_outputs[i],torch.fx.Node):new_outputs[i].name=name# To match the output target with correct input for input mutations# need to find the old to new placeholder mapold_new_placeholder_map={spec.arg.name:new_placeholders[i].namefori,specinenumerate(ep.graph_signature.input_specs)ifnotisinstance(spec.arg,ConstantArgument)}input_specs=[InputSpec(spec.kind,update_arg(spec.arg,new_placeholders[i]),spec.target,spec.persistent,)fori,specinenumerate(ep.graph_signature.input_specs)]output_specs=[OutputSpec(spec.kind,update_arg(spec.arg,new_outputs[i]),old_new_placeholder_map.get(spec.target,spec.target),)fori,specinenumerate(ep.graph_signature.output_specs)]ifjoint_loss_indexisnotNone:assertgraph_signature.backward_signatureisnotNonegradients=graph_signature.backward_signature.gradients_to_user_inputsassertlen(graph_signature.user_inputs)==len(ep.graph_signature.input_specs)specs={graph_signature.user_inputs[i]:specfori,specinenumerate(ep.graph_signature.input_specs)ifisinstance(spec.arg,TensorArgument)}fori,nodeinenumerate(new_outputs[len(output_specs):]):source=gradients[node.name]spec=specs[source]# type: ignore[index]ifspec.kind==InputKind.PARAMETER:kind=OutputKind.GRADIENT_TO_PARAMETERtarget=spec.targetelifspec.kind==InputKind.USER_INPUT:kind=OutputKind.GRADIENT_TO_USER_INPUTtarget=sourceelse:raiseAssertionError(f"Unknown input kind: {spec.kind}")output_specs.append(OutputSpec(kind,TensorArgument(name=node.name),target,))assertlen(new_placeholders)==len(old_placeholders)new_graph_signature=ExportGraphSignature(input_specs=input_specs,output_specs=output_specs)# NOTE: aot_export adds symint metadata for placeholders with int# values; since these become specialized, we replace such metadata with# the original values.# Also, set the param/buffer metadata back to the placeholders.forold_node,new_nodeinzip(old_placeholders,new_placeholders):ifnotisinstance(old_node.meta["val"],torch.Tensor):new_node.meta["val"]=old_node.meta["val"]if(new_node.targetinnew_graph_signature.inputs_to_parametersornew_node.targetinnew_graph_signature.inputs_to_buffers):fork,vinold_node.meta.items():new_node.meta[k]=vreturngm,new_graph_signaturedef_remove_unneccessary_copy_op_pass(gm:torch.fx.GraphModule,new_graph_signature:ExportGraphSignature)->Tuple[torch.fx.GraphModule,ExportGraphSignature]:""" Removes redundant copy_ node that was introduced due to mutated buffer. """withgm._set_replace_hook(new_graph_signature.get_replace_hook()):fornodeingm.graph.nodes:ifnode.op=="output":args,_=pytree.tree_flatten(node.args)foroutinargs:if(isinstance(out,torch.fx.Node)andout.nameinnew_graph_signature.buffers_to_mutate):if(out.op=="call_function"andout.target==torch.ops.aten.copy.default):out.replace_all_uses_with(out.args[1])# type: ignore[arg-type]gm.graph.erase_node(out)gm.recompile()returngm,new_graph_signaturedef_common_getitem_elimination_pass(gm:torch.fx.GraphModule,graph_signature,module_call_graph):withgm._set_replace_hook(graph_signature.get_replace_hook()):formoduleingm.modules():ifnotisinstance(module,torch.fx.GraphModule):continuenode_id:Dict[torch.fx.Node,str]={}getitems:Dict[str,torch.fx.Node]={}fornodeinlist(module.graph.nodes):ifnode.op=="call_function"andnode.target==operator.getitem:source,idx=node.argsnew_id=f"{node_id[source]}.{idx}"ifnew_idingetitems:node.replace_all_uses_with(getitems[new_id])forentryinmodule_call_graph:ifentry.signatureisnotNone:entry.signature.replace_all_uses_with(node,getitems[new_id])module.graph.erase_node(node)else:getitems[new_id]=nodenode_id[node]=new_idelse:node_id[node]=node.namedef_decompose_exported_program(ep,*,decomp_table:Dict[torch._ops.OperatorBase,Callable],_preserve_ops:Tuple[torch._ops.OpOverload],joint_loss_index:Optional[int],):gm,new_graph_signature=_decompose_and_get_gm_with_new_signature_constants(ep,decomp_table=decomp_table,_preserve_ops=_preserve_ops,joint_loss_index=joint_loss_index,)# TODO unfortunately preserving graph-level metadata is not# working well with aot_export. So we manually copy it.# (The node-level meta is addressed above.)gm.meta.update(ep.graph_module.meta)new_range_constraints=_get_updated_range_constraints(gm,ep.range_constraints,)exported_program=ExportedProgram(root=gm,graph=gm.graph,graph_signature=new_graph_signature,state_dict=ep.state_dict,range_constraints=new_range_constraints,module_call_graph=copy.deepcopy(ep.module_call_graph),example_inputs=ep.example_inputs,constants=ep.constants,)returnexported_program
[docs]classExportedProgram:""" Package of a program from :func:`export`. It contains an :class:`torch.fx.Graph` that represents Tensor computation, a state_dict containing tensor values of all lifted parameters and buffers, and various metadata. You can call an ExportedProgram like the original callable traced by :func:`export` with the same calling convention. To perform transformations on the graph, use ``.module`` property to access an :class:`torch.fx.GraphModule`. You can then use `FX transformation <https://pytorch.org/docs/stable/fx.html#writing-transformations>`_ to rewrite the graph. Afterwards, you can simply use :func:`export` again to construct a correct ExportedProgram. """def__init__(self,root:Union[torch.nn.Module,Dict[str,Any]],graph:torch.fx.Graph,graph_signature:ExportGraphSignature,state_dict:Dict[str,Union[torch.Tensor,torch.nn.Parameter]],range_constraints:"Dict[sympy.Symbol, Any]",module_call_graph:List[ModuleCallEntry],example_inputs:Optional[Tuple[Tuple[Any,...],Dict[str,Any]]]=None,constants:Optional[Dict[str,Union[torch.Tensor,FakeScriptObject,torch._C.ScriptObject]]]=None,*,verifiers:Optional[List[Type[Verifier]]]=None,):# Remove codegen related things from the graph. It should just be a flat graph.graph._codegen=torch.fx.graph.CodeGen()self._graph_module=_create_graph_module_for_export(root,graph)ifisinstance(root,torch.fx.GraphModule):self._graph_module.meta.update(root.meta)_common_getitem_elimination_pass(self._graph_module,graph_signature,module_call_graph)self._graph_signature:ExportGraphSignature=graph_signatureself._state_dict:Dict[str,Any]=state_dictself._range_constraints:Dict[sympy.Symbol,ValueRanges]=range_constraintsassertmodule_call_graphisnotNoneself._module_call_graph:List[ModuleCallEntry]=module_call_graphself._example_inputs=example_inputsself._constants=constantsor{}verifiers=verifiersor[Verifier]assertall(issubclass(v,Verifier)forvinverifiers)self._verifiers=verifiers# Validate should be always the last step of the constructor.self.validate()@property@compatibility(is_backward_compatible=False)defgraph_module(self):returnself._graph_module@property@compatibility(is_backward_compatible=False)defgraph(self):returnself.graph_module.graph@property@compatibility(is_backward_compatible=False)defgraph_signature(self):returnself._graph_signature@property@compatibility(is_backward_compatible=False)defstate_dict(self):returnself._state_dict
[docs]@compatibility(is_backward_compatible=False)defparameters(self)->Iterator[torch.nn.Parameter]:""" Returns an iterator over original module's parameters. """for_,paraminself.named_parameters():yieldparam
[docs]@compatibility(is_backward_compatible=False)defnamed_parameters(self)->Iterator[Tuple[str,torch.nn.Parameter]]:""" Returns an iterator over original module parameters, yielding both the name of the parameter as well as the parameter itself. """forparam_nameinself.graph_signature.parameters:yieldparam_name,self.state_dict[param_name]
[docs]@compatibility(is_backward_compatible=False)defbuffers(self)->Iterator[torch.Tensor]:""" Returns an iterator over original module buffers. """for_,bufinself.named_buffers():yieldbuf
[docs]@compatibility(is_backward_compatible=False)defnamed_buffers(self)->Iterator[Tuple[str,torch.Tensor]]:""" Returns an iterator over original module buffers, yielding both the name of the buffer as well as the buffer itself. """non_persistent_buffers=set(self.graph_signature.non_persistent_buffers)forbuffer_nameinself.graph_signature.buffers:ifbuffer_nameinnon_persistent_buffers:yieldbuffer_name,self.constants[buffer_name]else:yieldbuffer_name,self.state_dict[buffer_name]
@property@compatibility(is_backward_compatible=False)defrange_constraints(self):returnself._range_constraints@property@compatibility(is_backward_compatible=False)defmodule_call_graph(self):returnself._module_call_graph@property@compatibility(is_backward_compatible=False)defexample_inputs(self):returnself._example_inputs@property@compatibility(is_backward_compatible=False)defcall_spec(self):CallSpec=namedtuple("CallSpec",["in_spec","out_spec"])iflen(self.module_call_graph)==0:returnCallSpec(in_spec=None,out_spec=None)assertself.module_call_graph[0].fqn==""returnCallSpec(in_spec=self.module_call_graph[0].signature.in_spec,out_spec=self.module_call_graph[0].signature.out_spec,)@property@compatibility(is_backward_compatible=False)defverifier(self)->Any:returnself._verifiers[0]@property@compatibility(is_backward_compatible=False)defdialect(self)->str:assertself._verifiersisnotNonereturnself._verifiers[0].dialect@property@compatibility(is_backward_compatible=False)defverifiers(self):returnself._verifiers@property@compatibility(is_backward_compatible=False)deftensor_constants(self):returnself._constants@property@compatibility(is_backward_compatible=False)defconstants(self):returnself._constantsdef_get_flat_args_with_check(self,args,kwargs):"""Flatten args, kwargs using pytree, then, check specs. Args: args: List[Any] original args passed to __call__ kwargs: Dict[str, Any] original kwargs passed to __call Returns: A tuple of (flat_args, received_spec) flat_args is flattend args / kwargs received_spec is the pytree spec produced while flattening the tuple (args, kwargs) """in_spec=self.call_spec.in_specifin_specisnotNone:kwargs=reorder_kwargs(kwargs,in_spec)flat_args_with_path,received_spec=pytree.tree_flatten_with_path((args,kwargs))# type: ignore[possibly-undefined]self._check_input_constraints(flat_args_with_path)flat_args=tuple(x[1]forxinflat_args_with_path)returnflat_args,received_specdef_graph_module_flat_inputs(self,args:Any,kwargs:Any)->Any:"""Transform args, kwargs of __call__ to args for graph_module. self.graph_module takes stuff from state dict as inputs. The invariant is for ep: ExportedProgram is ep(args, kwargs) == ep.postprocess(ep.graph_module(ep.graph_module_flat_inputs(args, kwargs))) """in_spec=self.call_spec.in_specflat_args,received_spec=self._get_flat_args_with_check(args,kwargs)ifin_specisnotNoneandnotis_equivalent(received_spec,in_spec,_fx_collection_equivalence_fn):raiseValueError("Trying to flatten user inputs with exported input tree spec: \n"f"{in_spec}\n""but actually got inputs with tree spec of: \n"f"{received_spec}")additional_inputs=[]forinput_inself.graph_signature.input_specs:ifinput_.kind==InputKind.USER_INPUT:continueelifinput_.kindin(InputKind.PARAMETER,InputKind.BUFFER,):ifinput_.persistentisFalse:# This is a non-persistent buffer, grab it from our# constants instead of the state dict.additional_inputs.append(self.constants[input_.target])else:additional_inputs.append(self.state_dict[input_.target])elifinput_.kindin(InputKind.CONSTANT_TENSOR,InputKind.CUSTOM_OBJ,):additional_inputs.append(self.constants[input_.target])additional_inputs=tuple(additional_inputs)# NOTE: calling convention is first params, then buffers, then args as user supplied them.# See: torch/_functorch/aot_autograd.py#L1034returnadditional_inputs+flat_argsdef__call__(self,*args:Any,**kwargs:Any)->Any:raiseRuntimeError("Unable to call ExportedProgram directly. ""You should use `exported_program.module()` instead.")def_postprocess_graph_module_outputs(self,res,orig_args,orig_kwargs):"""Process potential mutations to the input. Because self.graph_module is functional, so mutations has to be written back after execution of graph_module. """importtorch._export.erroraserrorflat_args,_=self._get_flat_args_with_check(orig_args,orig_kwargs)ifself.call_spec.out_specisnotNone:buffer_mutation=self.graph_signature.buffers_to_mutateuser_input_mutation=self.graph_signature.user_inputs_to_mutatenum_mutated=len(buffer_mutation)+len(user_input_mutation)mutated_values=res[:num_mutated]# Exclude dependency token from final result.assertion_dep_token=self.graph_signature.assertion_dep_tokenifassertion_dep_tokenisnotNone:assertion_dep_token_index=next(iter(assertion_dep_token.keys()))res=res[:assertion_dep_token_index]res=res[num_mutated:]try:res=pytree.tree_unflatten(res,self.call_spec.out_spec)exceptException:_,received_spec=pytree.tree_flatten(res)raiseerror.InternalError(# noqa: B904"Trying to flatten user outputs with exported output tree spec: \n"f"{self.call_spec.out_spec}\n""but actually got outputs with tree spec of: \n"f"{received_spec}")finally:user_inputs=[specforspecinself.graph_signature.input_specsifspec.kind==InputKind.USER_INPUT]fori,valueinenumerate(mutated_values):output_spec=self.graph_signature.output_specs[i]ifoutput_spec.kind==OutputKind.BUFFER_MUTATION:assertoutput_spec.targetisnotNoneself.state_dict[output_spec.target]=valueelifoutput_spec.kind==OutputKind.USER_INPUT_MUTATION:assertoutput_spec.targetisnotNoneindex=next(ifori,specinenumerate(user_inputs)ifspec.arg.name==output_spec.target)flat_args[index].copy_(value)else:raiseAssertionError(f"Unexpected kind: {output_spec.kind}")returnresdef__str__(self)->str:graph_module=self.graph_module.print_readable(print_output=False,colored=False).replace("\n","\n ")string=("ExportedProgram:\n"f" {graph_module}\n"f"Graph signature: {self.graph_signature}\n"f"Range constraints: {self.range_constraints}\n")returnstring
[docs]defmodule(self)->torch.nn.Module:""" Returns a self contained GraphModule with all the parameters/buffers inlined. """from._unliftimport_unlift_exported_program_lifted_statesmodule=_unlift_exported_program_lifted_states(self)def_train(self,mode:bool=True):raiseNotImplementedError("Calling train() is not supported yet.")def_eval(self,mode:bool=True):raiseNotImplementedError("Calling eval() is not supported yet.")module.train=types.MethodType(_train,module)# type: ignore[method-assign]module.eval=types.MethodType(_eval,module)# type: ignore[method-assign]returnmodule
[docs]@_disable_prexisiting_fake_modedefrun_decompositions(self,decomp_table:Optional[Dict[torch._ops.OperatorBase,Callable]]=None,_preserve_ops:Tuple[torch._ops.OpOverload,...]=(),)->"ExportedProgram":""" Run a set of decompositions on the exported program and returns a new exported program. By default we will run the Core ATen decompositions to get operators in the `Core ATen Operator Set <https://pytorch.org/docs/stable/torch.compiler_ir.html>`_. For now, we do not decompose joint graphs. """fromtorch._decompimportcore_aten_decompositionsifdecomp_tableisNone:decomp_table=core_aten_decompositions()return_decompose_exported_program(self,decomp_table=decomp_table,_preserve_ops=_preserve_ops,# type: ignore[arg-type]joint_loss_index=None,)
def_transform_do_not_use(self,*passes:PassType)->"ExportedProgram":pm=PassManager(list(passes))# Since we abstractly run the passes, we need to disable backend decomp here# again.fromtorch.export._traceimport_ignore_backend_decompswith_ignore_backend_decomps():res=pm(self.graph_module)transformed_gm=res.graph_moduleifresisnotNoneelseself.graph_moduleasserttransformed_gmisnotNoneiftransformed_gmisself.graph_moduleandnotres.modified:returnself# TODO(zhxchen17) Remove this.def_get_updated_graph_signature(old_signature:ExportGraphSignature,new_gm:torch.fx.GraphModule,)->ExportGraphSignature:""" Update the graph signature's user_input/user_outputs. """new_input_specs=[]fori,nodeinenumerate(new_gm.graph.nodes):ifnode.op!="placeholder":breakasserti<len(old_signature.input_specs),"Number of inputs changed after transformation"old_input_spec=old_signature.input_specs[i]arg=(old_input_spec.argifisinstance(old_input_spec.arg,(ConstantArgument,CustomObjArgument))elsetype(old_input_spec.arg)(node.name))new_input_specs.append(InputSpec(old_input_spec.kind,arg,old_input_spec.target,old_input_spec.persistent,))output_node=list(new_gm.graph.nodes)[-1]assertoutput_node.op=="output"new_output_specs=[]fori,nodeinenumerate(output_node.args[0]):asserti<len(old_signature.output_specs),"Number of outputs changed after transformation"old_output_spec=old_signature.output_specs[i]arg=(old_output_spec.argifisinstance(old_output_spec.arg,(ConstantArgument,CustomObjArgument))elsetype(old_output_spec.arg)(node.name))new_output_specs.append(OutputSpec(old_output_spec.kind,arg,old_output_spec.target))new_signature=ExportGraphSignature(input_specs=new_input_specs,output_specs=new_output_specs)returnnew_signaturetransformed_ep=ExportedProgram(root=transformed_gm,graph=transformed_gm.graph,graph_signature=_get_updated_graph_signature(self.graph_signature,transformed_gm),state_dict=self.state_dict,range_constraints=_get_updated_range_constraints(transformed_gm,self.range_constraints,),module_call_graph=copy.deepcopy(self._module_call_graph),example_inputs=self.example_inputs,constants=self.constants,verifiers=self.verifiers,)transformed_ep.graph_module.meta.update(self.graph_module.meta)transformed_ep.graph_module.meta.update(res.graph_module.meta)returntransformed_epdef_check_input_constraints(self,flat_args_with_path):fromtorch._export.utilsimport_check_input_constraints_for_graphplaceholders=[pforpinself.graph.nodesifp.op=="placeholder"]input_placeholders=[pforp,sinzip(placeholders,self.graph_signature.input_specs)ifs.kind==InputKind.USER_INPUT]_check_input_constraints_for_graph(input_placeholders,flat_args_with_path,self.range_constraints)@compatibility(is_backward_compatible=False)defvalidate(self):self._validate()# TODO: remove this@finaldef_validate(self):assert(len(self.verifiers)>0),"ExportedProgram must have at least one verifier."forvinself.verifiers:v().check(self)# TODO(zhxchen17) Formalize this.def_update(self,graph_module,graph_signature,*,state_dict=None,verifiers=None)->"ExportedProgram":returnExportedProgram(root=graph_module,graph=graph_module.graph,graph_signature=graph_signature,state_dict=state_dictifstate_dictisnotNoneelseself.state_dict,range_constraints=copy.deepcopy(self.range_constraints),module_call_graph=copy.deepcopy(self._module_call_graph),example_inputs=self.example_inputs,constants=self.constants,verifiers=verifiersifverifiersisnotNoneelseself.verifiers,)
def_get_shape_env(gm):vals=[node.meta["val"]fornodeingm.graph.nodesifnode.meta.get("val",None)isnotNone]fromtorch._guardsimportdetect_fake_modefake_mode=detect_fake_mode(vals)iffake_modeisnotNone:returnfake_mode.shape_envforvinvals:ifisinstance(v,torch.SymInt):returnv.node.shape_envdef_get_updated_range_constraints(gm:torch.fx.GraphModule,old_range_constraints:"Optional[Dict[sympy.Symbol, Any]]"=None,)->"Dict[sympy.Symbol, Any]":assertold_range_constraintsisnotNoneshape_env=_get_shape_env(gm)ifshape_envisNone:return{}range_constraints=copy.copy(old_range_constraints)range_constraints={k:vfork,vinrange_constraints.items()ifknotinshape_env.replacements}# Only when we have an unbacked symint, and it's used as constructor inputs,# runtime_var_to_range will make a difference compated to var_to_range.# e.g. [2, oo) -> [0, oo)fork,vinshape_env.var_to_range.items():ifknotinshape_env.replacementsandknotinrange_constraints:range_constraints[k]=vreturnrange_constraintsdef_create_graph_module_for_export(root,graph):try:gm=torch.fx.GraphModule(root,graph)exceptSyntaxError:# If custom objects stored in memory are being used in the graph,# the generated python code will result in a syntax error on the custom# object, since it is unable to parse the in-memory object. However# we can still run the graph eagerly through torch.fx.Interpreter,# so we will bypass this error.warnings.warn("Unable to execute the generated python source code from ""the graph. The graph module will no longer be directly callable, ""but you can still run the ExportedProgram, and if needed, you can ""run the graph module eagerly using torch.fx.Interpreter.")gm=torch.fx.GraphModule(root,torch.fx.Graph())gm._graph=graphreturngm
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.