# mypy: allow-untyped-defsimportabcimportcopyimportloggingimportoperatorimportrefromcollectionsimportdefaultdictfromcontextlibimportcontextmanagerfromcopyimportdeepcopyfromdataclassesimportdataclassfromenumimportEnumfromtypingimportAny,Callable,cast,Dict,List,Optional,Set,Tuple,Unionimporttorchimporttorch.fx._pytreeasfx_pytreeimporttorch.utils._pytreeaspytreefromtorch._library.fake_class_registryimportFakeScriptObjectfromtorch.export._tree_utilsimportreorder_kwargsfromtorch.export.exported_programimport(ConstantArgument,ExportedProgram,ExportGraphSignature,InputKind,ModuleCallSignature,SymBoolArgument,SymFloatArgument,SymIntArgument,TensorArgument,)fromtorch.fx._symbolic_traceimportis_fx_tracingfromtorch.fx.graph_moduleimport_get_attr,_get_attr_via_attr_list,_print_readablefromtorch.utils._pytreeimportGetAttrKey,SequenceKeyfrom._remove_effect_tokens_passimport_remove_effect_tokenslog=logging.getLogger(__name__)__all__=["FlatArgsAdapter","InterpreterModule","InterpreterModuleDispatcher","UnflattenedModule","unflatten",]class_AttrKind(Enum):PARAMETER="parameter"BUFFER="buffer"CONSTANT="constant"MODULE="module"RUN_WITH_INTERPRETER=True@contextmanagerdef_disable_interpreter():globalRUN_WITH_INTERPRETERold_flag=RUN_WITH_INTERPRETERRUN_WITH_INTERPRETER=Falsetry:yieldfinally:RUN_WITH_INTERPRETER=old_flag# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module# This installs empty Modules where none exist yet if they are subpaths of targetdef_assign_attr(from_obj:Union[torch.Tensor,torch.ScriptObject,torch.nn.Module],to_module:torch.nn.Module,target:str,attr_kind:_AttrKind,persistent:bool=True,):*prefix,field=target.split(".")# We need to generate all submodules of `to_module` that are at `prefix` and# variants of `prefix` that differ only by call name. All of these submodules# will then be assigned `from_obj` at `field` so that they can share this attribute.# For example, if target is foo.bar.f, foo has another call name foo@1,# and bar has other call names bar@1, bar@2, then we will assign f to# foo.bar, foo.bar@1, foo.bar@2, foo@1.bar, foo@1.bar@1, foo@1.bar@2.to_modules={to_module}foriteminprefix:ts:Set[torch.nn.Module]=set()forto_moduleinto_modules:ifnothasattr(to_module,item):setattr(to_module,item,torch.nn.Module())ts.update(t_call# type: ignore[misc]fork,t_callinto_module._modules.items()if_is_call_name(k,item))to_modules=tsforto_moduleinto_modules:ifattr_kind==_AttrKind.PARAMETER:assertisinstance(from_obj,torch.nn.Parameter)to_module.register_parameter(field,from_obj)elifattr_kind==_AttrKind.BUFFER:assertisinstance(from_obj,torch.Tensor)to_module.register_buffer(field,from_obj,persistent=persistent)elifattr_kind==_AttrKind.CONSTANT:assertnotisinstance(from_obj,FakeScriptObject),"FakeScriptObject should only exist during tracing."assertisinstance(from_obj,(torch.Tensor,torch.ScriptObject,),)setattr(to_module,field,from_obj)elifattr_kind==_AttrKind.MODULE:assertisinstance(from_obj,torch.nn.Module)setattr(to_module,field,from_obj)
[docs]classInterpreterModule(torch.nn.Module):"""A module that uses torch.fx.Interpreter to execute instead of the usual codegen that GraphModule uses. This provides better stack trace information and makes it easier to debug execution. """graph_module:Optional[torch.fx.GraphModule]def__init__(self,graph:torch.fx.Graph,):super().__init__()self.graph=graphself.graph.owning_module=selfself._run_with_interpreter=RUN_WITH_INTERPRETERdefforward(self,*args,**kwargs):assertself.graph_moduleisnotNone,"Didn't finalize this InterpreterModule"ifnotis_fx_tracing()and(torch.compiler.is_dynamo_compiling()ornotself._run_with_interpreter):# Dynamo cannot trace through torch.fx.Interpreter, so fall back to# GraphModule codegen in this instance.# Patch the codegened forward to run with this InterpreterModule,# so attribute accesses, etc. are on this module instead.returntype(self.graph_module).forward(self,*args,**kwargs)else:ifkwargs:# Handle **kwargs. FX only natively supports positional# arguments (through placeholders). So in order to pass in# kwargs, we must correspond the names of the placeholders with# the keys in the kwarg dict.arg_list=list(args)kwarg_names=self.arg_names[len(arg_list):]arg_list.extend(kwargs[kwarg_name]forkwarg_nameinkwarg_namesifkwarg_nameinkwargs)# Assert that the kwargs passed in exactly match the positional# arguments specified by the GraphModule. This should be# guaranteed by the unflattening process.assertlen(kwarg_names)==len(kwargs)assertlen(arg_list)==len(self.arg_names)args=tuple(arg_list)returntorch.fx.Interpreter(self,graph=self.graph).run(*args,enable_io_processing=False)deffinalize(self):# We need to "finalize" because GraphModule populates its own state_dict# based on the get_attrs observed in the graph. So we need to fully# construct the graph and call _sink_params before generating this# GraphModule.# need to set `graph_module` directly on the dict to avoid it getting# registered as a submodule.self.__dict__["graph_module"]=torch.fx.GraphModule(self,self.graph)self.graph.lint()# Cache arg names for kwarg handling (see forward())self.arg_names=[]fornodeinself.graph.nodes:ifnode.op=="placeholder":self.arg_names.append(node.target)defprint_readable(self,print_output=True,include_stride=False,include_device=False,colored=False,):return_print_readable(self,"InterpreterModule",print_output,include_stride,include_device,colored,)
[docs]classInterpreterModuleDispatcher(torch.nn.Module):""" A module that carries a sequence of InterpreterModules corresponding to a sequence of calls of that module. Each call to the module dispatches to the next InterpreterModule, and wraps back around after the last. """def__init__(self,attrs:Set[str],call_modules:List[InterpreterModule]):super().__init__()assertcall_modulesself._modules=call_modules[0]._modulesforaccessorinattrs:setattr(self,accessor,getattr(call_modules[0],accessor))self._call_modules=call_modulesself._num_calls=0defforward(self,*args,**kwargs):call_module=self._call_modules[self._num_calls]self._num_calls=(self._num_calls+1)%len(self._call_modules)try:returncall_module(*args,**kwargs)exceptException:self._num_calls=0raisedefcall_modules(self):returnself._call_modulesdefprint_readable(self,print_output=True,include_stride=False,include_device=False,colored=False,):outputs=[mod.print_readable(print_output,include_stride,include_device,colored,)formodinself._call_modules]return"\n".join(outputs)
[docs]classFlatArgsAdapter(abc.ABC):""" Adapts input arguments with ``input_spec`` to align ``target_spec``. """
[docs]@abc.abstractmethoddefadapt(self,target_spec:pytree.TreeSpec,input_spec:pytree.TreeSpec,input_args:List[Any],)->List[Any]:"""NOTE: This adapter may mutate given ``input_args_with_path``."""...
classUnflattenedModule(torch.nn.Module):def__init__(self,export_module:ExportedProgram,flat_args_adapter:Optional[FlatArgsAdapter]=None,):super().__init__()ifexport_module.graph_signature.backward_signatureisnotNone:raiseValueError("Unflattening on JointExportModule NYI")fqn_list=[entry.fqnforentryinexport_module.module_call_graph]assertfqn_list[0]==""export_graph=deepcopy(export_module.graph)self.graph_signature=deepcopy(export_module.graph_signature)self.graph=torch.fx.Graph()self.graph.owning_module=selfself.module_call_graph=deepcopy(export_module.module_call_graph)self.flat_args_adapter=flat_args_adapter# Flag to indicate whether args have been adapted.self.adapted=Falseself._run_with_interpreter=RUN_WITH_INTERPRETER_inplace_buffer_mutations(export_graph,self.graph_signature)self.ivals=_IVals()# record any intermediate value x that is used, with the modules that used it,# and generate instructions to read the corresponding attributeseen_modules,seen_attrs=_outline_submodules(export_graph,self)# for each read intermediate value x, find the module that created it,# and generate instructions to update the corresponding attribute;# finally, initialize all these attributesself.ivals.create(seen_modules.values(),self)# move attributes that correspond to graph arguments for HOPs# from exported program to unflattened submodules_copy_graph_attrs(export_module._graph_module,self,seen_attrs)self.range_constraints=export_module.range_constraintsself.equality_constraints:List=[]# aliasing/unused param or buffer issues:# in strict-mode export, dynamo export will deduplicate aliased tensors,# and ignore unused tensors. For aliasing, this causes issues when some aliases# are unused, and we're unable to match the placeholder node to the correct FQN.# This leads to the graph signature potentially having the wrong target FQN,# and downstream issues where parameters are assigned to the wrong target attribute,# mismatching the relevant placeholder node in the unflattened module.# To resolve this we restore (_assign_attr) all aliased/unused tensors in# the state_dict as module attributes, but only keep the used tensors in the# graph's forward pass (_sink_params).state_dict=export_module.state_dictassigned_params:Set[str]=set()# tracking unused paramsid_to_param:Dict[int,torch.nn.Parameter]={}# handling weight-sharingfornameinself.graph_signature.parameters:# this loop adds used paramsparam=state_dict[name]ifid(param)notinid_to_param:id_to_param[id(param)]=torch.nn.Parameter(param.clone(),requires_grad=param.requires_grad)_assign_attr(id_to_param[id(param)],self,name,attr_kind=_AttrKind.PARAMETER,)assigned_params.add(name)non_persistent_buffers=set(self.graph_signature.non_persistent_buffers)assigned_buffers:Set[str]=set()# tracking unused buffersid_to_buffer:Dict[int,Tuple[torch.nn.Parameter,bool]]={}fornameinself.graph_signature.buffers:# this loop adds used buffersifnameinnon_persistent_buffers:persistent=Falsebuffer=export_module.constants[name]else:persistent=Truebuffer=state_dict[name]ifid(buffer)notinid_to_buffer:id_to_buffer[id(buffer)]=(buffer.clone(),persistent)_assign_attr(id_to_buffer[id(buffer)][0],self,name,attr_kind=_AttrKind.BUFFER,persistent=persistent,)assigned_buffers.add(name)# restore aliased/unused params and buffers# these appear in state dict but not graph signatureforname,tensorinstate_dict.items():ifnameinassigned_paramsornameinassigned_buffers:# already assignedcontinueis_buffer=Falseifid(tensor)inid_to_bufferornotisinstance(tensor,torch.nn.Parameter):# aliased bufferis_buffer=Trueifis_buffer:if(id(tensor)notinid_to_buffer):# this is completely unused (not weight-sharing)id_to_buffer[id(tensor)]=(tensor,True,)# assign to respect original model_assign_attr(id_to_buffer[id(tensor)][0],self,name,attr_kind=_AttrKind.BUFFER,persistent=True,)else:ifid(tensor)notinid_to_param:# this is unusedid_to_param[id(tensor)]=tensor_assign_attr(id_to_param[id(tensor)],self,name,attr_kind=_AttrKind.PARAMETER,)# use id map so we don't double-clone aliased constantsid_to_const:Dict[int,Union[torch.Tensor,torch._C.ScriptObject]]={}forfqn,constantinexport_module.constants.items():ifid(constant)notinid_to_const:ifisinstance(constant,torch.Tensor):constant=constant.clone()id_to_const[id(constant)]=constant_constant=id_to_const[id(constant)]_assign_attr(_constant,self,fqn,attr_kind=_AttrKind.CONSTANT,)# This is to handle parameters/buffers that point to the same tensor# object id -> list of (node_name, target_name)consts_map:Dict[int,List[Tuple[str,str]]]=defaultdict(list)consts_targets:Set[str]=set()defadd_to_consts_map(obj_id,node_name,target_name):name_list=consts_map[obj_id]name_list.append((node_name,target_name))added_params_buffers:Set[str]=set()# track aliased/unused params, buffersforsinself.graph_signature.input_specs:ifs.kind==InputKind.PARAMETERor(s.kind==InputKind.BUFFERands.persistent):asserthasattr(s.arg,"name")assertisinstance(s.target,str)add_to_consts_map(id(export_module.state_dict[s.target]),s.arg.name,s.target)consts_targets.add(s.target)added_params_buffers.add(s.target)elif((s.kind==InputKind.BUFFERandnots.persistent)ors.kind==InputKind.CONSTANT_TENSORors.kind==InputKind.CUSTOM_OBJ):asserthasattr(s.arg,"name")assertisinstance(s.target,str)add_to_consts_map(id(export_module.constants[s.target]),s.arg.name,s.target)consts_targets.add(s.target)# add constants that are aliased and don't appear in graph signatureforconst_name,constinexport_module.constants.items():ifconst_namenotinconsts_targets:assert(id(const)inconsts_map),"Constants should be either aliased or appear in graph signature"ph_name,_=consts_map[id(const)][0]add_to_consts_map(id(const),ph_name,const_name)added_params_buffers.add(s.target)# add aliased/unused params and buffers that don't appear in graph signatureforfqn,tensorinexport_module.state_dict.items():iffqnnotinadded_params_buffers:ifid(tensor)notinconsts_map:# completely unused (no weight-sharing), ignore.# this weight doesn't appear in graph module,# so won't cause FQN assignment issuescontinueph_name,_=consts_map[id(tensor)][0]add_to_consts_map(id(tensor),ph_name,fqn)# node name -> list of possible targetsinputs_to_state:Dict[str,List[str]]={}fornode_targetinconsts_map.values():targets=[t[1]fortinnode_target]forn,_innode_target:inputs_to_state[n]=targets_sink_params(self,inputs_to_state,[])redirected_call_indices=_deduplicate_modules(seen_modules.values())fqn_list=[fqnforfqninfqn_listiffqnnotinredirected_call_indices]self._dispatch_modules(redirected_call_indices,consts_targets)fqn_list=[fqnforfqninfqn_listif"@"notinfqn]# Cache so we don't have to compute this every time.# NOTE: this needs to be kept in sync with the placeholders in# self.graph, but currently we have no way to guarantee that.self.input_placeholders=[nodefornodeinself.graph.nodesifnode.op=="placeholder"]self.check_input_constraints=True# TODO(zhxchen17) We can register modules ahead of time instead of reorder later.fqn_order={fqn:ifori,fqninenumerate(fqn_list)}# In the case of legacy IR, we might be missing some modules from metadata.forname,_inself.named_modules(remove_duplicate=False):ifnamenotinfqn_order:fqn_order[name]=len(fqn_order)_reorder_submodules(self,fqn_order)self.graph.lint()def_print_graph(self):forfqn,modinself.named_modules():print(fqn+":")ifhasattr(mod,"graph")andisinstance(mod.graph,torch.fx.Graph):print(mod.graph)def_adapt_flat_args(self,flat_args,in_spec):signature=self.module_call_graph[0].signatureifin_spec==signature.in_spec:returnflat_argsifself.flat_args_adapterisNone:raiseTypeError("There is no flat args adapter sepcified. ""Are you sure you are calling this with the right arguments? ")else:flat_args=self.flat_args_adapter.adapt(target_spec=signature.in_spec,input_spec=in_spec,input_args=flat_args,)iflen(flat_args)!=signature.in_spec.num_leaves:raiseTypeError(f"Flat args adaption failed, number of args mismatch "f"Adatped: {len(flat_args)}\n"f"Exported module: {signature.in_spec.num_leaves}")returnflat_argsdefprocess_forward_inputs(self,*args,**kwargs):signature=self.module_call_graph[0].signaturereordered_kwargs=reorder_kwargs(kwargs,signature.in_spec)flat_args_with_path,in_spec=pytree.tree_flatten_with_path((args,reordered_kwargs))flat_args=[x[1]forxinflat_args_with_path]ifis_fx_tracing():returnflat_argsifin_spec!=signature.in_spec:ifnotself.adapted:print("Input treespec does not match with exported module's: \n"f"Input treespec: {in_spec}. ",f"Exported module treespec: {signature.in_spec}",)print("Adapting flat arg to match exported module's treespec")flat_args=self._adapt_flat_args(flat_args,in_spec)self.adapted=Trueifself.check_input_constraints:# Import here to avoid an unfortunate circular dependency.# TODO(suo): untangle this.fromtorch._export.utilsimport_check_input_constraints_for_graphifself.adaptedisTrue:# TODO(suo): The FlatArgsAdapter returns a list of flat args,# which we don't have keypaths for. For now, just create a dummy# keypath to associate with the arg.new_flat_args_with_path=[# type: ignore[var-annotated]((SequenceKey(idx=0),GetAttrKey(name="<unknown location>")),arg)forarginflat_args]else:new_flat_args_with_path=flat_args_with_path# type: ignore[assignment]_check_input_constraints_for_graph(self.input_placeholders,new_flat_args_with_path,self.range_constraints)returnflat_argsdefforward(self,*args,**kwargs):flat_args=torch._dynamo.disable(self.process_forward_inputs)(*args,**kwargs)signature=self.module_call_graph[0].signatureifis_fx_tracing():return_val=torch.fx.Interpreter(self,graph=self.graph).run(*flat_args,enable_io_processing=False)# For scalar return value, fx.Graph wraps in a tupleifisinstance(return_val,tuple)andlen(return_val)==1:returnreturn_val[0]returnreturn_valiftorch.compiler.is_dynamo_compiling()andnotself._run_with_interpreter:tree_out=torch.fx.GraphModule(self,self.graph)(*flat_args)else:tree_out=torch.fx.Interpreter(self,graph=self.graph).run(*flat_args,enable_io_processing=False)returnpytree.tree_unflatten(tree_out,signature.out_spec)def_dispatch_modules(self,redirected_call_indices,consts_targets):"""For a module whose call signatures are preserved, replace multiple modules corresponding to multiple calls to that module with a single dispatcher module that tracks which module to call. """# for each fqn whose module call signature is preserved,# map that fqn to a list of called modulescalled_modules=defaultdict(list)forentryinself.module_call_graph:ifentry.fqnandentry.signature:# some modules were removed and their fqns redirected to other# fqns during deduplicationfqn=entry.fqnmod=_get_attr(self,redirected_call_indices.get(fqn,fqn))base,idx=fqn.split("@")if"@"infqnelse[fqn,"0"]called_modules[base].append((int(idx),mod))attrs_map=defaultdict(set)fortargetinconsts_targets:if"."intarget:orig_fqn,name=target.rsplit(".",1)attrs_map[orig_fqn].add(name)else:attrs_map[""].add(target)# replace multiple call modules with a single dispatcher modulefororig_fqn,indexed_call_modulesincalled_modules.items():call_modules=[modfor_,modinsorted(indexed_call_modules)]iflen(call_modules)>1:fori,call_moduleinenumerate(call_modules):fqn=_call_name(orig_fqn,i+1)iffqnnotinredirected_call_indices:*prefix,name=fqn.split(".")_get_attr_via_attr_list(self,prefix)._modules.pop(name)self.set_submodule(orig_fqn,InterpreterModuleDispatcher(attrs_map[orig_fqn],call_modules),)# elide call indices in call modules because they are# tracked automatically inside the dispatcher moduledefelide_call_indices(prefix,graph):fornodeingraph.nodes:ifnode.op=="call_module":fqn=node.target.split("@")[0]path=f"{prefix}.{fqn}"ifprefixelsefqnifpathincalled_modules:node.target=fqnforfqn,modinself.named_modules(remove_duplicate=False):ifhasattr(mod,"graph"):elide_call_indices(fqn,mod.graph)elifhasattr(mod,"_call_modules"):formod_inmod._call_modules:asserthasattr(mod_,"graph")elide_call_indices(fqn,mod_.graph)defprint_readable(self,print_output=True,include_stride=False,include_device=False,colored=False,):return_print_readable(self,"UnflattenedModule",print_output,include_stride,include_device,colored,)
[docs]defunflatten(module:ExportedProgram,flat_args_adapter:Optional[FlatArgsAdapter]=None)->UnflattenedModule:"""Unflatten an ExportedProgram, producing a module with the same module hierarchy as the original eager module. This can be useful if you are trying to use :mod:`torch.export` with another system that expects a module hierachy instead of the flat graph that :mod:`torch.export` usually produces. .. note:: The args/kwargs of unflattened modules will not necessarily match the eager module, so doing a module swap (e.g. :code:`self.submod = new_mod`) will not necessarily work. If you need to swap a module out, you need to set the :code:`preserve_module_call_signature` parameter of :func:`torch.export.export`. Args: module (ExportedProgram): The ExportedProgram to unflatten. flat_args_adapter (Optional[FlatArgsAdapter]): Adapt flat args if input TreeSpec does not match with exported module's. Returns: An instance of :class:`UnflattenedModule`, which has the same module hierarchy as the original eager module pre-export. """module=_remove_effect_tokens(module)returnUnflattenedModule(module,flat_args_adapter)
def_inplace_buffer_mutations(graph:torch.fx.Graph,graph_signature:ExportGraphSignature,)->None:"""Transform buffer mutations from their functionalized form into a copy_ node in the graph. Functionalization represents buffer mutation by passing the buffer as an input and output. So for example, the eager code: def forward(self, x): self.buffer += x return x * x Will become a graph that looks like: def forward(self, buffer, x): mutated_buffer = aten.add(buffer, x) mul = aten.mul(x, x) return (mutated_buffer, mul) We want to inplace this into something that looks like the original eager code: def forward(self, buffer, x): mutated_buffer = aten.add(buffer, x) buffer.copy_(mutated_buffer) mul = aten.mul(x, x) return (mul,) """output_node=next(iter(reversed(graph.nodes)))assertoutput_node.op=="output"andlen(output_node.args)==1return_args=output_node.args[0]mutation_node_to_buffer=graph_signature.buffers_to_mutatemutations=return_args[:len(mutation_node_to_buffer)]buffers_to_inputs={v:kfork,vingraph_signature.inputs_to_buffers.items()}input_name_to_node={node.name:nodefornodeingraph.nodesifnode.op=="placeholder"}formutationinmutations:buffer_name=mutation_node_to_buffer[mutation.name]input_name=buffers_to_inputs[buffer_name]input_node=input_name_to_node[input_name]withgraph.inserting_after(mutation):new_node=graph.create_node("call_function",torch.ops.aten.copy_,(input_node,mutation))fork,vinmutation.meta.items():new_node.meta[k]=v# Replace all uses of the previously functional mutation with our copy_ output.mutation.replace_all_uses_with(new_node,lambdax:xisnotnew_node)# Remove the mutated buffer from the graph outputs, since we don't need to# thread it through anymore. We don't need to handle the inputs, which will# be handled by _sink_params.user_outputs=tuple(return_args[len(mutation_node_to_buffer):],)output_node.args=((user_outputs),)def_is_prefix(candidate,target):"""Check whether `candidate` is a prefix of `target`."""returnlen(candidate)<len(target)andtarget[:len(candidate)]==candidatedef_compute_accessor(parent_fqn:str,child_fqn:str)->str:ifparent_fqn=="":# Handle the root module correctly.returnchild_fqnparent_split=parent_fqn.split(".")child_split=child_fqn.split(".")# TODO: support skip connection by inlining the child module.ifchild_split[:len(parent_split)]!=parent_split:raiseRuntimeError(f"Child module '{child_fqn}' is not a descendant of parent module '{parent_fqn}'.""This is currently unsupported.""Please try to make child module attach to parent module directly.")return".".join(child_split[len(parent_split):])def_check_graph_equivalence(x:torch.nn.Module,y:torch.nn.Module):defgraph_dump(graph:torch.fx.Graph)->str:ret=[]nodes_idx:Dict[int,int]={}defarg_dump(arg)->str:ifisinstance(arg,torch.fx.Node):return"%"+str(nodes_idx[id(arg)])returnstr(arg)fori,nodeinenumerate(graph.nodes):args_dump=[str(arg)forarginpytree.tree_map(arg_dump,node.args)]args_dump+=[f"{key}={value}"forkey,valueinpytree.tree_map(arg_dump,node.kwargs).items()]target=node.targetifnode.opin("call_function","get_attr")else""ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})")nodes_idx[id(node)]=ireturn"\n".join(ret)assertisinstance(x.graph,torch.fx.Graph)assertisinstance(y.graph,torch.fx.Graph)returngraph_dump(x.graph)==graph_dump(y.graph)def_add_spec(gm:torch.nn.Module,spec)->str:i=0whilehasattr(gm,f"_spec_{i}"):i+=1name=f"_spec_{i}"setattr(gm,name,spec)returnnamedef_generate_flatten(gm:torch.fx.GraphModule,node)->torch.fx.Node:flatten=gm.graph.call_function(pytree.tree_flatten,(node,))getitem_0=gm.graph.call_function(operator.getitem,(flatten,0))returngetitem_0def_generate_flatten_spec(gm:Union[torch.fx.GraphModule,InterpreterModule,UnflattenedModule],node,spec)->torch.fx.Node:name=_add_spec(gm,spec)spec_node=gm.graph.get_attr(name)returngm.graph.call_function(fx_pytree.tree_flatten_spec,(node,spec_node))def_generate_unflatten(gm:Union[torch.fx.GraphModule,InterpreterModule,UnflattenedModule],nodes,spec)->torch.fx.Node:name=_add_spec(gm,spec)spec_node=gm.graph.get_attr(name)returngm.graph.call_function(pytree.tree_unflatten,(nodes,spec_node))def_get_submodule(mod:torch.nn.Module,target:str):*prefix,field=target.split(".")foriteminprefix:submod=getattr(mod,item,None)ifsubmodisNone:returnNoneifnotisinstance(submod,torch.nn.Module):returnNonemod=submodreturngetattr(mod,field,None)def_add_submodule(mod:torch.nn.Module,target:str,module_to_add:torch.nn.Module,create_module:Optional[Callable[[str],torch.nn.Module]]=None,):*prefix,field=target.split(".")fori,iteminenumerate(prefix):submod=getattr(mod,item,None)ifsubmodisNone:ifcreate_moduleisnotNone:submod=create_module(".".join(prefix[:i+1]))else:submod=torch.nn.Module()setattr(mod,item,submod)ifnotisinstance(submod,torch.nn.Module):returnFalsemod=submodmod.add_module(field,module_to_add)def_call_name(base:str,n:int)->str:# Given n >= 0, generate call names to a submodule `base` of the form# `base`, `base@1`, `base@2`, etc.returnbaseifn==1elsef"{base}@{n-1}"def_is_call_name(call_name:str,base:str)->bool:# Recognize when call_name = _call_name(base, n) for some n >= 0.returnre.match(re.escape(base)+r"(@\d+)?$",call_name)isnotNoneclass_ModuleFrame:def__init__(self,flat_graph:torch.fx.Graph,nodes:Tuple[torch.fx.Node,...],seen_nodes,seen_modules,seen_attrs,created_modules,parent,module_stack:List[Tuple[str,int]],module_id,module_call_graph:Dict[str,ModuleCallSignature],module:Optional[Union[torch.fx.GraphModule,UnflattenedModule]]=None,):self.flat_graph=flat_graphself.nodes=nodesself.seen_nodes=seen_nodesself.seen_modules=seen_modulesself.seen_attrs=seen_attrsself.created_modules=created_modulesself.parent=parentself.module_stack=module_stackself.module_id=module_idself.module_call_graph=module_call_graphself.verbose=Falseself.fqn,num_calls=self.module_stack[-1]# generate call name for self.fqnself.child_fqn=_call_name(self.fqn,num_calls+1)self.module:Union[torch.fx.GraphModule,UnflattenedModule,InterpreterModule]ifmoduleisnotNone:self.module=moduleself.ivals=module.ivalsifhasattr(module,"ivals")else{}# type: ignore[var-annotated]else:self.module=self.created_modules.get(self.fqn,InterpreterModule(torch.fx.Graph()),)self.ivals=parent.ivalsself.graph=self.module.graph# Mapping of nodes in the flat graph to nodes in this graph.self.node_map:Dict[torch.fx.Node,torch.fx.Node]={}self.node_to_placeholder={}self.parent_call_module:Optional[torch.fx.Node]=NoneifparentisnotNone:accessor=_compute_accessor(parent.fqn,self.child_fqn)defcreate_module(fqn):path=f"{parent.fqn}.{fqn}"ifparent.fqnelsefqnifpathinself.created_modules:returnself.created_modules[path]submod=InterpreterModule(torch.fx.Graph())self.created_modules[path]=submodreturnsubmod_add_submodule(parent.module,accessor,self.module,create_module)self.parent_call_module=parent.graph.call_module(accessor)ifself.seen_modules[self.module_id]:base_module_frame=self.seen_modules[self.module_id][0]self.module._modules=base_module_frame.module._modulesself.seen_modules[self.module_id].append(_SubmoduleEntry(parent_fqn=self.parent.fqn,parent_module=self.parent.module,parent_call_module=self.parent_call_module,fqn=self.fqn,call_idx=num_calls+1,module=self.module,))signature=module_call_graph.get(self.child_fqn)ifsignatureisnotNoneandself.parentisnotNone:assertsignature.in_spec.num_children==2args_spec=signature.in_spec.children_specs[0]kwargs_spec=signature.in_spec.children_specs[1]assertargs_spec.contextisNoneassertkwargs_spec.contextisnotNonewithself.graph.inserting_after(None):arg_nodes=[self.graph.placeholder(f"_positional_arg_{idx}")foridxinrange(args_spec.num_children)]kwarg_nodes={}fornameinkwargs_spec.context:kwarg_nodes[name]=self.graph.placeholder(name)flat_args=_generate_flatten_spec(self.module,(tuple(arg_nodes),kwarg_nodes),signature.in_spec,)foridx,arginenumerate(signature.inputs):flat_arg_node=self.graph.create_node(op="call_function",target=operator.getitem,args=(flat_args,idx),name=(arg.nameifnotisinstance(arg,ConstantArgument)elsef"_constant_{idx}"),)ifisinstance(arg,ConstantArgument):continueifarg.nameinself.seen_nodes:flat_arg_node.meta=copy.copy(self.seen_nodes[arg.name].meta)self.node_to_placeholder[self.seen_nodes[arg.name]]=flat_arg_nodewithself.parent.graph.inserting_before(self.parent_call_module):input_nodes:List[Optional[torch.fx.Node]]=[]forinputinsignature.inputs:ifisinstance(input,ConstantArgument):input_nodes.append(input.value)# type: ignore[arg-type]elifinput.namenotinself.seen_nodes:input_nodes.append(None)else:assertisinstance(input,(TensorArgument,SymIntArgument,SymBoolArgument,SymFloatArgument,),)input_nodes.append(self.parent.remap_input(self.seen_nodes[input.name]))inputs_node=_generate_unflatten(self.parent.module,input_nodes,signature.in_spec,)args_node=self.parent.graph.call_function(operator.getitem,(inputs_node,0))kwargs_node=self.parent.graph.call_function(operator.getitem,(inputs_node,1))arg_nodes=[self.parent.graph.call_function(operator.getitem,(args_node,i))foriinrange(args_spec.num_children)]kwarg_nodes={k:self.parent.graph.call_function(operator.getitem,(kwargs_node,k))forkinkwargs_spec.context}assertself.parent_call_moduleisnotNoneself.parent_call_module.args=tuple(arg_nodes)self.parent_call_module.kwargs=kwarg_nodes# type: ignore[assignment]defadd_placeholder(self,x):assertself.fqn!="",f"Cannot add placeholder {x} to root module"assertx.graphisself.flat_graph# x is not in subgraph, create a new placeholder for subgraphwithself.graph.inserting_before(None):placeholder_node=self.graph.placeholder(x.name,type_expr=x.type)# copy all meta fields, even if some fields might be irrelevant for# the placeholder nodeplaceholder_node.meta=copy.copy(x.meta)self.node_to_placeholder[x]=placeholder_nodedefcopy_sym_call_function(self,x):# This only exists because we deduplicate sym_size nodes in the flat export graph,# and if preserve_module_call_signature is set, we may not be able to pass sym_size# nodes, or their downstream users, as inputs to submodule calls.# To avoid this we copy these call_function nodes with sym_type results.# This should however only be done for sym_type nodes - call_function nodes on tensors# should not be deduplicated in the first place.args=pytree.tree_map_only(torch.fx.Node,self.remap_input,x.args)kwargs=pytree.tree_map_only(torch.fx.Node,self.remap_input,x.kwargs)node=self.graph.call_function(x.target,args,kwargs)node.meta=copy.copy(x.meta)self.node_map[x]=nodereturnnodedefremap_input(self,x):assertx.graphisself.flat_graphifxinself.node_map:returnself.node_map[x]self.print(f"remap_input({x})")ifxinself.node_to_placeholder:returnself.node_to_placeholder[x]elif(x.op=="placeholder"orself.module_call_graph.get(self.fqn)isNone# allow placeholder creation if we are not preserving module call signature):self.add_placeholder(x)ifself.parent_call_moduleisnotNone:# Important to *prepend* the output to match how we are# inserting placeholder nodes.withself.parent.graph.inserting_before(self.parent_call_module):self.parent_call_module.insert_arg(0,self.parent.remap_input(x))returnself.node_to_placeholder[x]elifx.op=="call_function"and(x.targetin(torch.ops.aten.sym_size.int,torch.ops.aten.item.default,torch.ops.aten.unbind.int,torch.ops.aten.sum.dim_IntList,torch.ops.aten.view.default,torch.ops.aten.diff.default,)or(hasattr(x.target,"__module__")andx.target.__module__=="_operator")):# export deduplicates sym_size nodes, and may need to re-copy them# if module call signature needs to be preservedself.copy_sym_call_function(x)returnself.node_map[x]elifself.module_call_graph.get(self.fqn)isnotNone:# x is an ival that is not in placeholders, so create a# get_attr node corresponding to attribute __ival__xreturnself.ivals.read(self.fqn,self.graph,x)# type: ignore[operator, union-attr]else:raiseRuntimeError(f"Could not run remap_input() on op type: {x.op} for node {x}")deffinalize_outputs(self):self.created_modules.pop(self.fqn,None)orig_outputs=[]signature=self.module_call_graph.get(self.child_fqn)ifsignatureisnotNoneandself.parentisnotNone:foroutputinsignature.outputs:ifisinstance(output,(TensorArgument,SymIntArgument,SymBoolArgument,SymFloatArgument),):ifoutput.nameinself.seen_nodes:orig_outputs.append(self.seen_nodes[output.name])else:orig_outputs.append(None)else:raiseRuntimeError(f"Unsupported data type for output node: {output}")defget_actual_output_node(output):ifoutputisNone:returnNoneseen_node=self.seen_nodes[output.name]ifseen_nodeinself.node_map:returnself.node_map[seen_node]elifseen_nodeinself.node_to_placeholder:returnself.node_to_placeholder[seen_node]else:raiseRuntimeError(f"Could not find output node {output}. Graph: {self.graph}")tree_out_node=_generate_unflatten(self.module,tuple(get_actual_output_node(output)foroutputinorig_outputs),signature.out_spec,)parent_out:Optional[torch.fx.Node]=_generate_flatten_spec(self.parent.module,self.parent_call_module,signature.out_spec)graph_outputs:Union[torch.fx.Node,List[torch.fx.Node]]=tree_out_nodeelse:graph_outputs=[]# Iterate through nodes we have copied into self.graph.fororig_nodeinself.node_map.keys():foruser_nodeinorig_node.users:ifuser_node.namenotinself.seen_nodes:# external user node, need to expose as an outputorig_outputs.append(orig_node)graph_outputs.append(self.node_map[orig_node])breakparent_out=self.parent_call_moduleiflen(graph_outputs)==1:graph_outputs=graph_outputs[0]assertisinstance(graph_outputs,(list,torch.fx.Node))self.graph.output(graph_outputs)# Rewrite outputs in parent moduleifparent_outisNone:returnparent_out.meta["val"]=(graph_outputs.meta.get("val")ifisinstance(graph_outputs,torch.fx.Node)else[o.meta.get("val")foroingraph_outputs])iflen(orig_outputs)==1andsignatureisNone:self.parent.node_map[orig_outputs[0]]=parent_outelse:fori,orig_outputinenumerate(orig_outputs):iforig_outputisNone:continue# Use Proxy to record getitem access.proxy_out=torch.fx.Proxy(parent_out)[i].node# type: ignore[index]proxy_out.meta["val"]=orig_output.meta.get("val")self.parent.node_map[orig_output]=proxy_outdefcopy_node(self,node):self.print("copying",node.format_node())self.node_map[node]=self.graph.node_copy(node,self.remap_input)self.seen_nodes[node.name]=nodedefrun_outer(self):i=0fornodeinself.flat_graph.nodes:self.print(i,node.meta.get("nn_module_stack"),node.format_node())i+=1# Copy all graph inputsnode_idx:int=0node=self.nodes[node_idx]whilenode.op=="placeholder":self.copy_node(node)node_idx+=1node=self.nodes[node_idx]self.run_from(node_idx)# Copy graph outputsfornodeinself.flat_graph.nodes:ifnode.op=="output":self.copy_node(node)defprint(self,*args,**kwargs):ifself.verbose:print(*args,**kwargs)defrun_from(self,node_idx):module_idx=0# Walk through the graph, building up a new graph with the right submoduleswhilenode_idx<len(self.nodes):node=self.nodes[node_idx]assertnode.op!="placeholder"self.print()self.print("STEP",node_idx,node.format_node())self.print(self.module_stack)depth=len(self.module_stack)ifnode.op=="output":ifdepth==1:# We want the output node of the original graph to be handled# specially by the outermost stack frame (in run_outer). So# skip finalization here.returnnode_idx# We've reached the end of the graph. Wrap up all the existing stack frames.self.finalize_outputs()returnnode_idxiflen(node.meta.get("nn_module_stack",{}))==0:raiseRuntimeError(f"Unable to find nn_module_stack for node {node}")nn_module_stack=node.meta["nn_module_stack"]fromtorch._export.passes._node_metadata_hookimport(_EMPTY_NN_MODULE_STACK_KEY,)if(len(nn_module_stack)==1and_EMPTY_NN_MODULE_STACK_KEYinnn_module_stack):# Empty case from the node_metadata_hooknode_module_stack=self.module_stackelse:node_module_stack=[(path,int(k.split("@")[-1])if"@"inkelse0)fork,(path,ty)innode.meta["nn_module_stack"].items()]ifnode_module_stack[:depth]!=self.module_stack:# This means that the current module is done executing and the# current node is the beginning of a new module.## In this case, we should finalize this module and return without# incrementing the node counter.self.finalize_outputs()self.print("outlining",self.fqn)self.print(self.graph)returnnode_idxassertnode_module_stackisnotNoneif_is_prefix(self.module_stack,node_module_stack):# This means that the current node represents the execution of a new# module.next_module=node_module_stack[depth]self.print("Creating new stack frame for",next_module)# Run a nested version of module outliner from the current node# counter. Once it is complete, continue from that point.next_module_key=list(node.meta["nn_module_stack"].keys())[depth]node_idx=_ModuleFrame(self.flat_graph,self.nodes,self.seen_nodes,self.seen_modules,self.seen_attrs,self.created_modules,self,self.module_stack+[next_module],next_module_key.split("@")[0],self.module_call_graph,).run_from(node_idx)module_idx+=1continue# The only remaining possibility is that we are in the right stack# frame. Copy the node into this frame's graph and increment the node counter.assertnode_module_stack==self.module_stackifnode.op=="get_attr":# this must be a graph argument for a HOPself.seen_attrs[self.child_fqn].add(node.target)self.copy_node(node)node_idx+=1@dataclassclass_SubmoduleEntry:parent_fqn:strparent_module:torch.nn.Moduleparent_call_module:torch.fx.Nodefqn:strcall_idx:intmodule:torch.nn.Moduledef_outline_submodules(orig_graph:torch.fx.Graph,root_module:UnflattenedModule):seen_nodes:Dict[str,torch.fx.Node]={}seen_modules:Dict[int,List[_SubmoduleEntry]]=defaultdict(list)seen_attrs:Dict[str,Set[str]]=defaultdict(set)created_modules:Dict[str,torch.nn.Module]={}_ModuleFrame(orig_graph,tuple(orig_graph.nodes),seen_nodes,seen_modules,seen_attrs,created_modules,None,[("",0)],"",{entry.fqn:entry.signatureforentryinroot_module.module_call_graphifentry.signature},module=root_module,).run_outer()returnseen_modules,seen_attrsdef_reorder_submodules(parent:torch.nn.Module,fqn_order:Dict[str,int],prefix:str=""):# TODO Can be optimized by adding submodules ahead of time.ifprefix=="":forfqninlist(fqn_order.keys())[1:]:if_get_submodule(parent,fqn)isNone:_add_submodule(parent,fqn,torch.nn.Module())children=[]forname,childinlist(parent._modules.items()):ifchildisNone:continuefqn=prefix+name_reorder_submodules(child,fqn_order,prefix=fqn+".")delattr(parent,name)children.append((fqn_order[fqn],name,child))children.sort(key=operator.itemgetter(0))for_,name,childinchildren:parent.register_module(name,child)class_IVals:""" Collect the intermediate values of buffer mutations in a graph, along with the module call fqns that create and use them. Later, in each fqn associated with an intermediate value we will install a corresponding attribute, so that it can be updated and read. Example: in the following graph, suppose that buf_in and buf_out are the input and output values of a buffer. buf_in = placeholder() ... ival1 = f0(buf_in, ...) # inside self.n0(...) ... ival2 = f1(ival1, ...) # inside self.n1(...) ... buf_out = f2(ival2, ...) # inside self.n2(...) return buf_out, ... Here ival1 and ival2 are intermediate values created inside calls to n0 and n1 respectively, and used inside calls to n1 and n2 respectively. Thus our analysis will produce {ival1: {n0, n1}, ival2: {n1, n2}}. """def__init__(self):# ival node name -> set of fqns that create and use itself.fqns=defaultdict(set)# ival node name -> tensor storage for corresponding attributeself.storage={}defread(self,fqn,graph,node):""" Read attribute corresponding to a given intermediate value. """# to read ival x, get attribute __ival__xwithgraph.inserting_before(None):ival_node=graph.get_attr("__ival__"+node.name,type_expr=node.type)ival_node.meta=copy.copy(node.meta)ifnode.namenotinself.storage:# create empty tensor matching fake, using a cache# to ensure the same tensor is returned per ival_namefake=node.meta["val"]self.storage[node.name]=torch.empty(fake.shape,dtype=fake.dtype)self.fqns[node.name].add(fqn)returnival_nodedefupdate(self,fqn,graph,node):""" Update attribute corresponding to a given intermediate value. """self.fqns[node.name].add(fqn)# to update ival x, get attribute __ival__x and copy x to __ival__xwithgraph.inserting_after(node):ival_node=graph.get_attr("__ival__"+node.name,type_expr=node.type)ival_node.meta=copy.copy(node.meta)withgraph.inserting_after(ival_node):new_ival_node=graph.create_node("call_function",torch.ops.aten.copy_,(ival_node,node))new_ival_node.meta=copy.copy(node.meta)defcreate(self,partitions,root_module):""" Update attributes corresponding to intermediate values that were read. Finally, initialize attributes in all modules that read or update corresponding intermediate values. """entries=[("",root_module)]forshared_submodulesinpartitions:forentryinshared_submodules:entries.append((entry.fqn,entry.module))graph=entry.module.graphfornodeingraph.nodes:ifnode.nameinself.storage:self.update(entry.fqn,graph,node)# fqn -> list of ival node names read or updated through itivals=defaultdict(list)forname,fqnsinself.fqns.items():forfqninfqns:ivals[fqn].append(name)forfqn,modinentries:fornameinivals[fqn]:ival_name=f"__ival__{name}"# for a ival named x created in module call m,# create attribute m.__ival__x, initially emptysetattr(mod,ival_name,self.storage[name])def_copy_graph_attrs(gm:torch.fx.GraphModule,root_module:UnflattenedModule,seen_attrs:Dict[str,Set[str]],):forchild_fqn,namesinseen_attrs.items():module=_get_attr(root_module,child_fqn)ifchild_fqnelseroot_modulefornameinnames:val=getattr(gm,name)setattr(module,name,val)def_deduplicate_modules(partitions):redirected_call_indices={}forshared_submodulesinpartitions:fori,entryinenumerate(shared_submodules):child_fqn=_call_name(entry.fqn,entry.call_idx)target=_compute_accessor(entry.parent_fqn,child_fqn)deduplicated=False# Iterate over all previously seen modules, and deduplicate if possibleforseeninshared_submodules[:i]:if_check_graph_equivalence(seen.module,entry.module):parent=entry.parent_module# Since graphs are equivalent, we can deduplicate.# There are two cases.ifseen.fqn==entry.fqn:# Case 1: The current module has the same fqn as the seen module.# In this case we have generated a call name that can be optimized away.# So we remove the current module from the hierarchy and replace# the current call name with the seen call name in the parent graph.*prefix,name=target.split(".")_get_attr_via_attr_list(parent,prefix)._modules.pop(name)seen_child_fqn=_call_name(seen.fqn,seen.call_idx)seen_target=_compute_accessor(entry.parent_fqn,seen_child_fqn)entry.parent_call_module.target=seen_targetredirected_call_indices[child_fqn]=seen_child_fqnbreakelifnotdeduplicated:# Case 2: The current module has a different fqn than the seen module.# In this case we replace the current module with the seen module.# There should be nothing pointing to the current module any more,# so it can be garbage collected.# NOTE: We *do not* replace the current call name with the seen call name# in the parent graph, because this will lose information on which fqn# was actually called. However, it is possible that the current call name# will be optimized away when we find another seen module with the same fqn,# so we do not break out of the loop yet.parent.set_submodule(target,seen.module)deduplicated=Truereturnredirected_call_indicesdef_sink_params(module:torch.nn.Module,inputs_to_state:Dict[str,List[str]],scope:List[str],module_id_to_inputs_removed:Optional[Dict[int,Set[str]]]=None,):"""Sink params, buffers, and constants from graph inputs into get_attr nodes. Exported modules are purely functional, so they pass their parameters and buffers in as inputs to the graph. To replicate eager's semantics, we need to get them from the module state via get_attr instead. module: GraphModule, potentially containing nested submodules. inputs_to_state: mapping graph input names to the corresponding key in the state_dict. scope: tracks where we are in the module hierarchy, so that we can emit the right `getattr(self, "foo.bar")` calls, etc. module_id_to_inputs_removed: records inputs removed by child modules, mapping the module object id to the list of placeholder node names in the child module that were removed. """ifmodule_id_to_inputs_removedisNone:module_id_to_inputs_removed=defaultdict(set)ifid(module)inmodule_id_to_inputs_removed:return{id(module):module_id_to_inputs_removed[id(module)]}# We need to use _modules here instead of named_children(), because we# explicitly want duplicate modules to show up in the traversal.forname,submoduleinmodule._modules.items():submod_id_to_inputs_removed=_sink_params(cast(torch.nn.Module,submodule),inputs_to_state,scope+[name],module_id_to_inputs_removed,)fork,vinsubmod_id_to_inputs_removed.items():module_id_to_inputs_removed[k].update(v)graph=getattr(module,"graph",None)ifgraphisNoneorlen(graph.nodes)==0:# Not all modules have graphs defined, if they are empty modules with no operations (like ParameterList)returnmodule_id_to_inputs_removedassertisinstance(graph,torch.fx.Graph)inputs=list(filter(lambdan:n.op=="placeholder",graph.nodes))the_last_input=inputs[-1]# Also remove from call_module nodescall_module_nodes=filter(lambdan:n.op=="call_module",graph.nodes)fornodeincall_module_nodes:submodule=_get_attr(module,node.target)# remove placeholder from call_module node arguments, only if we've# erased the placeholder node in the corresponding _sink_params() callifsubmoduleisnotNoneandid(submodule)inmodule_id_to_inputs_removed:node.args=tuple(filter(lambdan:n.namenotinmodule_id_to_inputs_removed[id(submodule)],node.args,))# Filter out inputs_to_state corresponding to current scope.inputs_to_state_of_scope:Dict[torch.fx.Node,list[str]]={}fornodeininputs:ifnode.namenotininputs_to_state:continuestate_name=Noneforsnininputs_to_state[node.name]:sn_split=sn.split(".")ifsn_split[:len(scope)]==[x.split("@")[0]forxinscope]:state_name=sn_splitbreak# If there's a mismatch between scope name and state name, then# there must be multiple scopes pointing to the same state name,# meaning some modules are shared. In such case, we can simply skip# updating the current node because another later iteration will# take care of this input node when the unique match between scope# and state name occurs. To make sure this always happen, we should# enforce the invariant that no placeholder node in the unflattened# graph appears in inputs_to_state dict, which means all the extra# input nodes have been handled.ifstate_nameisNone:continueinputs_to_state_of_scope[node]=state_name# Record name of remove inputs for return purpose.inputs_removed:Set[str]=set()fornode,state_nameininputs_to_state_of_scope.items():iflen(node.users)>0:attr_path=state_name[len(scope):]state_attr=_get_attr_via_attr_list(module,attr_path)assertisinstance(state_attr,(torch.Tensor,torch.ScriptObject))# Make sure the newly created get_attr node is placed after the last placeholder nodewithgraph.inserting_after(the_last_input):new_node=graph.create_node("get_attr",".".join(attr_path))node.replace_all_uses_with(new_node,propagate_meta=True)graph.erase_node(node)inputs_removed.add(node.name)ifisinstance(module,InterpreterModule):module.finalize()return{id(module):inputs_removed}
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.