importabcimportcopyimportoperatorfromcopyimportdeepcopyfromenumimportEnumfromitertoolsimportchainfromtypingimportAny,cast,Dict,List,Optional,Unionimporttorchimporttorch.fx._pytreeasfx_pytreeimporttorch.utils._pytreeaspytreefromtorch.export._tree_utilsimportreorder_kwargsfromtorch.export.exported_programimport(ConstantArgument,ExportedProgram,ModuleCallSignature,SymIntArgument,TensorArgument,)fromtorch.fx._symbolic_traceimportis_fx_tracingfromtorch.utils._pytreeimportGetAttrKey,SequenceKey__all__=["InterpreterModule","UnflattenedModule","unflatten","FlatArgsAdapter"]class_AttrKind(Enum):PARAMETER="parameter"BUFFER="buffer"CONSTANT="constant"# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module# This installs empty Modules where none exist yet if they are subpaths of targetdef_assign_attr(from_obj:Union[torch.Tensor,torch.ScriptObject],to_module:torch.nn.Module,target:str,attr_kind:_AttrKind,persistent:bool=True,):*prefix,field=target.split(".")foriteminprefix:t=getattr(to_module,item,None)iftisNone:t=torch.nn.Module()setattr(to_module,item,t)to_module=tifattr_kind==_AttrKind.PARAMETER:assertisinstance(from_obj,torch.nn.Parameter)to_module.register_parameter(field,from_obj)elifattr_kind==_AttrKind.BUFFER:assertisinstance(from_obj,torch.Tensor)to_module.register_buffer(field,from_obj,persistent=persistent)elifattr_kind==_AttrKind.CONSTANT:assertisinstance(from_obj,(torch.Tensor,torch.ScriptObject))setattr(to_module,field,from_obj)
[docs]classInterpreterModule(torch.nn.Module):"""A module that uses torch.fx.Interpreter to execute instead of the usual codegen that GraphModule uses. This provides better stack trace information and makes it easier to debug execution. """def__init__(self,graph:torch.fx.Graph,):super().__init__()self.graph=graphself.graph.owning_module=selfdefforward(self,*args,**kwargs):assertself.graph_moduleisnotNone,"Didn't finalize this InterpreterModule"iftorch.compiler.is_dynamo_compiling():# Dynamo cannot trace through torch.fx.Interpreter, so fall back to# GraphModule codegen in this instance.returnself.graph_module(*args,**kwargs)else:ifkwargs:# Handle **kwargs. FX only natively supports positional# arguments (through placeholders). So in order to pass in# kwargs, we must correspond the names of the placeholders with# the keys in the kwarg dict.arg_list=list(args)kwarg_names=self.arg_names[len(arg_list):]forkwarg_nameinkwarg_names:ifkwarg_nameinkwargs:arg_list.append(kwargs[kwarg_name])# Assert that the kwargs passed in exactly match the positional# arguments specified by the GraphModule. This should be# guaranteed by the unflattening process.assertlen(kwarg_names)==len(kwargs)assertlen(arg_list)==len(self.arg_names)args=tuple(arg_list)returntorch.fx.Interpreter(self,graph=self.graph).run(*args,enable_io_processing=False)deffinalize(self):# We need to "finalize" because GraphModule populates its own state_dict# based on the get_attrs observed in the graph. So we need to fully# construct the graph and call _sink_params before generating this# GraphModule.# need to set `graph_module` directly on the dict to avoid it getting# registered as a submodule.self.__dict__["graph_module"]=torch.fx.GraphModule(self,self.graph)self.graph.lint()# Cache arg names for kwarg handling (see forward())self.arg_names=[]fornodeinself.graph.nodes:ifnode.op=="placeholder":self.arg_names.append(node.target)
[docs]classFlatArgsAdapter(abc.ABC):""" Adapts input arguments with ``input_spec`` to align ``target_spec``. """
[docs]@abc.abstractmethoddefadapt(self,target_spec:pytree.TreeSpec,input_spec:pytree.TreeSpec,input_args:List[Any],)->List[Any]:"""NOTE: This adapter may mutate given ``input_args_with_path``."""...
classUnflattenedModule(torch.nn.Module):def__init__(self,export_module:ExportedProgram,flat_args_adapter:Optional[FlatArgsAdapter]=None,):super().__init__()ifexport_module.graph_signature.backward_signatureisnotNone:raiseValueError("Unflattening on JointExportModule NYI")export_graph=deepcopy(export_module.graph)self.graph_signature=deepcopy(export_module.graph_signature)self.graph=torch.fx.Graph()self.module_call_graph=deepcopy(export_module.module_call_graph)self.flat_args_adapter=flat_args_adapter# Flag to indicate whether args have been adapted.self.adapted=False_inplace_buffer_mutations(export_graph,self.graph_signature)_outline_submodules(export_graph,self)self.range_constraints=export_module.range_constraintsself.equality_constraints:List=[]state_dict=export_module.state_dictfornameinself.graph_signature.parameters:cloned=torch.nn.Parameter(state_dict[name].clone())_assign_attr(cloned,self,name,attr_kind=_AttrKind.PARAMETER,)non_persistent_buffers=set(self.graph_signature.non_persistent_buffers)fornameinself.graph_signature.buffers:ifnameinnon_persistent_buffers:persistent=Falsecloned=export_module.constants[name].clone()else:persistent=Truecloned=state_dict[name].clone()_assign_attr(cloned,self,name,attr_kind=_AttrKind.BUFFER,persistent=persistent,)forfqninchain(self.graph_signature.lifted_tensor_constants,self.graph_signature.lifted_custom_objs,):constant=export_module.constants[fqn]ifisinstance(constant,torch.Tensor):constant=constant.clone()_assign_attr(constant,self,fqn,attr_kind=_AttrKind.CONSTANT,)inputs_to_state:Dict[str,str]={**self.graph_signature.inputs_to_parameters,**self.graph_signature.inputs_to_buffers,**self.graph_signature.inputs_to_lifted_tensor_constants,**self.graph_signature.inputs_to_lifted_custom_objs,}_sink_params(self,inputs_to_state,[])# Check all input nodes has been processed.formoduleinself.modules():ifnotisinstance(module,torch.fx.GraphModule):continuefornodeinmodule.graph.nodes:ifnode.op!="placeholder":continueassertnode.namenotininputs_to_state# Cache so we don't have to compute this every time.# NOTE: this needs to be kept in sync with the placeholders in# self.graph, but currently we have no way to guarantee that.self.input_placeholders=[nodefornodeinself.graph.nodesifnode.op=="placeholder"]self.check_input_constraints=Trueassertself.module_call_graph[0].fqn==""defforward(self,*args,**kwargs):signature=self.module_call_graph[0].signaturereordered_kwargs=reorder_kwargs(kwargs,signature.in_spec)flat_args_with_path,in_spec=pytree.tree_flatten_with_path((args,reordered_kwargs))flat_args=[x[1]forxinflat_args_with_path]ifis_fx_tracing():return_val=torch.fx.Interpreter(self,graph=self.graph).run(*flat_args,enable_io_processing=False)# For scalar return value, fx.Graph wraps in a tupleifisinstance(return_val,tuple)andlen(return_val)==1:returnreturn_val[0]returnreturn_valifin_spec!=signature.in_spec:ifnotself.adapted:print("Input treespec does not match with exported module's: \n"f"Input treespec: {in_spec}. ",f"Exported module treespec: {signature.in_spec}",)ifself.flat_args_adapterisNone:raiseTypeError("There is no flat args adapter sepcified. ""Are you sure you are calling this with the right arguments? ")else:ifnotself.adapted:print("Adapting flat arg to match exported module's treespec")flat_args=self.flat_args_adapter.adapt(target_spec=signature.in_spec,input_spec=in_spec,input_args=flat_args,)self.adapted=Trueiflen(flat_args)!=signature.in_spec.num_leaves:raiseTypeError(f"Flat args adaption failed, number of args mismatch "f"Adatped: {len(flat_args)}\n"f"Exported module: {signature.in_spec.num_leaves}")ifself.check_input_constraints:# Import here to avoid an unfortunate circular dependency.# TODO(suo): untangle this.fromtorch._export.utilsimport_check_input_constraints_for_graphifself.adaptedisTrue:# TODO(suo): The FlatArgsAdapter returns a list of flat args,# which we don't have keypaths for. For now, just create a dummy# keypath to associate with the arg.new_flat_args_with_path=[# type: ignore[var-annotated]((SequenceKey(idx=0),GetAttrKey(name="<unknown location>")),arg)forarginflat_args]else:new_flat_args_with_path=flat_args_with_path# type: ignore[assignment]_check_input_constraints_for_graph(self.input_placeholders,new_flat_args_with_path,self.range_constraints)tree_out=torch.fx.Interpreter(self,graph=self.graph).run(*flat_args,enable_io_processing=False)returnpytree.tree_unflatten(tree_out,signature.out_spec)
[docs]defunflatten(module:ExportedProgram,flat_args_adapter:Optional[FlatArgsAdapter]=None)->UnflattenedModule:"""Unflatten an ExportedProgram, producing a module with the same module hierarchy as the original eager module. This can be useful if you are trying to use :mod:`torch.export` with another system that expects a module hierachy instead of the flat graph that :mod:`torch.export` usually produces. .. note:: The args/kwargs of unflattened modules will not necessarily match the eager module, so doing a module swap (e.g. :code:`self.submod = new_mod`) will not necessarily work. If you need to swap a module out, you need to set the :code:`preserve_module_call_signature` parameter of :func:`torch.export.export`. Args: module (ExportedProgram): The ExportedProgram to unflatten. flat_args_adapter (Optional[FlatArgsAdapter]): Adapt flat args if input TreeSpec does not match with exported module's. Returns: An instance of :class:`UnflattenedModule`, which has the same module hierarchy as the original eager module pre-export. """returnUnflattenedModule(module,flat_args_adapter)
def_inplace_buffer_mutations(graph:torch.fx.Graph,graph_signature)->None:"""Transform buffer mutations from their functionalized form into a copy_ node in the graph. Functionalization represents buffer mutation by passing the buffer as an input and output. So for example, the eager code: def forward(self, x): self.buffer += x return x * x Will become a graph that looks like: def forward(self, buffer, x): mutated_buffer = aten.add(buffer, x) mul = aten.mul(x, x) return (mutated_buffer, mul) We want to inplace this into something that looks like the original eager code: def forward(self, buffer, x): mutated_buffer = aten.add(buffer, x) buffer.copy_(mutated_buffer) mul = aten.mul(x, x) return (mul,) """output_node=next(iter(reversed(graph.nodes)))assertoutput_node.op=="output"andlen(output_node.args)==1return_args=output_node.args[0]mutation_node_to_buffer=graph_signature.buffers_to_mutatemutations=return_args[:len(mutation_node_to_buffer)]buffers_to_inputs={v:kfork,vingraph_signature.inputs_to_buffers.items()}input_name_to_node={node.name:nodefornodeingraph.nodesifnode.op=="placeholder"}formutationinmutations:buffer_name=mutation_node_to_buffer[mutation.name]input_name=buffers_to_inputs[buffer_name]input_node=input_name_to_node[input_name]withgraph.inserting_after(mutation):new_node=graph.create_node("call_function",torch.ops.aten.copy_,(input_node,mutation))fork,vinmutation.meta.items():new_node.meta[k]=v# Replace all uses of the previously functional mutation with our copy_ output.mutation.replace_all_uses_with(new_node,lambdax:xisnotnew_node)# Remove the mutated buffer from the graph outputs, since we don't need to# thread it through anymore. We don't need to handle the inputs, which will# be handled by _sink_params.user_outputs=tuple(return_args[len(mutation_node_to_buffer):],)output_node.args=((user_outputs),)def_is_prefix(candidate,target):"""Check whether `candidate` is a prefix of `target`."""returnlen(candidate)<len(target)andtarget[:len(candidate)]==candidatedef_compute_accessor(parent_fqn:str,child_fqn:str)->str:ifparent_fqn=="":# Handle the root module correctly.returnchild_fqnparent_split=parent_fqn.split(".")child_split=child_fqn.split(".")assert(child_split[:len(parent_split)]==parent_split),f"Child module '{child_fqn}' is not a descendant of parent module '{parent_fqn}'"return".".join(child_split[len(parent_split):])def_verify_graph_equivalence(x:torch.nn.Module,y:torch.nn.Module):defgraph_dump(graph:torch.fx.Graph)->str:ret=[]nodes_idx:Dict[int,int]={}defarg_dump(arg)->str:ifisinstance(arg,torch.fx.Node):return"%"+str(nodes_idx[id(arg)])returnstr(arg)fori,nodeinenumerate(graph.nodes):args_dump=[str(arg)forarginpytree.tree_map(arg_dump,node.args)]args_dump+=[f"{key}={value}"forkey,valueinpytree.tree_map(arg_dump,node.kwargs).items()]target=node.targetifnode.op=="call_function"else""ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})")nodes_idx[id(node)]=ireturn"\n".join(ret)assertgraph_dump(x.graph)==graph_dump(y.graph)def_add_spec(gm:torch.nn.Module,spec)->str:i=0whilehasattr(gm,f"_spec_{i}"):i+=1name=f"_spec_{i}"setattr(gm,name,spec)returnnamedef_generate_flatten(gm:torch.nn.Module,node,spec)->torch.fx.Node:name=_add_spec(gm,spec)spec_node=gm.graph.get_attr(name)returngm.graph.call_function(fx_pytree.tree_flatten_spec,(node,spec_node))def_generate_unflatten(gm:torch.nn.Module,nodes,spec)->torch.fx.Node:name=_add_spec(gm,spec)spec_node=gm.graph.get_attr(name)returngm.graph.call_function(pytree.tree_unflatten,(nodes,spec_node))def_add_submodule(mod:torch.nn.Module,target:str,module_to_add:torch.nn.Module):*prefix,field=target.split(".")foriteminprefix:submod=getattr(mod,item,None)ifsubmodisNone:submod=torch.nn.Module()setattr(mod,item,submod)ifnotisinstance(submod,torch.nn.Module):returnFalsemod=submodmod.add_module(field,module_to_add)class_ModuleFrame:def__init__(self,flat_graph,nodes,seen_nodes,seen_modules,parent,module_stack,module_id,module_call_graph:Dict[str,ModuleCallSignature],module:Optional[torch.nn.Module]=None,):self.flat_graph=flat_graphself.nodes=nodesself.seen_nodes=seen_nodesself.seen_modules=seen_modulesself.parent=parentself.module_stack=module_stackself.module_id=module_idself.module_call_graph=module_call_graphself.verbose=Falseself.fqn=self.module_stack[-1]ifmoduleisnotNone:self.module=moduleelse:self.module=InterpreterModule(torch.fx.Graph())ifself.module_idinself.seen_modules:self.cached_graph_module=self.seen_modules[self.module_id]else:self.cached_graph_module=Noneself.seen_modules[self.module_id]=self.moduleself.graph=self.module.graph# Mapping of nodes in the flat graph to nodes in this graph.self.node_map:Dict[torch.fx.Node,torch.fx.Node]={}self.node_to_placeholder={}self.parent_call_module:Optional[torch.fx.Node]=NoneifparentisnotNone:accessor=_compute_accessor(parent.fqn,self.fqn)_add_submodule(parent.module,accessor,self.moduleifself.cached_graph_moduleisNoneelseself.cached_graph_module,)self.parent_call_module=parent.graph.call_module(accessor)signature=module_call_graph.get(self.fqn)ifsignatureisnotNoneandself.parentisnotNone:assertsignature.in_spec.num_children==2args_spec=signature.in_spec.children_specs[0]kwargs_spec=signature.in_spec.children_specs[1]assertargs_spec.contextisNoneassertkwargs_spec.contextisnotNonewithself.graph.inserting_after(None):arg_nodes=[]foridxinrange(args_spec.num_children):arg_nodes.append(self.graph.placeholder(f"_positional_arg_{idx}"))kwarg_nodes={}fornameinkwargs_spec.context:kwarg_nodes[name]=self.graph.placeholder(name)flat_args=_generate_flatten(self.module,(tuple(arg_nodes),kwarg_nodes),signature.in_spec,)foridx,arginenumerate(signature.inputs):flat_arg_node=self.graph.create_node(op="call_function",target=operator.getitem,args=(flat_args,idx),name=arg.nameifnotisinstance(arg,ConstantArgument)elsef"_constant_{idx}",)ifisinstance(arg,ConstantArgument):continueflat_arg_node.meta=copy.copy(self.seen_nodes[arg.name].meta)self.node_to_placeholder[self.seen_nodes[arg.name]]=flat_arg_nodewithself.parent.graph.inserting_before(self.parent_call_module):input_nodes:List[Optional[torch.fx.Node]]=[]forinputinsignature.inputs:ifisinstance(input,ConstantArgument)andinput.valueisNone:input_nodes.append(None)else:assertisinstance(input,(TensorArgument,SymIntArgument))input_nodes.append(self.parent.remap_input(self.seen_nodes[input.name]))inputs_node=_generate_unflatten(self.parent.module,input_nodes,signature.in_spec,)args_node=self.parent.graph.call_function(operator.getitem,(inputs_node,0))kwargs_node=self.parent.graph.call_function(operator.getitem,(inputs_node,1))arg_nodes=[self.parent.graph.call_function(operator.getitem,(args_node,i))foriinrange(args_spec.num_children)]kwarg_nodes={k:self.parent.graph.call_function(operator.getitem,(kwargs_node,k))forkinkwargs_spec.context}assertself.parent_call_moduleisnotNoneself.parent_call_module.args=tuple(arg_nodes)self.parent_call_module.kwargs=kwarg_nodesdefadd_placeholder(self,x):assertx.graphisself.flat_graph# x is not in subgraph, create a new placeholder for subgraphwithself.graph.inserting_before(None):placeholder_node=self.graph.placeholder(x.name,type_expr=x.type)# copy all meta fields, even if some fields might be irrelvant for# the placeholder nodeplaceholder_node.meta=copy.copy(x.meta)self.node_to_placeholder[x]=placeholder_nodedefremap_input(self,x):assertx.graphisself.flat_graphifxinself.node_map:returnself.node_map[x]ifxnotinself.node_to_placeholder:self.add_placeholder(x)ifself.parent_call_moduleisnotNone:# Important to *prepend* the output to match how we are# inserting placeholder nodes.self.parent_call_module.insert_arg(0,self.parent.remap_input(x))returnself.node_to_placeholder[x]deffinalize_outputs(self):orig_outputs=[]signature=self.module_call_graph.get(self.fqn)ifsignatureisnotNoneandself.parentisnotNone:foroutputinsignature.outputs:ifisinstance(output,(TensorArgument,SymIntArgument)):orig_outputs.append(self.seen_nodes[output.name])else:raiseRuntimeError(f"Unsupported data type for output node: {output}")tree_out_node=_generate_unflatten(self.module,tuple(self.node_map[self.seen_nodes[output.name]]foroutputinorig_outputs),signature.out_spec,)parent_out:Optional[torch.fx.Node]=_generate_flatten(self.parent.module,self.parent_call_module,signature.out_spec)graph_outputs:Union[torch.fx.Node,List[torch.fx.Node]]=tree_out_nodeelse:graph_outputs=[]# Iterate through nodes we have copied into self.graph.fororig_nodeinself.node_map.keys():foruser_nodeinorig_node.users:ifuser_node.namenotinself.seen_nodes:# external user node, need to expose as an outputorig_outputs.append(orig_node)graph_outputs.append(self.node_map[orig_node])breakparent_out=self.parent_call_moduleiflen(graph_outputs)==1:graph_outputs=graph_outputs[0]assertisinstance(graph_outputs,(list,torch.fx.Node))self.graph.output(graph_outputs)# Rewrite outputs in parent moduleifparent_outisNone:returnparent_out.meta["val"]=(graph_outputs.meta.get("val")ifisinstance(graph_outputs,torch.fx.Node)else[o.meta.get("val")foroingraph_outputs])iflen(orig_outputs)==1andsignatureisNone:self.parent.node_map[orig_outputs[0]]=parent_outelse:fori,orig_outputinenumerate(orig_outputs):# Use Proxy to record getitem access.proxy_out=torch.fx.Proxy(parent_out)[i].node# type: ignore[index]proxy_out.meta["val"]=orig_output.meta.get("val")self.parent.node_map[orig_output]=proxy_outifself.cached_graph_moduleisnotNone:_verify_graph_equivalence(self.cached_graph_module,self.module)defcopy_node(self,node):self.print("copying",node.format_node())self.node_map[node]=self.graph.node_copy(node,self.remap_input)self.seen_nodes[node.name]=nodedefrun_outer(self):i=0fornodeinself.flat_graph.nodes:self.print(i,node.meta.get("nn_module_stack"),node.format_node())i+=1# Copy all graph inputsnode_idx:int=0node=self.nodes[node_idx]whilenode.op=="placeholder":self.copy_node(node)node_idx+=1node=self.nodes[node_idx]self.run_from(node_idx)# Copy graph outputsfornodeinself.flat_graph.nodes:ifnode.op=="output":self.copy_node(node)defprint(self,*args,**kwargs):ifself.verbose:print(*args,**kwargs)defrun_from(self,node_idx):module_idx=0# Walk through the graph, building up a new graph with the right submoduleswhilenode_idx<len(self.nodes):node=self.nodes[node_idx]assertnode.op!="placeholder"self.print()self.print("STEP",node_idx,node.format_node())self.print(self.module_stack)ifnode.op=="output":iflen(self.module_stack)==1:# We want the output node of the original graph to be handled# specially by the outermost stack frame (in run_outer). So# skip finalization here.returnnode_idx# We've reached the end of the graph. Wrap up all the existing stack frames.self.finalize_outputs()returnnode_idxnode_module_stack=([pathforpath,tyinnode.meta["nn_module_stack"].values()]if"nn_module_stack"innode.metaelseself.module_stack)ifnode_module_stack[:len(self.module_stack)]!=self.module_stack:# This means that the current module is done executing and the# current node is the beginning of a new module.## In this case, we should finalize this module and return without# incrementing the node counter.self.finalize_outputs()self.print("outlining",self.fqn)self.print(self.graph)returnnode_idxassertnode_module_stackisnotNoneif_is_prefix(self.module_stack,node_module_stack):# This means that the current node represents the execution of a new# module.next_module=node_module_stack[len(self.module_stack)]self.print("Creating new stack frame for",next_module)# Run a nested version of module outliner from the current node# counter. Once it is complete, continue from that point.node_idx=_ModuleFrame(self.flat_graph,self.nodes,self.seen_nodes,self.seen_modules,self,self.module_stack+[next_module],list(node.meta["nn_module_stack"].keys())[len(self.module_stack)],self.module_call_graph,).run_from(node_idx)module_idx+=1continue# The only remaining possibility is that we are in the right stack# frame. Copy the node into this frame's graph and increment the node counter.assertnode_module_stack==self.module_stackself.copy_node(node)node_idx+=1def_outline_submodules(orig_graph:torch.fx.Graph,root_module:UnflattenedModule):seen_nodes:Dict[str,torch.fx.Node]={}seen_modules:Dict[int,torch.nn.Module]={}_ModuleFrame(orig_graph,tuple(orig_graph.nodes),seen_nodes,seen_modules,None,[""],"",{entry.fqn:entry.signatureforentryinroot_module.module_call_graphifentry.signature},module=root_module,).run_outer()def_sink_params(module:torch.nn.Module,inputs_to_state:Dict[str,str],scope:List[str],):"""Sink params, buffers, and constants from graph inputs into get_attr nodes. Exported modules are purely functional, so they pass their parameters and buffers in as inputs to the graph. To replicate eager's semantics, we need to get them from the module state via get_attr instead. module: GraphModule, potentially containining nested submodules. inputs_to_state: mapping graph input names to the corresponding key in the state_dict. scope: tracks where we are in the module hierarchy, so that we can emit the right `getattr(self, "foo.bar")` calls, etc. """# We need to use _modules here instead of named_children(), because we# explicitly want duplicate modules to show up in the traversal.forname,submoduleinmodule._modules.items():_sink_params(cast(torch.nn.Module,submodule),inputs_to_state,scope+[name])ifnothasattr(module,"graph"):# Not all modules have graphs defined, if they are empty modules with no operations (like ParameterList)returngraph=module.graphinputs=list(filter(lambdan:n.op=="placeholder",graph.nodes))the_last_input=inputs[-1]# Also remove from call_module nodescall_module_nodes=filter(lambdan:n.op=="call_module",graph.nodes)fornodeincall_module_nodes:node.args=tuple(filter(lambdan:n.namenotininputs_to_state,node.args))fornodeininputs:ifnode.namenotininputs_to_state:continueiflen(node.users)>0:state_name=inputs_to_state[node.name].split(".")# If there's a mismatch beteewn scope name and state name, then there must be multuple scopes# pointing to the same state name, meaning some modules are shared. In such case, we can simply# skip updating the current node because another later iteration will take care of this input# node when the unique match between scope and state name occurs.# To make sure this always happen, we should enforce the invariant that no placeholder node# in the unflattened graph appears in inputs_to_state dict, which means all the extra input# nodes have been handled.ifstate_name[:len(scope)]!=scope:continueattr_path=state_name[len(scope):]state_attr=_recursive_getattr(module,attr_path)assertisinstance(state_attr,(torch.Tensor,torch.ScriptObject))# Make sure the newly created get_attr node is placed after the last placeholder nodewithgraph.inserting_after(the_last_input):new_node=graph.create_node("get_attr",".".join(attr_path))node.replace_all_uses_with(new_node,propagate_meta=True)graph.erase_node(node)ifisinstance(module,InterpreterModule):module.finalize()def_recursive_getattr(obj,attr_path):forattrinattr_path:obj=getattr(obj,attr)returnobj
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.