# mypy: allow-untyped-defsimportcopyimportdataclassesimportfunctoolsimportreimporttypesimportwarningsfromcollectionsimportnamedtuplefromtypingimport(Any,Callable,Dict,Iterator,List,Optional,Tuple,Type,TYPE_CHECKING,Union,)fromtorch.fx.immutable_collectionsimportimmutable_dict,immutable_listifTYPE_CHECKING:# Import the following modules during type checking to enable code intelligence features,# such as auto-completion in tools like pylance, even when these modules are not explicitly# imported in user code.importsympyfromtorch.utils._sympy.value_rangesimportValueRangesimporttorchimporttorch.utils._pytreeaspytreefromtorch.export._tree_utilsimportis_equivalent,reorder_kwargsfromtorch.fx._compatibilityimportcompatibilityfromtorch.fx._utilsimportfirst_call_function_nn_module_stackfromtorch.fx.experimental.proxy_tensorimportmaybe_disable_fake_tensor_modefromtorch.fx.passes.infra.pass_baseimportPassResultfromtorch.fx.passes.infra.pass_managerimportPassManagerfromtorch.fx.passes.runtime_assertimportinsert_deferred_runtime_assertsfrom.graph_signatureimport(# noqa: F401_sig_to_specs,ArgumentSpec,ConstantArgument,CustomObjArgument,ExportGraphSignature,InputKind,InputSpec,OutputKind,OutputSpec,SymIntArgument,TensorArgument,TokenArgument,)__all__=["ExportedProgram","ModuleCallEntry","ModuleCallSignature",]PassType=Callable[[torch.fx.GraphModule],Optional[PassResult]]
def_disable_prexisiting_fake_mode(fn):@functools.wraps(fn)defwrapper(*args,**kwargs):withmaybe_disable_fake_tensor_mode():returnfn(*args,**kwargs)returnwrapperdef_fx_collection_equivalence_fn(spec1_type:Optional[type],spec1_context:pytree.Context,spec2_type:Optional[type],spec2_context:pytree.Context,)->bool:"""Treat containers and their immutable variants as the same type. Otherwise compare as normal. """ifspec1_typeisNoneorspec2_typeisNone:returnspec1_typeisspec2_typeandspec1_context==spec2_contextifissubclass(spec1_type,(dict,immutable_dict))andissubclass(spec2_type,(dict,immutable_dict)):returnspec1_context==spec2_contextifissubclass(spec1_type,(list,immutable_list))andissubclass(spec2_type,(list,immutable_list)):returnspec1_context==spec2_contextreturnspec1_typeisspec2_typeandspec1_context==spec2_contextdef_rename_without_collisions(name_map:Dict[str,str],orig_name:str,name:str,is_placeholder:bool=False,):""" Renames nodes to avoid name collisions, with suffixing. name_map: map from original name to new name orig_name: mapping key name: candidate name (potentially suffixed, e.g. mul_2) is_placeholder: if the node is a placeholder, avoid detecting suffix """ifnameinname_map.values():# non-placeholder nodes may be suffixed with the count# instead of adding another suffix, we will try to increment itmatch=re.match(r"(.*)_(\d+)",name)ifmatchandnotis_placeholder:name,n=match.group(1),int(match.group(2))else:n=0while(dup_name:=f"{name}_{n+1}")inname_map.values():n+=1name_map[orig_name]=dup_nameelse:name_map[orig_name]=namereturnname_map[orig_name]def_name_hoo_subgraph_placeholders(gm:torch.fx.GraphModule)->None:""" Propagate placeholder names from the top-level graph into HigherOrderOp subgraphs, and handle collisions with non-placeholders by count suffixing. Different HOO subgraph types have different input schemas, so we first enumerate them and gather the top-level named placeholder nodes. """# gather all HOO subgraphs and their top-level named placeholder nodessubgraph_ph_tuples:List[Tuple[torch.fx.GraphModule,List[torch.fx.Node]]]=[]fornodeingm.graph.nodes:ifnode.op=="call_function"andisinstance(node.target,torch._ops.HigherOrderOperator):# HOO subgraphs have varying input schemas, so we enumerate them thereifnode.target._name=="cond":_,true_graph,false_graph,cond_args=node._argssubgraph_ph_tuples.append((getattr(gm,true_graph.target),cond_args))subgraph_ph_tuples.append((getattr(gm,false_graph.target),cond_args))elifnode.target._name=="wrap_with_set_grad_enabled":subgraph,phs=node._args[1],node._args[2:]subgraph_ph_tuples.append((getattr(gm,subgraph.target),phs))elifnode.target._name=="map_impl":body_graph,array,args=node._argssubgraph_ph_tuples.append((getattr(gm,body_graph.target),array+args))# propagate namesforsubgraph,hoo_phsinsubgraph_ph_tuples:name_map:Dict[str,str]={}fori,nodeinenumerate(subgraph.graph.nodes):ifi<len(hoo_phs):# placeholder, retain namename_map[node.name]=hoo_phs[i].namenode.name=node.target=hoo_phs[i].nameelse:# non-placeholder, check for collisionsnode.name=_rename_without_collisions(name_map,node.name,node.name)# recurse and recompile_name_hoo_subgraph_placeholders(subgraph)subgraph.recompile()
[docs]classExportedProgram:""" Package of a program from :func:`export`. It contains an :class:`torch.fx.Graph` that represents Tensor computation, a state_dict containing tensor values of all lifted parameters and buffers, and various metadata. You can call an ExportedProgram like the original callable traced by :func:`export` with the same calling convention. To perform transformations on the graph, use ``.module`` property to access an :class:`torch.fx.GraphModule`. You can then use `FX transformation <https://pytorch.org/docs/stable/fx.html#writing-transformations>`_ to rewrite the graph. Afterwards, you can simply use :func:`export` again to construct a correct ExportedProgram. """def__init__(self,root:Union[torch.nn.Module,Dict[str,Any]],graph:torch.fx.Graph,graph_signature:ExportGraphSignature,state_dict:Dict[str,Union[torch.Tensor,torch.nn.Parameter]],range_constraints:"Dict[sympy.Symbol, Any]",module_call_graph:List[ModuleCallEntry],example_inputs:Optional[Tuple[Tuple[Any,...],Dict[str,Any]]]=None,verifier:Optional[Type[Any]]=None,# TODO Change typing hint to Verifier.tensor_constants:Optional[Dict[str,torch.Tensor]]=None,# TODO: deprecate thisconstants:Optional[Dict[str,Union[torch.Tensor,torch._C.ScriptObject]]]=None,):# Remove codegen related things from the graph. It should just be a flat graph.graph._codegen=torch.fx.graph.CodeGen()self._graph_module=_create_graph_module_for_export(root,graph)ifisinstance(root,torch.fx.GraphModule):self._graph_module.meta.update(root.meta)self._graph_signature:ExportGraphSignature=graph_signatureself._state_dict:Dict[str,Any]=state_dictself._range_constraints:Dict[sympy.Symbol,ValueRanges]=range_constraintsassertmodule_call_graphisnotNoneself._module_call_graph:List[ModuleCallEntry]=module_call_graphself._example_inputs=example_inputsself._constants=tensor_constantsorconstantsor{}assertself._constantsisnotNonefromtorch._export.verifierimportVerifierifverifierisNone:verifier=Verifierassertissubclass(verifier,Verifier)self._verifier=verifier# Validate should be always the last step of the constructor.self.verifier().check(self)@property@compatibility(is_backward_compatible=False)defgraph_module(self):returnself._graph_module@property@compatibility(is_backward_compatible=False)defgraph(self):returnself.graph_module.graph@property@compatibility(is_backward_compatible=False)defgraph_signature(self):returnself._graph_signature@property@compatibility(is_backward_compatible=False)defstate_dict(self):returnself._state_dict
[docs]@compatibility(is_backward_compatible=False)defparameters(self)->Iterator[torch.nn.Parameter]:""" Returns an iterator over original module's parameters. """for_,paraminself.named_parameters():yieldparam
[docs]@compatibility(is_backward_compatible=False)defnamed_parameters(self)->Iterator[Tuple[str,torch.nn.Parameter]]:""" Returns an iterator over original module parameters, yielding both the name of the parameter as well as the parameter itself. """forparam_nameinself.graph_signature.parameters:yieldparam_name,self.state_dict[param_name]
[docs]@compatibility(is_backward_compatible=False)defbuffers(self)->Iterator[torch.Tensor]:""" Returns an iterator over original module buffers. """for_,bufinself.named_buffers():yieldbuf
[docs]@compatibility(is_backward_compatible=False)defnamed_buffers(self)->Iterator[Tuple[str,torch.Tensor]]:""" Returns an iterator over original module buffers, yielding both the name of the buffer as well as the buffer itself. """non_persistent_buffers=set(self.graph_signature.non_persistent_buffers)forbuffer_nameinself.graph_signature.buffers:ifbuffer_nameinnon_persistent_buffers:yieldbuffer_name,self.constants[buffer_name]else:yieldbuffer_name,self.state_dict[buffer_name]
@property@compatibility(is_backward_compatible=False)defrange_constraints(self):returnself._range_constraints@property@compatibility(is_backward_compatible=False)defmodule_call_graph(self):returnself._module_call_graph@property@compatibility(is_backward_compatible=False)defexample_inputs(self):returnself._example_inputs@property@compatibility(is_backward_compatible=False)defcall_spec(self):CallSpec=namedtuple("CallSpec",["in_spec","out_spec"])iflen(self.module_call_graph)==0:returnCallSpec(in_spec=None,out_spec=None)assertself.module_call_graph[0].fqn==""returnCallSpec(in_spec=self.module_call_graph[0].signature.in_spec,out_spec=self.module_call_graph[0].signature.out_spec,)@property@compatibility(is_backward_compatible=False)defverifier(self)->Any:returnself._verifier@property@compatibility(is_backward_compatible=False)defdialect(self)->str:returnself._verifier.dialect@property@compatibility(is_backward_compatible=False)deftensor_constants(self):returnself._constants@property@compatibility(is_backward_compatible=False)defconstants(self):returnself._constantsdef_get_flat_args_with_check(self,args,kwargs):"""Flatten args, kwargs using pytree, then, check specs. Args: args: List[Any] original args passed to __call__ kwargs: Dict[str, Any] original kwargs passed to __call Returns: A tuple of (flat_args, received_spec) flat_args is flattend args / kwargs received_spec is the pytree spec produced while flattening the tuple (args, kwargs) """in_spec=self.call_spec.in_specifin_specisnotNone:kwargs=reorder_kwargs(kwargs,in_spec)flat_args_with_path,received_spec=pytree.tree_flatten_with_path((args,kwargs))# type: ignore[possibly-undefined]self._check_input_constraints(flat_args_with_path)flat_args=tuple(x[1]forxinflat_args_with_path)returnflat_args,received_specdef_graph_module_flat_inputs(self,args:Any,kwargs:Any)->Any:"""Transform args, kwargs of __call__ to args for graph_module. self.graph_module takes stuff from state dict as inputs. The invariant is for ep: ExportedProgram is ep(args, kwargs) == ep.postprocess(ep.graph_module(ep.graph_module_flat_inputs(args, kwargs))) """in_spec=self.call_spec.in_specflat_args,received_spec=self._get_flat_args_with_check(args,kwargs)ifin_specisnotNoneandnotis_equivalent(received_spec,in_spec,_fx_collection_equivalence_fn):raiseValueError("Trying to flatten user inputs with exported input tree spec: \n"f"{in_spec}\n""but actually got inputs with tree spec of: \n"f"{received_spec}")additional_inputs=[]forinput_inself.graph_signature.input_specs:ifinput_.kind==InputKind.USER_INPUT:continueelifinput_.kindin(InputKind.PARAMETER,InputKind.BUFFER,):ifinput_.persistentisFalse:# This is a non-persistent buffer, grab it from our# constants instead of the state dict.additional_inputs.append(self.constants[input_.target])else:additional_inputs.append(self.state_dict[input_.target])elifinput_.kindin(InputKind.CONSTANT_TENSOR,InputKind.CUSTOM_OBJ,):additional_inputs.append(self.constants[input_.target])additional_inputs=tuple(additional_inputs)# NOTE: calling convention is first params, then buffers, then args as user supplied them.# See: torch/_functorch/aot_autograd.py#L1034returnadditional_inputs+flat_argsdef__call__(self,*args:Any,**kwargs:Any)->Any:raiseRuntimeError("Unable to call ExportedProgram directly. ""You should use `exported_program.module()` instead.")def_postprocess_graph_module_outputs(self,res,orig_args,orig_kwargs):"""Process potential mutations to the input. Because self.graph_module is functional, so mutations has to be written back after execution of graph_module. """importtorch._export.erroraserrorflat_args,_=self._get_flat_args_with_check(orig_args,orig_kwargs)ifself.call_spec.out_specisnotNone:buffer_mutation=self.graph_signature.buffers_to_mutateuser_input_mutation=self.graph_signature.user_inputs_to_mutatenum_mutated=len(buffer_mutation)+len(user_input_mutation)mutated_values=res[:num_mutated]# Exclude dependency token from final result.assertion_dep_token=self.graph_signature.assertion_dep_tokenifassertion_dep_tokenisnotNone:assertion_dep_token_index=next(iter(assertion_dep_token.keys()))res=res[:assertion_dep_token_index]res=res[num_mutated:]try:res=pytree.tree_unflatten(res,self.call_spec.out_spec)exceptException:_,received_spec=pytree.tree_flatten(res)raiseerror.InternalError(# noqa: B904"Trying to flatten user outputs with exported output tree spec: \n"f"{self.call_spec.out_spec}\n""but actually got outputs with tree spec of: \n"f"{received_spec}")finally:user_inputs=[specforspecinself.graph_signature.input_specsifspec.kind==InputKind.USER_INPUT]fori,valueinenumerate(mutated_values):output_spec=self.graph_signature.output_specs[i]ifoutput_spec.kind==OutputKind.BUFFER_MUTATION:assertoutput_spec.targetisnotNoneself.state_dict[output_spec.target]=valueelifoutput_spec.kind==OutputKind.USER_INPUT_MUTATION:assertoutput_spec.targetisnotNoneindex=next(ifori,specinenumerate(user_inputs)ifspec.arg.name==output_spec.target)flat_args[index].copy_(value)else:raiseAssertionError(f"Unexpected kind: {output_spec.kind}")returnresdef__str__(self)->str:graph_module=self.graph_module.print_readable(print_output=False).replace("\n","\n ")string=("ExportedProgram:\n"f" {graph_module}\n"f"Graph signature: {self.graph_signature}\n"f"Range constraints: {self.range_constraints}\n")returnstring
[docs]defmodule(self)->torch.nn.Module:""" Returns a self contained GraphModule with all the parameters/buffers inlined. """from._unliftimport_unlift_exported_program_lifted_statesmodule=_unlift_exported_program_lifted_states(self)def_train(self,mode:bool=True):raiseNotImplementedError("Calling train() is not supported yet.")def_eval(self,mode:bool=True):raiseNotImplementedError("Calling eval() is not supported yet.")module.train=types.MethodType(_train,module)# type: ignore[method-assign]module.eval=types.MethodType(_eval,module)# type: ignore[method-assign]returnmodule
[docs]@_disable_prexisiting_fake_modedefrun_decompositions(self,decomp_table:Optional[Dict[torch._ops.OperatorBase,Callable]]=None)->"ExportedProgram":""" Run a set of decompositions on the exported program and returns a new exported program. By default we will run the Core ATen decompositions to get operators in the `Core ATen Operator Set <https://pytorch.org/docs/stable/torch.compiler_ir.html>`_. For now, we do not decompose joint graphs. """fromtorch._decompimportcore_aten_decompositionsfromtorch._export.passes.lift_constants_passimport(ConstantAttrMap,lift_constants_pass,)fromtorch._export.passes.replace_sym_size_ops_passimport(_replace_sym_size_ops_pass,)fromtorch._functorch.aot_autogradimportaot_export_moduledef_get_placeholders(gm):placeholders=[]fornodeingm.graph.nodes:ifnode.op!="placeholder":breakplaceholders.append(node)returnplaceholdersifdecomp_tableisNone:decomp_table=core_aten_decompositions()old_placeholders=_get_placeholders(self.graph_module)fake_args=[node.meta["val"]fornodeinold_placeholders]buffers_to_remove=[nameforname,_inself.graph_module.named_buffers()]fornameinbuffers_to_remove:delattr(self.graph_module,name)# TODO(zhxhchen17) Return the new graph_signature directly.fromtorch.export._traceimport_ignore_backend_decompswith_ignore_backend_decomps():gm,graph_signature=aot_export_module(self.graph_module,fake_args,decompositions=decomp_table,trace_joint=False,)# Update the signatures with the new placeholder names in case they# changed when calling aot_exportdefupdate_arg(old_arg,new_ph):ifisinstance(old_arg,ConstantArgument):returnold_argelifisinstance(old_arg,TensorArgument):returnTensorArgument(name=new_ph.name)elifisinstance(old_arg,SymIntArgument):returnSymIntArgument(name=new_ph.name)raiseRuntimeError(f"Type of old_arg not supported: {type(old_arg)}")new_placeholders=_get_placeholders(gm)new_outputs=list(gm.graph.nodes)[-1].args[0]# rename the placeholdersassertlen(new_placeholders)==len(old_placeholders)forold_ph,new_phinzip(old_placeholders,new_placeholders):new_ph.name=new_ph.target=old_ph.name# handle name collisions with newly decomposed graph nodesname_map={ph.name:ph.nameforphinnew_placeholders}fornodeingm.graph.nodes:ifnode.op=="placeholder":continuenode.name=_rename_without_collisions(name_map,node.name,node.name)# propagate names to higher order op subgraphs_name_hoo_subgraph_placeholders(gm)# To match the output target with correct input for input mutations# need to find the old to new placeholder mapold_new_placeholder_map={spec.arg.name:new_placeholders[i].namefori,specinenumerate(self.graph_signature.input_specs)ifnotisinstance(spec.arg,ConstantArgument)}input_specs=[InputSpec(spec.kind,update_arg(spec.arg,new_placeholders[i]),spec.target,spec.persistent,)fori,specinenumerate(self.graph_signature.input_specs)]output_specs=[OutputSpec(spec.kind,update_arg(spec.arg,new_outputs[i]),old_new_placeholder_map.get(spec.target,spec.target),)fori,specinenumerate(self.graph_signature.output_specs)]assertlen(new_placeholders)==len(old_placeholders)new_graph_signature=ExportGraphSignature(input_specs=input_specs,output_specs=output_specs)# NOTE: aot_export adds symint metadata for placeholders with int# values; since these become specialized, we replace such metadata with# the original values.# Also, set the param/buffer metadata back to the placeholders.forold_node,new_nodeinzip(old_placeholders,new_placeholders):ifnotisinstance(old_node.meta["val"],torch.Tensor):new_node.meta["val"]=old_node.meta["val"]if(new_node.targetinnew_graph_signature.inputs_to_parametersornew_node.targetinnew_graph_signature.inputs_to_buffers):fork,vinold_node.meta.items():new_node.meta[k]=v# TODO unfortunately preserving graph-level metadata is not# working well with aot_export. So we manually copy it.# (The node-level meta is addressed above.)gm.meta.update(self.graph_module.meta)new_range_constraints=_get_updated_range_constraints(gm,self.range_constraints,_is_executorch=False,)constants=lift_constants_pass(gm,new_graph_signature,ConstantAttrMap())fork,vinconstants.items():assertknotinself.constantsself.constants[k]=v_replace_sym_size_ops_pass(gm)fromtorch._dynamoimportconfigas_dynamo_configfromtorch._export.passes._node_metadata_hookimport(_node_metadata_hook,_set_node_metadata_hook,)ifnot_dynamo_config.do_not_emit_runtime_asserts:stack_trace=('File "torch/fx/passes/runtime_assert.py", line 24, '"in insert_deferred_runtime_asserts")shape_env=_get_shape_env(gm)ifshape_envisnotNone:with_set_node_metadata_hook(gm,functools.partial(_node_metadata_hook,stack_trace=stack_trace)):insert_deferred_runtime_asserts(gm,shape_env,f"exported program: {first_call_function_nn_module_stack(gm.graph)}",export=True,)exported_program=ExportedProgram(root=gm,graph=gm.graph,graph_signature=new_graph_signature,state_dict=self.state_dict,range_constraints=new_range_constraints,module_call_graph=copy.deepcopy(self.module_call_graph),example_inputs=self.example_inputs,verifier=self.verifier,constants=self.constants,)returnexported_program
def_transform_do_not_use(self,*passes:PassType)->"ExportedProgram":pm=PassManager(list(passes))# Since we abstractly run the passes, we need to disable backend decomp here# again.fromtorch.export._traceimport_ignore_backend_decompswith_ignore_backend_decomps():res=pm(self.graph_module)transformed_gm=res.graph_moduleifresisnotNoneelseself.graph_moduleasserttransformed_gmisnotNoneiftransformed_gmisself.graph_moduleandnotres.modified:returnself# TODO(zhxchen17) Remove this.def_get_updated_graph_signature(old_signature:ExportGraphSignature,new_gm:torch.fx.GraphModule,)->ExportGraphSignature:""" Update the graph signature's user_input/user_outputs. """new_input_specs=[]fori,nodeinenumerate(new_gm.graph.nodes):ifnode.op!="placeholder":breakasserti<len(old_signature.input_specs),"Number of inputs changed after transformation"old_input_spec=old_signature.input_specs[i]arg=(old_input_spec.argifisinstance(old_input_spec.arg,(ConstantArgument,CustomObjArgument))elsetype(old_input_spec.arg)(node.name))new_input_specs.append(InputSpec(old_input_spec.kind,arg,old_input_spec.target,old_input_spec.persistent,))output_node=list(new_gm.graph.nodes)[-1]assertoutput_node.op=="output"new_output_specs=[]fori,nodeinenumerate(output_node.args[0]):asserti<len(old_signature.output_specs),"Number of outputs changed after transformation"old_output_spec=old_signature.output_specs[i]arg=(old_output_spec.argifisinstance(old_output_spec.arg,(ConstantArgument,CustomObjArgument))elsetype(old_output_spec.arg)(node.name))new_output_specs.append(OutputSpec(old_output_spec.kind,arg,old_output_spec.target))new_signature=ExportGraphSignature(input_specs=new_input_specs,output_specs=new_output_specs)returnnew_signaturetransformed_ep=ExportedProgram(root=transformed_gm,graph=transformed_gm.graph,graph_signature=_get_updated_graph_signature(self.graph_signature,transformed_gm),state_dict=self.state_dict,range_constraints=_get_updated_range_constraints(transformed_gm,self.range_constraints,_is_executorch=False,),module_call_graph=copy.deepcopy(self._module_call_graph),example_inputs=self.example_inputs,verifier=self.verifier,constants=self.constants,)transformed_ep.graph_module.meta.update(self.graph_module.meta)transformed_ep.graph_module.meta.update(res.graph_module.meta)returntransformed_epdef_check_input_constraints(self,flat_args_with_path):fromtorch._export.utilsimport_check_input_constraints_for_graphplaceholders=[pforpinself.graph.nodesifp.op=="placeholder"]input_placeholders=[pforp,sinzip(placeholders,self.graph_signature.input_specs)ifs.kind==InputKind.USER_INPUT]_check_input_constraints_for_graph(input_placeholders,flat_args_with_path,self.range_constraints)def_validate(self):self.verifier().check(self)# TODO(zhxchen17) Formalize this.def_update(self,graph_module,graph_signature,state_dict=None)->"ExportedProgram":returnExportedProgram(root=graph_module,graph=graph_module.graph,graph_signature=graph_signature,state_dict=state_dictorself.state_dict,range_constraints=copy.deepcopy(self.range_constraints),module_call_graph=copy.deepcopy(self._module_call_graph),example_inputs=self.example_inputs,verifier=self.verifier,tensor_constants=self.tensor_constants,)
def_get_shape_env(gm):vals=[node.meta["val"]fornodeingm.graph.nodesifnode.meta.get("val",None)isnotNone]fromtorch._guardsimportdetect_fake_modefake_mode=detect_fake_mode(vals)iffake_modeisnotNone:returnfake_mode.shape_envforvinvals:ifisinstance(v,torch.SymInt):returnv.node.shape_envdef_get_updated_range_constraints(gm:torch.fx.GraphModule,old_range_constraints:"Optional[Dict[sympy.Symbol, Any]]"=None,_is_executorch:bool=True,)->"Dict[sympy.Symbol, Any]":# FIXME(tmanlaibaatar) Remove this whole branch once https://github.com/pytorch/pytorch/pull/123764if_is_executorch:assertold_range_constraintsisNoneshape_env=_get_shape_env(gm)ifshape_envisNone:return{}range_constraints={k:vfork,vinshape_env.var_to_range.items()ifknotinshape_env.replacements}# Only when we have an unbacked symint, and it's used as constructor inputs,# runtime_var_to_range will make a difference compated to var_to_range.# e.g. [2, oo) -> [0, oo)fork,vinshape_env.var_to_range.items():ifknotinshape_env.replacements:range_constraints[k]=vreturnrange_constraintsassertold_range_constraintsisnotNoneshape_env=_get_shape_env(gm)ifshape_envisNone:return{}range_constraints=copy.copy(old_range_constraints)range_constraints={k:vfork,vinrange_constraints.items()ifknotinshape_env.replacements}# Only when we have an unbacked symint, and it's used as constructor inputs,# runtime_var_to_range will make a difference compated to var_to_range.# e.g. [2, oo) -> [0, oo)fork,vinshape_env.var_to_range.items():ifknotinshape_env.replacementsandknotinrange_constraints:range_constraints[k]=vreturnrange_constraintsdef_create_graph_module_for_export(root,graph):try:gm=torch.fx.GraphModule(root,graph)exceptSyntaxError:# If custom objects stored in memory are being used in the graph,# the generated python code will result in a syntax error on the custom# object, since it is unable to parse the in-memory object. However# we can still run the graph eagerly through torch.fx.Interpreter,# so we will bypass this error.warnings.warn("Unable to execute the generated python source code from ""the graph. The graph module will no longer be directly callable, ""but you can still run the ExportedProgram, and if needed, you can ""run the graph module eagerly using torch.fx.Interpreter.")gm=torch.fx.GraphModule(root,torch.fx.Graph())gm._graph=graphreturngm
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.