[docs]@compatibility(is_backward_compatible=True)classInterpreter:""" An Interpreter executes an FX graph Node-by-Node. This pattern can be useful for many things, including writing code transformations as well as analysis passes. Methods in the Interpreter class can be overridden to customize the behavior of execution. The map of overrideable methods in terms of call hierarchy:: run() +-- run_node +-- placeholder() +-- get_attr() +-- call_function() +-- call_method() +-- call_module() +-- output() Example: Suppose we want to swap all instances of ``torch.neg`` with ``torch.sigmoid`` and vice versa (including their ``Tensor`` method equivalents). We could subclass Interpreter like so:: class NegSigmSwapInterpreter(Interpreter): def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any: if target == 'neg': call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) input = torch.randn(3, 4) result = NegSigmSwapInterpreter(gm).run(input) torch.testing.assert_close(result, torch.neg(input).sigmoid()) Args: module (GraphModule): The module to be executed garbage_collect_values (bool): Whether to delete values after their last use within the Module's execution. This ensures optimal memory usage during execution. This can be disabled to, for example, examine all of the intermediate values in the execution by looking at the ``Interpreter.env`` attribute. """@compatibility(is_backward_compatible=True)def__init__(self,module:GraphModule,garbage_collect_values:bool=True):assertisinstance(module,GraphModule)self.module=moduleself.submodules=dict(self.module.named_modules())self.env:Dict[Node,Any]={}self.name="Interpreter"self.garbage_collect_values=garbage_collect_valuesifself.garbage_collect_values:# Run through reverse nodes and record the first instance of a use# of a given node. This represents the *last* use of the node in the# execution order of the program, which we will use to free unused# valuesnode_to_last_use:Dict[Node,Node]={}self.user_to_last_uses:Dict[Node,List[Node]]={}defregister_last_uses(n:Node,user:Node):ifnnotinnode_to_last_use:node_to_last_use[n]=userself.user_to_last_uses.setdefault(user,[]).append(n)fornodeinreversed(self.module.graph.nodes):map_arg(node.args,lambdan:register_last_uses(n,node))map_arg(node.kwargs,lambdan:register_last_uses(n,node))
[docs]@compatibility(is_backward_compatible=True)defrun(self,*args,initial_env:Optional[Dict[Node,Any]]=None,enable_io_processing:bool=True)->Any:""" Run `module` via interpretation and return the result. Args: *args: The arguments to the Module to run, in positional order initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution. This is a dict mapping `Node` to any value. This can be used, for example, to pre-populate results for certain `Nodes` so as to do only partial evaluation within the interpreter. enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and process_outputs function first before using them. Returns: Any: The value returned from executing the Module """self.env=initial_envifinitial_envisnotNoneelse{}# Positional function args are consumed left-to-right by# `placeholder` nodes. Use an iterator to keep track of# position and extract those values.ifenable_io_processing:args=self.module.graph.process_inputs(*args)self.args_iter:Iterator[Any]=iter(args)pbar=tqdm(total=len(self.module.graph.nodes),desc=f"{self.name}: {str(list(self.module.graph.nodes))ifconfig.verbose_progresselse''}",initial=0,position=0,leave=True,disable=config.disable_progress,delay=0)fornodeinself.module.graph.nodes:pbar.update(1)ifnodeinself.env:# Short circuit if we have this value. This could# be used, for example, for partial evaluation# where the caller has pre-populated `env` with# values for a subset of the program.continuetry:self.env[node]=self.run_node(node)exceptExceptionase:msg=f"While executing {node.format_node()}"msg='{}\n\n{}'.format(e.args[0],msg)ife.argselsestr(msg)msg+=f"\nOriginal traceback:\n{node.stack_trace}"e.args=(msg,)+e.args[1:]ifisinstance(e,KeyError):raiseRuntimeError(*e.args)fromeraiseifself.garbage_collect_values:forto_deleteinself.user_to_last_uses.get(node,[]):delself.env[to_delete]ifnode.op=='output':output_val=self.env[node]returnself.module.graph.process_outputs(output_val)ifenable_io_processingelseoutput_val
[docs]@compatibility(is_backward_compatible=True)defrun_node(self,n:Node)->Any:""" Run a specific node ``n`` and return the result. Calls into placeholder, get_attr, call_function, call_method, call_module, or output depending on ``node.op`` Args: n (Node): The Node to execute Returns: Any: The result of executing ``n`` """withself._set_current_node(n):args,kwargs=self.fetch_args_kwargs_from_env(n)assertisinstance(args,tuple)assertisinstance(kwargs,dict)returngetattr(self,n.op)(n.target,args,kwargs)
# Main Node running APIs
[docs]@compatibility(is_backward_compatible=True)defplaceholder(self,target:'Target',args:Tuple[Argument,...],kwargs:Dict[str,Any])->Any:""" Execute a ``placeholder`` node. Note that this is stateful: ``Interpreter`` maintains an internal iterator over arguments passed to ``run`` and this method returns next() on that iterator. Args: target (Target): The call target for this node. See `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for details on semantics args (Tuple): Tuple of positional args for this invocation kwargs (Dict): Dict of keyword arguments for this invocation Returns: Any: The argument value that was retrieved. """assertisinstance(target,str)iftarget.startswith('*'):# For a starred parameter e.g. `*args`, retrieve all# remaining values from the args list.returnlist(self.args_iter)else:try:returnnext(self.args_iter)exceptStopIterationassi:iflen(args)>0:returnargs[0]else:raiseRuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!')fromsi
[docs]@compatibility(is_backward_compatible=True)defget_attr(self,target:'Target',args:Tuple[Argument,...],kwargs:Dict[str,Any])->Any:""" Execute a ``get_attr`` node. Will retrieve an attribute value from the ``Module`` hierarchy of ``self.module``. Args: target (Target): The call target for this node. See `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for details on semantics args (Tuple): Tuple of positional args for this invocation kwargs (Dict): Dict of keyword arguments for this invocation Return: Any: The value of the attribute that was retrieved """assertisinstance(target,str)returnself.fetch_attr(target)
[docs]@compatibility(is_backward_compatible=True)defcall_function(self,target:'Target',args:Tuple[Argument,...],kwargs:Dict[str,Any])->Any:""" Execute a ``call_function`` node and return the result. Args: target (Target): The call target for this node. See `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for details on semantics args (Tuple): Tuple of positional args for this invocation kwargs (Dict): Dict of keyword arguments for this invocation Return Any: The value returned by the function invocation """assertnotisinstance(target,str)# Execute the function and return the resultreturntarget(*args,**kwargs)
[docs]@compatibility(is_backward_compatible=True)defcall_method(self,target:'Target',args:Tuple[Argument,...],kwargs:Dict[str,Any])->Any:""" Execute a ``call_method`` node and return the result. Args: target (Target): The call target for this node. See `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for details on semantics args (Tuple): Tuple of positional args for this invocation kwargs (Dict): Dict of keyword arguments for this invocation Return Any: The value returned by the method invocation """# args[0] is the `self` object for this method callself_obj,*args_tail=args# Execute the method and return the resultassertisinstance(target,str)returngetattr(self_obj,target)(*args_tail,**kwargs)
[docs]@compatibility(is_backward_compatible=True)defcall_module(self,target:'Target',args:Tuple[Argument,...],kwargs:Dict[str,Any])->Any:""" Execute a ``call_module`` node and return the result. Args: target (Target): The call target for this node. See `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for details on semantics args (Tuple): Tuple of positional args for this invocation kwargs (Dict): Dict of keyword arguments for this invocation Return Any: The value returned by the module invocation """# Retrieve executed args and kwargs values from the environment# Execute the method and return the resultassertisinstance(target,str)submod=self.fetch_attr(target)returnsubmod(*args,**kwargs)
[docs]@compatibility(is_backward_compatible=True)defoutput(self,target:'Target',args:Tuple[Argument,...],kwargs:Dict[str,Any])->Any:""" Execute an ``output`` node. This really just retrieves the value referenced by the ``output`` node and returns it. Args: target (Target): The call target for this node. See `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for details on semantics args (Tuple): Tuple of positional args for this invocation kwargs (Dict): Dict of keyword arguments for this invocation Return: Any: The return value referenced by the output node """returnargs[0]
# Helper methods
[docs]@compatibility(is_backward_compatible=True)deffetch_attr(self,target:str):""" Fetch an attribute from the ``Module`` hierarchy of ``self.module``. Args: target (str): The fully-qualified name of the attribute to fetch Return: Any: The value of the attribute. """target_atoms=target.split('.')attr_itr=self.modulefori,atominenumerate(target_atoms):ifnothasattr(attr_itr,atom):raiseRuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")attr_itr=getattr(attr_itr,atom)returnattr_itr
[docs]@compatibility(is_backward_compatible=True)deffetch_args_kwargs_from_env(self,n:Node)->Tuple[Tuple,Dict]:""" Fetch the concrete values of ``args`` and ``kwargs`` of node ``n`` from the current execution environment. Args: n (Node): The node for which ``args`` and ``kwargs`` should be fetched. Return: Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``. """args=self.map_nodes_to_values(n.args,n)assertisinstance(args,tuple)kwargs=self.map_nodes_to_values(n.kwargs,n)assertisinstance(kwargs,dict)returnargs,kwargs
[docs]@compatibility(is_backward_compatible=True)defmap_nodes_to_values(self,args:Argument,n:Node)->Argument:""" Recursively descend through ``args`` and look up the concrete value for each ``Node`` in the current execution environment. Args: args (Argument): Data structure within which to look up concrete values n (Node): Node to which ``args`` belongs. This is only used for error reporting. """defload_arg(n_arg:Node)->Any:ifn_argnotinself.env:raiseRuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() 'f'to diagnose such issues')returnself.env[n_arg]returnmap_arg(args,load_arg)
[docs]@compatibility(is_backward_compatible=True)classTransformer(Interpreter):""" ``Transformer`` is a special type of interpreter that produces a new ``Module``. It exposes a ``transform()`` method that returns the transformed ``Module``. ``Transformer`` does not require arguments to run, as ``Interpreter`` does. ``Transformer`` works entirely symbolically. Example: Suppose we want to swap all instances of ``torch.neg`` with ``torch.sigmoid`` and vice versa (including their ``Tensor`` method equivalents). We could subclass ``Transformer`` like so:: class NegSigmSwapXformer(Transformer): def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: if target == 'neg': call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform() input = torch.randn(3, 4) torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid()) Args: module (GraphModule): The ``Module`` to be transformed. """@compatibility(is_backward_compatible=True)def__init__(self,module):super().__init__(module)self.new_graph=Graph()self.new_graph.set_codegen(module.graph._codegen)classTransformerTracer(Tracer):def__init__(self,graph:Graph):super().__init__()self.graph=graphdefis_leaf_module(self,_,__)->bool:returnTrueself.tracer=TransformerTracer(self.new_graph)self.tracer.root=module
[docs]@compatibility(is_backward_compatible=True)defplaceholder(self,target:'Target',args:Tuple[Argument,...],kwargs:Dict[str,Any])->Proxy:""" Execute a ``placeholder`` node. In ``Transformer``, this is overridden to insert a new ``placeholder`` into the output graph. Args: target (Target): The call target for this node. See `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for details on semantics args (Tuple): Tuple of positional args for this invocation kwargs (Dict): Dict of keyword arguments for this invocation """assertisinstance(target,str)default_value=next(iter(args))ifargselseinspect.Signature.emptyreturnProxy(self.new_graph.placeholder(target,default_value=default_value),self.tracer)
[docs]@compatibility(is_backward_compatible=True)defget_attr(self,target:'Target',args:Tuple[Argument,...],kwargs:Dict[str,Any])->Proxy:""" Execute a ``get_attr`` node. In ``Transformer``, this is overridden to insert a new ``get_attr`` node into the output graph. Args: target (Target): The call target for this node. See `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for details on semantics args (Tuple): Tuple of positional args for this invocation kwargs (Dict): Dict of keyword arguments for this invocation """assertisinstance(target,str)returnProxy(self.new_graph.get_attr(target),self.tracer)
[docs]@compatibility(is_backward_compatible=True)defcall_module(self,target:'Target',args:Tuple[Argument,...],kwargs:Dict[str,Any])->Any:# Override so that the leaf module policy from `self.tracer` is respected.assertisinstance(target,str)submod=self.fetch_attr(target)returnself.tracer.call_module(submod,submod.forward,args,kwargs)
[docs]@compatibility(is_backward_compatible=True)defcall_function(self,target:'Target',args:Tuple[Argument,...],kwargs:Dict[str,Any])->Any:# Override so that functions that were wrapped are still wrapped.returnself.tracer.create_proxy('call_function',target,args,kwargs)
[docs]@compatibility(is_backward_compatible=True)deftransform(self)->GraphModule:""" Transform ``self.module`` and return the transformed ``GraphModule``. """withfx_traceback.preserve_node_meta():result=super().run(enable_io_processing=False)ifresultisnotNone:defstrip_proxy(a:Union[Argument,Proxy])->Any:returna.nodeifisinstance(a,Proxy)elseaself.new_graph.output(map_aggregate(result,strip_proxy))returnGraphModule(self.module,self.new_graph)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.