# mypy: ignore-errorsimportcollectionsimportcopyimportdisimportenumimportinspectimportloggingimportoperatorimportsysfromdataclassesimportfields,is_dataclassfromtypingimportAny,Callable,Dict,Iterator,Optional,OrderedDict,Tupleimporttorchimporttorch.fx.tracebackasfx_tracebackfromtorch.utils._tracebackimportCapturedTracebackfrom._compatibilityimportcompatibilityfrom.graphimportGraph,magic_methods,reflectable_magic_methodsfrom.nodeimportArgument,base_types,map_aggregate,Node,Targetfrom.operator_schemasimportcheck_for_mutable_operation__all__=["TracerBase","GraphAppendingTracer","TraceError","Proxy","MetaProxy","Attribute","ParameterProxy","Scope","ScopeContextManager",]log=logging.getLogger(__name__)@compatibility(is_backward_compatible=False)classScope:"""Scope object that records the module path and the module type of a module. Scope is used to track the information of the module that contains a Node in a Graph of GraphModule. For example:: class Sub(torch.nn.Module): def forward(self, x): # This will be a call_method Node in GraphModule, # scope for this would be (module_path="sub", module_type=Sub) return x.transpose(1, 2) class M(torch.nn.Module): def __init__(self) -> None: self.sub = Sub() def forward(self, x): # This will be a call_method Node as well, # scope for this would be (module_path="", None) x = x.transpose(1, 2) x = self.sub(x) return x """def__init__(self,module_path:str,module_type:Any):super().__init__()self.module_path=module_pathself.module_type=module_type@compatibility(is_backward_compatible=False)classScopeContextManager:"""A context manager to track the Scope of Node during symbolic tracing. When entering a forward function of a Module, we'll update the scope information of the current module, and when we exit, we'll restore the previous scope information. """def__init__(self,scope:Scope,current_scope:Scope,):super().__init__()# Keep a copy of prev scope to restore on exitself._prev_scope=copy.copy(scope)# Update scope to current scopescope.module_path=current_scope.module_pathscope.module_type=current_scope.module_type# Save a reference so we can restore itself._scope=scopedef__enter__(self):returnself._scopedef__exit__(self,*args):self._scope.module_path=self._prev_scope.module_pathself._scope.module_type=self._prev_scope.module_typereturn_COPY_META_FIELDS=["nn_module_stack","torch_fn","source_fn_stack","original_aten","recompute","ac_graph_id","has_backward_hook","from_node","quantization_tag",# TODO deprecated"_numeric_debug_handle",# TODO deprecated"custom","partitioner_tag",]@compatibility(is_backward_compatible=True)classTracerBase:graph:Graphrecord_stack_traces:bool=False# Feature flag for mutable schema checking# Enableby default in 1.12check_mutable_operations:bool=False# Feature flag for assert tracingtrace_asserts:bool=False# Feature flag for proxying accesses to buffer valuesproxy_buffer_attributes:bool=False# Name of the function to be traced. It will only be used when# ``root`` is an instance of ``nn.Module``traced_func_name:str="forward"# Maps the containing module's name to the operator namescope:Scope# Records the module call stackmodule_stack:OrderedDict[str,Tuple[str,Any]]# Mapping of node name to module scopenode_name_to_scope:Dict[str,Tuple[str,type]]@compatibility(is_backward_compatible=True)defcreate_node(self,kind:str,target:Target,args:Tuple[Argument,...],kwargs:Dict[str,Argument],name:Optional[str]=None,type_expr:Optional[Any]=None,)->Node:""" Inserts a graph node given target, args, kwargs, and name. This method can be overridden to do extra checking, validation, or modification of values used in node creation. For example, one might want to disallow in-place operations from being recorded. """ifkind=="call_function"andself.check_mutable_operations:check_for_mutable_operation(target,args,kwargs)node=self.graph.create_node(kind,target,args,kwargs,name,type_expr)# TODO node_name_to_scope will be depreciated in favor of# node.meta['nn_module_stack']self.node_name_to_scope[node.name]=(self.scope.module_path,self.scope.module_type,)# Optionally set stack trace on the created Node for debugging purposesiffx_traceback.has_preserved_node_meta():current_meta:Dict[str,Any]=fx_traceback.get_current_meta()stack_trace=current_meta.get("stack_trace")ifstack_trace:node.stack_trace=stack_trace# Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta# If other meta fields are needed, they can be added hereforfieldin_COPY_META_FIELDS:iffieldincurrent_meta:node.meta[field]=copy.copy(current_meta[field])# Here we decrement to account for the sequence_nr having# just been incremented while tracing this lowered aten op.new_seq_nr=torch.autograd._get_sequence_nr()-1# The sequence_nr increments every time a new autograd Node# is created. During the FWD pass we store the sequence_nr# corresponding to the last autograd Node created on this fx# node's meta. A single aten op can create multiple autograd# nodes as is the case with in-place foreach ops. During the# BWD pass we retrieve the sequence_nr stored on the current# executing autograd Node. See NOTE [ Sequence Number ].ifcurrent_meta.get("in_grad_fn",0)>0:new_seq_nr=current_meta["grad_fn_seq_nr"][-1]node.meta["seq_nr"]=new_seq_nrelifself.module_stack:node.meta["nn_module_stack"]=copy.copy(self.module_stack)log.debug("create_node %s",node)returnnode@compatibility(is_backward_compatible=True)defproxy(self,node:Node)->"Proxy":returnProxy(node,self)@compatibility(is_backward_compatible=True)defcreate_proxy(self,kind:str,target:Target,args:Tuple[Any,...],kwargs:Dict[str,Any],name:Optional[str]=None,type_expr:Optional[Any]=None,# fix noqa when updating bc testsproxy_factory_fn:Callable[[Node],"Proxy"]=None,# noqa: RUF013):""" Create a Node from the given arguments, then return the Node wrapped in a Proxy object. If kind = 'placeholder', then we're creating a Node that represents the parameter of a function. If we need to encode a default parameter, we use the ``args`` tuple. ``args`` is otherwise empty for ``placeholder`` Nodes. """args_=self.create_arg(args)kwargs_=self.create_arg(kwargs)assertisinstance(args_,tuple)assertisinstance(kwargs_,dict)node=self.create_node(kind,target,args_,kwargs_,name,type_expr)ifnotproxy_factory_fn:proxy=self.proxy(node)else:proxy=proxy_factory_fn(node)ifself.record_stack_tracesandnotproxy.node.stack_trace:proxy.node.stack_trace="".join(CapturedTraceback.extract().format())returnproxydef_find_user_frame(self):""" Find the Python stack frame executing the user code during symbolic tracing. """# We have to do a little dance here. Basically, walk up the callstack and# record the first frame not in the pytorch source. This is the frame executing# the user code during tracing.frame=inspect.currentframe()pt_files=["torch/fx/proxy.py","torch/fx/_symbolic_trace.py","torch/fx/experimental/proxy_tensor.py","torch/_ops.py","torch/_tensor.py","torch/utils/_python_dispatch.py","torch/_prims_common/wrappers.py","torch/_refs/__init__.py","torch/_refs/nn/functional/__init__.py","torch/utils/_stats.py",]whileframe:frame=frame.f_backifframeandall(notframe.f_code.co_filename.endswith(file)forfileinpt_files):breakifnotframe:returnNonereturnframe@compatibility(is_backward_compatible=True)defcreate_arg(self,a:Any)->Argument:""" A method that lowers the objects seen as arguments during symbolic evaluation into Argument types that can be stored in IR. Can be override to support more trace-specific types. """ifisinstance(a,Proxy):returna.node# most common arg type goes firstelifhasattr(a,"__fx_create_arg__"):returna.__fx_create_arg__(self)# aggregateselifisinstance(a,tuple):ifhasattr(a,"_fields"):# NamedTuple constructors don't seem to like getting a generator# expression as an argument to their constructor, so build this# intermediate tuple and unpack it into the NamedTuple constructorargs=[self.create_arg(elem)forelemina]returntype(a)(*args)# type: ignore[arg-type]returntype(a)([self.create_arg(elem)forelemina])elifisinstance(a,list):return[self.create_arg(elem)forelemina]elifisinstance(a,dict):defno_node(arg):ifisinstance(arg,Node):raiseRuntimeError("Keys for dictionaries used as an argument cannot contain a "f"Node. Got key: {k}")r={}fork,vina.items():# Check for invalid dict keys. We do not want a Proxy to appear# anywhere within the key. Since keys can be collection types,# we iterate through the key with map_aggregatek=self.create_arg(k)map_aggregate(k,no_node)r[k]=self.create_arg(v)returnrelifisinstance(a,slice):returnslice(self.create_arg(a.start),self.create_arg(a.stop),self.create_arg(a.step),)elifisinstance(a,range):returnrange(self.create_arg(a.start),self.create_arg(a.stop),self.create_arg(a.step),)elifisinstance(a,(torch._ops.OpOverload,torch._ops.HigherOrderOperator)):returnaelifis_dataclass(a):kwargs={field.name:self.create_arg(getattr(a,field.name))forfieldinfields(a)}returnself.create_node("call_function",a.__class__,(),kwargs)elifisinstance(a,(*base_types,enum.Enum))oraisNoneorais...:returnaraiseNotImplementedError(f"argument of type: {type(a)}")@compatibility(is_backward_compatible=True)defto_bool(self,obj:"Proxy")->bool:"""Called when a proxy object is being converted to a boolean, such as when used in control flow. Normally we don't know what to do because we don't know the value of the proxy, but a custom tracer can attach more information to the graph node using create_node and can choose to return a value. """raiseTraceError("symbolically traced variables cannot be used as inputs to control flow")@compatibility(is_backward_compatible=True)defiter(self,obj:"Proxy")->Iterator:"""Called when a proxy object is being iterated over, such as when used in control flow. Normally we don't know what to do because we don't know the value of the proxy, but a custom tracer can attach more information to the graph node using create_node and can choose to return an iterator. """raiseTraceError("Proxy object cannot be iterated. This can be ""attempted when the Proxy is used in a loop or"" as a *args or **kwargs function argument. ""See the torch.fx docs on pytorch.org for a ""more detailed explanation of what types of ""control flow can be traced, and check out the"" Proxy docstring for help troubleshooting ""Proxy iteration errors")@compatibility(is_backward_compatible=True)defkeys(self,obj:"Proxy")->Any:"""Called when a proxy object is has the keys() method called. This is what happens when ** is called on a proxy. This should return an iterator it ** is suppose to work in your custom tracer. """returnAttribute(obj,"keys")()# used in Proxy object when just appending to the graph while not tracing.@compatibility(is_backward_compatible=True)classGraphAppendingTracer(TracerBase):def__init__(self,graph:Graph):super().__init__()self.graph=graphself.scope=Scope("",None)self.module_stack=collections.OrderedDict()self.node_name_to_scope={}@compatibility(is_backward_compatible=False)defassert_fn(x):assertx@compatibility(is_backward_compatible=True)classTraceError(ValueError):pass
[docs]@compatibility(is_backward_compatible=True)classProxy:""" ``Proxy`` objects are ``Node`` wrappers that flow through the program during symbolic tracing and record all the operations (``torch`` function calls, method calls, operators) that they touch into the growing FX Graph. If you're doing graph transforms, you can wrap your own ``Proxy`` method around a raw ``Node`` so that you can use the overloaded operators to add additional things to a ``Graph``. ``Proxy`` objects cannot be iterated. In other words, the symbolic tracer will throw an error if a ``Proxy`` is used in a loop or as an ``*args``/``**kwargs`` function argument. There are two main ways around this: 1. Factor out the untraceable logic into a top-level function and use ``fx.wrap`` on it. 2. If the control flow is static (i.e. the loop trip count is based on some hyperparameter), the code can be kept in its original position and refactored into something like:: for i in range(self.some_hyperparameter): indexed_item = proxied_value[i] For a more detailed description into the Proxy internals, check out the "Proxy" section in `torch/fx/README.md` """@compatibility(is_backward_compatible=True)def__init__(self,node:Node,tracer:"Optional[TracerBase]"=None):iftracerisNone:# This allows you to create a Proxy object around a raw Nodetracer=GraphAppendingTracer(node.graph)self.tracer=tracerself.node=nodedef__repr__(self)->str:returnf"Proxy({self.node.name})"def__getattr__(self,k)->"Attribute":# note: not added to the graph yet, if this is a method call# we peephole optimize to the method invocationreturnAttribute(self,k)def__getstate__(self)->Dict:returnself.__dict__def__deepcopy__(self,memo)->Dict:# We have to explicitly override this method, because otherwise deepcopy# will go to __getattr__(self, "__deepcopy__") and return a# Attribute(__deepcopy__), and may go into an infinite loop in some cases.importcopynew_dict={}fork,vinself.__dict__.items():try:new_obj=copy.deepcopy(v,memo)exceptException:log.warning("Shallow copy %s of Proxy because it cannot be deepcopied. ""Proxy is created for node %s",k,self.node.name,)new_obj=copy.copy(v)new_dict[k]=new_objassert"node"innew_dictassert"tracer"innew_dictnew_proxy=Proxy(new_dict["node"],new_dict["tracer"])fork,vinnew_dict.items():new_proxy.__dict__[k]=vreturnnew_proxydef__setstate__(self,d):# This is called when being unpickled/loaded.self.__dict__=ddef__call__(self,*args,**kwargs)->"Proxy":returnself.tracer.create_proxy("call_method","__call__",(self,)+args,kwargs)def__iter__(self)->Iterator["Proxy"]:frame=inspect.currentframe()assertframeisnotNonecalling_frame=frame.f_backassertcalling_frameisnotNoneinst_list=list(dis.get_instructions(calling_frame.f_code))ifsys.version_info>=(3,11):frombisectimportbisect_leftinst_idx=bisect_left(inst_list,calling_frame.f_lasti,key=lambdax:x.offset)else:inst_idx=calling_frame.f_lasti//2inst=inst_list[inst_idx]ifinst.opname=="UNPACK_SEQUENCE":return(self[i]foriinrange(inst.argval))# type: ignore[index]returnself.tracer.iter(self)def__abs__(self):returnself.tracer.create_proxy("call_function",operator.abs,(self,),{})def__bool__(self)->bool:ifself.tracer.trace_asserts:# check if this boolean is used in an assertion, bytecode pattern for assertions# is pretty stable for Python 3.7--3.9frame=inspect.currentframe()assertframeisnotNonecalling_frame=frame.f_backassertcalling_frameisnotNoneinsts=list(dis.get_instructions(calling_frame.f_code))ifsys.version_info>=(3,11):frombisectimportbisect_leftcur=bisect_left(insts,calling_frame.f_lasti,key=lambdax:x.offset)else:cur=calling_frame.f_lasti//2inst=insts[cur]ifinst.opname=="POP_JUMP_IF_TRUE":first=insts[cur+1]assertinst.argisnotNonelast=insts[inst.arg//2-1]starts_with_assert=(first.opname=="LOAD_GLOBAL"andfirst.argval=="AssertionError"orfirst.opname=="LOAD_ASSERTION_ERROR")ifstarts_with_assertandlast.opname=="RAISE_VARARGS":self.tracer.create_proxy("call_function",assert_fn,(self,),{})returnTruereturnself.tracer.to_bool(self)@compatibility(is_backward_compatible=True)defkeys(self):returnself.tracer.keys(self)def__len__(self):raiseRuntimeError("'len' is not supported in symbolic tracing by default. If you want ""this call to be recorded, please call torch.fx.wrap('len') at ""module scope")@classmethoddef__torch_function__(cls,orig_method,types,args=None,kwargs=None):args=argsifargselse()kwargs=kwargsifkwargselse{}tracers:Dict[Any,None]={}deffind_tracer(a):ifisinstance(a,cls):tracers[a.tracer]=Nonetorch.fx.node.map_aggregate(args,find_tracer)torch.fx.node.map_aggregate(kwargs,find_tracer)iflen(tracers)>1:raiseRuntimeError(f"Found multiple different tracers {list(tracers.keys())} while "f"trying to trace operations {orig_method}")tracer=next(iter(tracers.keys()))ifisinstance(orig_method,torch._C.ScriptMethod):args=(orig_method.owner,)+argsreturntracer.create_proxy("call_method",orig_method.name,args,kwargs)iftorch.overrides.is_tensor_method_or_property(orig_method):returntracer.create_proxy("call_method",orig_method.__name__,args,kwargs)else:ifisinstance(orig_method,torch._ops.HigherOrderOperator):# TODO: Define how to symbolically trace HigherOrderOperatorsraiseRuntimeError("Unable to symbolically trace HigherOrderOperators")returntracer.create_proxy("call_function",orig_method,args,kwargs,name=tracer.graph._target_to_str(orig_method.__name__),)
@compatibility(is_backward_compatible=False)classMetaProxy(Proxy):""" A Proxy subclass that propagates metadata (meta['val']) during graph tracing. """def__init__(self,node:Node,tracer:"Optional[TracerBase]"=None,fake_mode=None):super().__init__(node,tracer)self.fake_mode=fake_modedef__repr__(self)->str:returnf"MetaProxy({self.node.name})"@classmethoddef__torch_function__(cls,orig_method,types,args=None,kwargs=None):args=argsifargselse()kwargs=kwargsifkwargselse{}meta_proxy=Noneforarginargs:ifisinstance(arg,MetaProxy):meta_proxy=argbreakassert(meta_proxyisnotNone),"No MetaProxy found in arguments, but one is expected."proxy=super().__torch_function__(orig_method,types,args,kwargs)withmeta_proxy.fake_mode:proxy.node.meta["val"]=orig_method(*[a.node.meta["val"]ifisinstance(a,Proxy)elseaforainargs],**kwargs,)returnMetaProxy(proxy.node,proxy.tracer,meta_proxy.fake_mode)@compatibility(is_backward_compatible=True)classAttribute(Proxy):@compatibility(is_backward_compatible=True)def__init__(self,root:Proxy,attr:str):self.root=rootself.attr=attrself.tracer=root.tracerself._node:Optional[Node]=None@propertydefnode(self):# the node for attributes is added lazily, since most will just be method calls# which do not rely on the getitem callifself._nodeisNone:self._node=self.tracer.create_proxy("call_function",getattr,(self.root,self.attr),{}).nodereturnself._nodedef__call__(self,*args,**kwargs):returnself.tracer.create_proxy("call_method",self.attr,(self.root,)+args,kwargs)@compatibility(is_backward_compatible=False)classParameterProxy(Proxy):""" A special proxy which lets "shape", "size", "dim", and a few other attribute accesses pass through to the underlying module parameter object, so that conditional tests on these attributes will not throw exception during tracing """def__init__(self,tracer:TracerBase,node:Node,name,param):super().__init__(node,tracer)assertisinstance(param,torch.nn.Parameter)self.param=paramself.name=namedef__repr__(self)->str:returnf"ParameterProxy({self.name})"@propertydefshape(self):returnself.param.shapedefsize(self):returnself.param.size()defdim(self):returnself.param.dim()@propertydefndim(self):returnself.param.ndimdefnumel(self):returnself.param.numel()defnelement(self):returnself.param.nelement()formethodinmagic_methods:def_scope(method):defimpl(*args,**kwargs):tracer=args[0].tracertarget=getattr(operator,method)returntracer.create_proxy("call_function",target,args,kwargs)impl.__name__=methodas_magic=f'__{method.strip("_")}__'setattr(Proxy,as_magic,impl)_scope(method)def_define_reflectable(orig_method_name):method_name=f'__r{orig_method_name.strip("_")}__'defimpl(self,rhs):target=getattr(operator,orig_method_name)returnself.tracer.create_proxy("call_function",target,(rhs,self),{})impl.__name__=method_nameimpl.__qualname__=method_namesetattr(Proxy,method_name,impl)fororig_method_nameinreflectable_magic_methods:_define_reflectable(orig_method_name)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.