Source code for torch.fx.experimental.proxy_tensor
# mypy: allow-untyped-decorators# Copyright (c) Facebook, Inc. and its affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.from__future__importannotationsimportfunctoolsimportinspectimportloggingimportoperatorimporttracebackimporttypingimporttyping_extensionsimportwarningsimportweakreffromcollectionsimportdefaultdictfromcontextlibimport_GeneratorContextManager,contextmanager,ExitStack,nullcontextfromdataclassesimportdataclassfromtypingimport(Any,Callable,Dict,Generator,List,Mapping,Optional,overload,Protocol,Sequence,Tuple,Type,TYPE_CHECKING,TypeVar,Union,)fromtyping_extensionsimportConcatenate,ParamSpec,SelffromweakrefimportWeakKeyDictionaryimporttorchimporttorch._opsimporttorch.fxasfximporttorch.fx.tracebackasfx_tracebackimporttorch.utils._pytreeaspytreefromtorchimportSymBool,SymInt,Tensorfromtorch._dispatch.pythonimportenable_python_dispatcherfromtorch._library.fake_class_registryimportFakeScriptObjectfromtorch._loggingimporttrace_structuredfromtorch._subclasses.fake_implsimportfast_detachfromtorch._subclasses.fake_tensorimport(FakeTensor,FakeTensorMode,is_fake,unset_fake_temporarily,)fromtorch._subclasses.meta_utilsimportis_sparse_anyfromtorch.fximportGraphModule,Proxy,Tracerfromtorch.fx.graph_moduleimport_assign_attrfromtorch.fx.nodeimport_side_effectful_need_to_be_preserved_pre_dispatchfromtorch.fx.passes.shape_propimport_extract_tensor_metadatafromtorch.nnimportModulefromtorch.overridesimportTorchFunctionModefromtorch.utils._python_dispatchimport(_disable_infra_mode,_push_mode,_unset_infra_mode,TorchDispatchMode,)fromtorch.utils._statsimportcountfromtorch.utils._thunkimportThunkfromtorch.utils._tracebackimportCapturedTracebackfromtorch.utils.weakimport_WeakHashRef,WeakIdKeyDictionary,WeakTensorKeyDictionaryfrom._backward_stateimportBackwardStatefrom.sym_nodeimportSymNodeifTYPE_CHECKING:importtypesfromcollections.abcimportMutableMappingimportsympyfromtorch._opsimportOpOverloadfromtorch.fx._symbolic_traceimportPHBasefromtorch.typesimportIntLikeType__all__=["PythonKeyTracer","dispatch_trace","make_fx","DecompositionInterpreter","py_sym_types","get_innermost_proxy_mode","get_proxy_mode","handle_sym_dispatch","maybe_enable_thunkify","maybe_disable_thunkify",]_ProxyTracer=Union["PythonKeyTracer","_GraphAppendingTracerEx"]_AnyScriptObject=(torch.ScriptObject,FakeScriptObject)_AnyScriptObjectType=Union[torch.ScriptObject,FakeScriptObject]aten=torch.ops.atenprim=torch.ops.primlog=logging.getLogger(__name__)not_implemented_log=torch._logging.getArtifactLogger(__name__,"not_implemented")CURRENT_DECOMPOSITION_TABLE:Mapping[OpOverload,Callable]={}CONSTANT_NUMEL_LIMIT=1T=TypeVar("T")U=TypeVar("U")_P=ParamSpec("_P")R=TypeVar("R")null_ctx_type=type(nullcontext)# We currently convert all SymInt to proxies before we use them.# This could plausibly be handled at the Dynamo level.pytree.register_pytree_node(torch.Size,lambdaxs:(list(xs),None),lambdaxs,_:tuple(xs),flatten_with_keys_fn=lambdaxs:([(pytree.SequenceKey(i),x)fori,xinenumerate(xs)],None,),serialized_type_name="torch.Size",)deffake_signature(fn:Callable[_P,R],nargs:int)->Callable[_P,R]:"""FX gets confused by varargs, de-confuse it"""argnames=",".join(f"arg{i}"foriinrange(nargs))returneval(f"lambda {argnames}: fn({argnames})",{"fn":fn})@contextmanagerdefdecompose(decomposition_table:Optional[Mapping[OpOverload,Callable]])->Generator[Mapping[OpOverload,Callable],None,None]:globalCURRENT_DECOMPOSITION_TABLEold_decomposition_table=CURRENT_DECOMPOSITION_TABLECURRENT_DECOMPOSITION_TABLE=decomposition_tableor{}try:yieldCURRENT_DECOMPOSITION_TABLEfinally:CURRENT_DECOMPOSITION_TABLE=old_decomposition_table# ensure we cannot collide with other propertiesproxy_slot=object()class_NoDefault:passno_default=_NoDefault()fromtorch.typesimportpy_sym_types,PySymTypeclass_HasMeta(Protocol):meta:Dict[str,PySymType]defis_sym_node(node:_HasMeta)->bool:asserthasattr(node,"meta"),"All nodes traced with proxy_tensor should have meta"return"val"innode.metaandisinstance(node.meta["val"],py_sym_types)@overloaddefset_proxy_slot(obj:Tensor,tracer:_ProxyTracer,proxy:_ProxyTensor)->None:...@overloaddefset_proxy_slot(obj:_AnyScriptObjectType,tracer:_ProxyTracer,proxy:Proxy)->None:...@overloaddefset_proxy_slot(obj:PySymType,tracer:_ProxyTracer,proxy:_PySymProxyType)->None:...defset_proxy_slot(obj:Union[PySymType,_AnyScriptObjectType,Tensor],tracer:_ProxyTracer,proxy:object,)->None:log.debug("set_proxy_slot %s (%s) %s",obj,id(obj),proxy)ifisinstance(obj,Tensor):# We DO want to clobber proxies whenever we run an inplace operation# on a tensor, and it affects the metadata on the proxy.assertisinstance(proxy,_ProxyTensor)tracer.tensor_tracker[obj]=proxyelifisinstance(obj,(_AnyScriptObject)):# We DO want to clobber proxies, with a similar rationale as for tensors.assertisinstance(proxy,Proxy)tracer.script_object_tracker[obj]=proxyelse:# NB: Never clobber pre-existing proxy. Although the proxies# are in principle equivalent, when we do graph partitioning# we need there not to be spurious dependencies on tangent inputs.# This works because primals get their SymInts set first, and# THEN later we allocate tangent inputs. Make sure if a SymInt# is derivable from a primal that we use that.assertisinstance(obj,py_sym_types),type(obj)ifobjnotintracer.symnode_tracker:tracer.symnode_tracker[obj]=typing.cast(_PySymProxyType,proxy)# WAR: python test/dynamo/test_subclasses.py# TestNestedTensor.test_basic_autograd## AOTAutograd doesn't pass the "outer sizes" as an actual argument# to make_fx, but it is made use of internally in AOTAutograd's# call to tensor unflatten. Because the outer sizes isn't passed# as an argument, it is therefore untracked. However, it turns# out you luck out, because *Dynamo* will manually add the outer# sizes as an argument so you can fix up the proxy'ness.## This is probably fixed in# https://github.com/pytorch/pytorch/pull/125941/importsympyifisinstance(obj.node.expr,sympy.Symbol):tracer.sympy_expr_tracker[obj.node.expr]=proxydefhas_proxy_slot(obj:Tensor,tracer:_ProxyTracer)->bool:assertisinstance(obj,(Tensor,SymNode)),type(obj)returnbool(get_proxy_slot(obj,tracer,False,lambda_:True))_PySymProxyType=Thunk[Proxy]@overloaddefget_proxy_slot(obj:Tensor,tracer:_ProxyTracer,)->_ProxyTensor:...@overloaddefget_proxy_slot(obj:Tensor,tracer:_ProxyTracer,default:U,)->Union[_ProxyTensor,U]:...@overloaddefget_proxy_slot(obj:Tensor,tracer:_ProxyTracer,default:U,transform:Callable[[_ProxyTensor],R],)->Union[R,U]:...@overloaddefget_proxy_slot(obj:_AnyScriptObjectType,tracer:_ProxyTracer,)->Proxy:...@overloaddefget_proxy_slot(obj:_AnyScriptObjectType,tracer:_ProxyTracer,default:U,)->Union[Proxy,U]:...@overloaddefget_proxy_slot(obj:_AnyScriptObjectType,tracer:_ProxyTracer,default:U,transform:Callable[[Proxy],R],)->Union[R,U]:...@overloaddefget_proxy_slot(obj:PySymType,tracer:_ProxyTracer,)->_PySymProxyType:...@overloaddefget_proxy_slot(obj:PySymType,tracer:_ProxyTracer,default:T,)->Union[T,_PySymProxyType]:...@overloaddefget_proxy_slot(obj:PySymType,tracer:_ProxyTracer,default:U,transform:Callable[[_PySymProxyType],R],)->Union[R,U]:...# the default argument is what to return if the slot is not set.# the transform argument is handy if you need to extract a subfield from# the successfully looked up result (but NOT the default.)defget_proxy_slot(obj:Union[Tensor,_AnyScriptObjectType,PySymType],tracer:_ProxyTracer,default:object=no_default,transform:Callable=lambdax:x,)->object:tracker:Anyifisinstance(obj,Tensor):tracker=tracer.tensor_trackerelifisinstance(obj,_AnyScriptObject):tracker=tracer.script_object_trackerelse:assertisinstance(obj,py_sym_types),type(obj)tracker=tracer.symnode_trackerifobjnotintracker:# Last ditchifisinstance(obj,py_sym_types)andobj.node.exprintracer.sympy_expr_tracker:value=tracer.sympy_expr_tracker[obj.node.expr]else:ifisinstance(default,_NoDefault):raiseRuntimeError(f"{obj} ({id(obj)})is not tracked with proxy for {tracer}")returndefaultelse:value=tracker[obj]res=transform(value)returnresdefsnapshot_fake(val:Tensor)->Optional[Tensor]:# val.detach() will also eventually call fast_detach(),# but this saves us a full trip into __torch_dispatch__# (snapshot_fake is called a lot)ifisinstance(val,FakeTensor):returnfast_detach(val.fake_mode,val)else:returnval.detach()_ExtractValType=Optional[Union[PySymType,_AnyScriptObjectType,BackwardState,List["_ExtractValType"],Tuple["_ExtractValType",...],Dict[str,"_ExtractValType"],Tensor,int,float,bool,]]defextract_val(val:_ExtractValType)->_ExtractValType:ifis_fake(val):returnsnapshot_fake(val)elifisinstance(val,py_sym_types):returnvalelifisinstance(val,_AnyScriptObject):returnvalelifisinstance(val,BackwardState):returnvalelifisinstance(val,(list,tuple)):returnval.__class__([extract_val(x)forxinval])elifisinstance(val,dict):return{k:extract_val(v)fork,vinval.items()}elifisinstance(val,Tensor):ifnotval.is_sparse:# NB: Kinda hacky, but we should try to get val as the metadata# everywhere# TODO: This doesn't properly track storages. A more robust# approach would be to maintain a per-trace FakeTensorMode and# from_real_tensor to create fake values (don't forget to# snapshot_fake)fake_tensor_mode=FakeTensorMode(allow_fallback_kernels=True)withfake_tensor_mode:returntorch.empty_strided(val.shape,val.stride(),device=val.device,dtype=val.dtype)else:returnNoneelifisinstance(val,(int,float,bool)):returnvalelifvalisNone:returnNonetyping_extensions.assert_never(val)@contextmanagerdef_enable_thunkify(tracer:_ProxyTracer,*,enable:bool=True)->Generator[None,None,None]:""" Enable thunkification inside the context manager. Thunkification prevents SymNode computation from directly being traced into an FX graph; instead, the compute is only added to the graph if it is actually used. This helps us track SymNode compute when it is computed (since we need /something/ to put in the tracker) even if it is unlikely to be used. """old=tracer.enable_thunkifytracer.enable_thunkify=enabletry:yieldfinally:tracer.enable_thunkify=old
[docs]@contextmanagerdefmaybe_disable_thunkify()->Generator[None,None,None]:"""Within a context, disable thunkification. See :func:`maybe_enable_thunkify` for more details. This is helpful if you have a wrapper function which you want to enable thunkification on, but in some segment on the inside (say, the original user function), you want to disable thunkification as you know it is not needed there. """proxy_mode=get_proxy_mode()ifproxy_modeisnotNone:with_enable_thunkify(proxy_mode.tracer,enable=False):yieldelse:yield
[docs]@contextmanagerdefmaybe_enable_thunkify()->Generator[None,None,None]:"""Within this context manager, if you are doing make_fx tracing, we will thunkify all SymNode compute and avoid tracing it into the graph unless it is actually needed. You should prefer to avoid using this as much as possible, as lazy evaluation of SymNode tracing can lead to long chains of thunks which will stack overflow if you evaluate them. However, this is currently sometimes necessary as there are buggy parts of PT2 which will fail with "s0 is not tracked with proxy" error due to insufficient tracing of SymNode computation. """proxy_mode=get_proxy_mode()ifproxy_modeisnotNone:with_enable_thunkify(proxy_mode.tracer):yieldelse:yield
# Note [invariants for node meta 'val']# What invariants do we have for the 'val' set on the FX node? It has accurate# metadata... but only for metadata that exists "below" all other subsystems# (most notably autograd, but also vmap, functorch transforms, etc). This means# you can get the dtype, shape, stride, storage, but you CANNOT get requires_grad,# grad_fn, _base (_base actually may be set due to recursive call to# ADInplaceOrView, but you shouldn't rely on it.)defset_meta(proxy:Proxy,val:_ExtractValType)->Proxy:proxy.node.meta["val"]=extract_val(val)with_enable_thunkify(proxy.tracer):# type: ignore[arg-type]# Best effort tensor_meta setting; prefer using val!ifis_fake(val):proxy.node.meta["tensor_meta"]=_extract_tensor_metadata(val)elifisinstance(val,Tensor)andnotval.is_sparse:proxy.node.meta["tensor_meta"]=_extract_tensor_metadata(val)returnproxydefthunkify(tracer:_ProxyTracer,f:Callable[_P,R],*args:_P.args,**kwargs:_P.kwargs)->Thunk[R]:""" Delays computation of f until it's called again Also caches the result """iftracer.enable_thunkify:returnThunk(functools.partial(f,*args,**kwargs))else:r=f(*args,**kwargs)returnThunk(lambda:r)deftrack_tensor(tensor:Tensor,proxy:Proxy,*,constant:Optional[Tensor],tracer:_ProxyTracer)->None:deftry_set_proxy_slot(outer_s:IntLikeType,proxy_callable:Callable[Concatenate[PySymType,_P],Proxy],*args:_P.args,**kwargs:_P.kwargs,)->None:assertcallable(proxy_callable)ifisinstance(outer_s,SymInt):with_enable_thunkify(tracer):set_proxy_slot(outer_s,tracer,thunkify(tracer,proxy_callable,outer_s,*args,**kwargs),)# The basic idea is that we need to associate each tensor/SymInt# with a Proxy. How do we setup this association? We just store# the proxy on the proxy slot of the object, keyed on the tracer# (so that if we have multiple tracers at the same time, they# don't clobber each other.)fori,sinenumerate(tensor.shape):try_set_proxy_slot(s,lambdax,i:set_meta(tracer.create_proxy("call_function",torch.ops.aten.sym_size.int,(proxy,i),{}),x,),i,)ifnotis_sparse_any(tensor):fori,sinenumerate(tensor.stride()):try_set_proxy_slot(s,lambdax,i:set_meta(tracer.create_proxy("call_function",torch.ops.aten.sym_stride.int,(proxy,i),{}),x,),i,)try_set_proxy_slot(tensor.numel(),lambdax:set_meta(tracer.create_proxy("call_function",torch.ops.aten.sym_numel.default,(proxy,),{}),x,),)ifnotis_sparse_any(tensor):try_set_proxy_slot(tensor.storage_offset(),lambdax:set_meta(tracer.create_proxy("call_function",torch.ops.aten.sym_storage_offset.default,(proxy,),{},),x,),)set_proxy_slot(tensor,tracer,_ProxyTensor(proxy,constant))_NestedProxys=Union[Proxy,Sequence["_NestedProxys"],Mapping[object,"_NestedProxys"]]_NestedTensors=Union[Tensor,Sequence["_NestedTensors"],Mapping[object,"_NestedTensors"]]deftrack_tensor_tree(inner_res:T,proxy_res:_NestedProxys,*,constant:Optional[_NestedTensors],tracer:_ProxyTracer,)->T:# NB: We call set_unbacked_bindings only on the *topmost* call to# track_tensor_tree, not recursive calls. This is because there must# be only ONE unbacked_binding proxy call, and it should be the one# where all of the unbacked SymInts actually first come into existence.# If you call this again on the inner proxies for the tuple projections,# you will have multiple unbacked_bindings for the same symbol, but# they're not going to show up anywhere.## I was briefly deceived into setting unbacked bindings recursively when# working on https://github.com/pytorch/pytorch/pull/133585 because I# observed that some extra unbacked bindings were needed to handle some# higher order operator code. But actually it looks like this was# just an unrelated bug that needed to be fixed separately._set_unbacked_bindings(inner_res,proxy_res)defwrap_with_proxy(e:object,proxy:_NestedProxys,constant:Optional[_NestedTensors])->None:ifisinstance(e,Tensor):assertisinstance(proxy,Proxy)assertconstantisNoneorisinstance(constant,Tensor)track_tensor(e,proxy,tracer=tracer,constant=constant)set_meta(proxy,e)elifisinstance(e,py_sym_types):assertisinstance(proxy,Proxy)# NB: eagerly set meta here, so that the numbering is in orderset_meta(proxy,e)set_proxy_slot(e,tracer,thunkify(tracer,lambda:proxy))elifisinstance(e,_AnyScriptObject):assertisinstance(proxy,Proxy)set_proxy_slot(e,tracer,proxy)set_meta(proxy,e)elifisinstance(e,(tuple,list)):# example use case: allreduce_ returns ([tensor], work)ifisinstance(proxy,fx.Proxy):set_meta(proxy,e)defget_constant(c:Optional[_NestedTensors],idx:int)->Optional[_NestedTensors]:ifcisNone:returnNoneelse:assertisinstance(c,(list,tuple))returnc[idx]foridx,eeinenumerate(e):# Use an indexer here - if proxy is a List then it will unwrap# it. If it's a Proxy then it will proxy the getelem.wrap_with_proxy(ee,proxy[idx],get_constant(constant,idx))# type: ignore[index]elifisinstance(e,dict):# example use case: triton_kernel_wrapper takes arguments as kwargs# In theory we could support const-prop when proxy-tensor-tracing# operators that returns dicts of tensors, but we have no use case# for it today (since the only op we currently trace that can# return a dict is triton_kernel_wrapper_functional/mutation,# which does not participate in const-prop)assertconstantisNoneifisinstance(proxy,fx.Proxy):set_meta(proxy,e)forkey,valine.items():wrap_with_proxy(val,proxy[key],None)# type: ignore[index]elifisinstance(e,BackwardState):assertisinstance(proxy,Proxy)set_meta(proxy,e)e.proxy=proxyelse:# intentionally pass on primitivespasswrap_with_proxy(inner_res,proxy_res,constant)returninner_res@dataclassclass_ProxyTensor:proxy:Proxyconstant:Optional[Tensor]deffetch_sym_proxy(tracer:_ProxyTracer,)->Callable[[PySymType],Union[bool,int,float,Proxy]]:definner(e:PySymType)->Union[int,bool,float,Proxy]:n=e.nodeifn.constantisnotNone:returnn.constantife.node.expr.is_number:ifisinstance(e,SymBool):returnbool(e.node.expr)elifisinstance(e,SymInt):returnint(e.node.expr)returnfloat(e.node.expr)else:assertisinstance(e,py_sym_types)# NB: we REQUIRE all symints to be trackedreturnget_proxy_slot(e,tracer).force()returninner@overloaddeffetch_object_proxy(tracer:_ProxyTracer,t:Tensor)->Union[_ProxyTensor,Tensor]:...@overloaddeffetch_object_proxy(tracer:_ProxyTracer,t:_AnyScriptObjectType)->Union[Proxy,_AnyScriptObjectType]:...@overloaddeffetch_object_proxy(tracer:_ProxyTracer,t:PySymType)->Union[_PySymProxyType,PySymType]:...deffetch_object_proxy(tracer:_ProxyTracer,t:Union[Tensor,_AnyScriptObjectType,PySymType])->object:returnget_proxy_slot(t,tracer,t)HANDLED_TYPES=(Tensor,torch.nn.Parameter,FakeTensor)def_maybe_record_pointwise_barrier(func:object,proxy_mode:ProxyTorchDispatchMode)->None:""" Records pointwise operators in user program (non decomposed) that were output in fp16/bf16 """ifproxy_mode.decomp_layersornotproxy_mode.emulate_precision_casts:returnif(notisinstance(func,torch._ops.OpOverload)ortorch.Tag.pointwisenotinfunc.tags):returnlast_node=next(iter(reversed(proxy_mode.tracer.graph.nodes)))t=last_node.meta.get("val")ifnotisinstance(t,torch.Tensor)ort.dtypenotin(torch.bfloat16,torch.float16,):returnlast_node.meta["low_precision_pointwise_barrier"]=Truedefproxy_call(proxy_mode:ProxyTorchDispatchMode,func:OpOverload,pre_dispatch:bool,args:Tuple[object,...],kwargs:Dict[str,object],)->object:unrecognized_types:List[Type]=[]flat_args_kwargs,spec=pytree.tree_flatten((args,kwargs))defcan_handle_tensor(x:Tensor)->bool:r=type(x)inHANDLED_TYPESorhas_proxy_slot(x,proxy_mode.tracer)ifproxy_mode._allow_fake_constant:r=rortype(x)in(torch._subclasses.FakeTensor,)ifnotr:unrecognized_types.append(type(x))returnr# If there are any tensor subclasses, we need to handle those tensor subclasses first# TODO: we could use types to test thisifnotall(can_handle_tensor(x)forxinflat_args_kwargsifisinstance(x,Tensor)):not_implemented_log.debug("ProxyTensorMode tensors without proxy had unrecognized subclasses: %s",unrecognized_types,)returnNotImplementedr=maybe_handle_decomp(proxy_mode,func,args,kwargs)ifrisnotNotImplemented:_maybe_record_pointwise_barrier(func,proxy_mode)returnr# For pre-autograd tracing, we do not want to run CompositeImplicit decomps.ifnotpre_dispatchandfuncnotin[torch.ops.aten.size.default,torch.ops.aten.stride.default,torch.ops.aten.storage_offset.default,]:withproxy_mode:r=func.decompose(*args,**kwargs)ifrisnotNotImplemented:returnriffuncistorch.ops.aten.is_nonzero.default:withproxy_mode:return(args[0]!=0).item()# type: ignore[attr-defined]tracer=proxy_mode.tracerf_flat_args_kwargs=[(fetch_object_proxy(tracer,x)ifisinstance(x,(Tensor,_AnyScriptObject))elsex)forxinflat_args_kwargs]# If there are SymInts, we also should not consider this constant.# However, fake tensor handling of SymInts is sufficiently broken that# I couldn't write a test for this caseall_constant=(notany(t.constantisNonefortinf_flat_args_kwargsifisinstance(t,_ProxyTensor))# TODO: maybe constant SymInts should also be allowed? Not sure if# this can happenandnotany(isinstance(x,py_sym_types)forxinflat_args_kwargs))iftorch.Tag.data_dependent_outputinfunc.tags:# Check if all of the Tensor inputs are constantsifall_constant:const_flat_args_kwargs=[t.constantifisinstance(t,_ProxyTensor)elsetfortinf_flat_args_kwargs]const_args,const_kwargs=pytree.tree_unflatten(const_flat_args_kwargs,spec)withunset_fake_temporarily():returnfunc(*const_args,**const_kwargs)# If any of the Tensor inputs are "real" (not FakeTensor), we may# incorrectly burn in constants by allowing this access. Raise# an error in this caseifproxy_mode._error_on_data_dependent_opsandpytree.tree_all_only(Tensor,lambdat:notis_fake(t),(args,kwargs)):raiseRuntimeError(f"It appears that you're trying to get value out of a tracing tensor with {func} - erroring out! ""It's likely that this is caused by data-dependent control flow or similar. ""It may be possible to trace this with dynamic shapes; try setting tracing_mode='symbolic' ""in your make_fx call.")proxy_flat_args_kwargs=[e.proxyifisinstance(e,_ProxyTensor)elseeforeinf_flat_args_kwargs]proxy_flat_args_kwargs=[(fetch_sym_proxy(proxy_mode.tracer)(e)ifisinstance(e,py_sym_types)elsee)foreinproxy_flat_args_kwargs]proxy_args,proxy_kwargs=pytree.tree_unflatten(proxy_flat_args_kwargs,spec)# When we trace through a torch.tensor invocation, you never actually# see a torch.ops.aten.tensor call. Instead, the way this function is# implemented internally is that we allocate a plain tensor (this is# *guaranteed* to be a plain tensor, we disable all modes when doing# so), and then call at::lift_fresh on it (to give modes a chance to do# their stuff). Furthermore, the tensor argument to lift_fresh is guaranteed# to be freshly allocated, so we want lift_fresh to be a no-op (directly# returning the input argument).## Here is the basic problem: when we trace this sequence of executions# into an FX graph, what happens to this call sequence? Traditionally,# tensor constants get interned as buffers on the FX GraphModule. But# this is dangerous. Consider:## x = torch.tensor(1)# x.add_(2)## Naively, this traces into:## t = self._tensor_constant0 # initialized to torch.tensor(1)# x = torch.ops.aten.lift_fresh(t)# x.add_(2)## If lift_fresh returns t directly, the subsequent add_ call will# modify the tensor constant. Really, the problem is we've violated# the invariant the argument to lift is fresh. So what we should# preserve the invariant by replacing lift_fresh with lift_fresh_copy:## t = self._tensor_constant0 # initialized to torch.tensor(1)# x = torch.ops.aten.lift_fresh_copy(t)# x.add_(2)## This is what the overload modification does.iffuncistorch.ops.aten.lift_fresh.default:func=torch.ops.aten.lift_fresh_copy.defaultproxy_out=proxy_mode.tracer.create_proxy("call_function",func,proxy_args,proxy_kwargs,name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__),)with_enable_thunkify(proxy_mode.tracer):out=func(*args,**kwargs)# In some circumstances, we will be tracing in a situation where a tensor# is *statically* known to be a constant (currently, this only happens if# you run torch.tensor; deterministic factory functions like torch.arange# don't get this treatment). When the tensor in question is small, it's# helpful to due constant propagation in case we call item() (in which# case we can return the constant value that is known, rather than give# an error.) The logic here tests if constant propagation is possible# (because all of the inputs are constant). If so, we disable fake tensor# mode (if it is on) and do true compute on the constant.## It's worth highlighting that we're making a policy decision here.# There is a potential that the tensor is actually quite large, and we# don't actually want to run the compute. The tensor being quite large# is one of the reasons why factory functions don't get this treatment# (since they can be quite large; if a parameter is initialized to a# constant value it will be!) Similarly, there is also a potential# to run an operator that blows up the size of a small tensor; we don't# protect against this case, but we could force, e.g., only single# element constant computation by testing the numel of the result before# propagating const-ness. Similarly, we don't require the constant to# live on CPU, but we could.any_constant=any(t.constantisnotNonefortinf_flat_args_kwargsifisinstance(t,_ProxyTensor))constant=Nonedeftensor_numel_in_limit(t:Tensor)->bool:returnt.numel()<=CONSTANT_NUMEL_LIMIT# If this is a lift, the input tensor is guaranteed to be a# constant, so we keep a copy of the original argument along so# we can query it if we're asked to item() it at some later pointif(funcistorch.ops.aten.lift_fresh_copy.defaultandout.numel()<=CONSTANT_NUMEL_LIMIT):withunset_fake_temporarily():assertisinstance(args[0],(Proxy,Tensor)),type(args[0])constant=args[0].clone()elif(torch.Tag.nondeterministic_seedednotinfunc.tagsandall_constantandany_constantandpytree.tree_all_only(Tensor,tensor_numel_in_limit,out)):# NB: do NOT include factories as constantswithunset_fake_temporarily():const_flat_args_kwargs=[t.constantifisinstance(t,_ProxyTensor)elsetfortinf_flat_args_kwargs]const_args,const_kwargs=pytree.tree_unflatten(const_flat_args_kwargs,spec)constant=func(*const_args,**const_kwargs)else:constant=Nonetrack_tensor_tree(out,proxy_out,constant=constant,tracer=tracer)_maybe_record_pointwise_barrier(func,proxy_mode)returnoutclass_SymNodeDict:""" Wrapper around a dictionary that will hash SymInts with their nodes """def__init__(self)->None:self.sym_node_dict:Dict[PySymType,_PySymProxyType]={}def__setitem__(self,key:PySymType,value:_PySymProxyType)->None:self.sym_node_dict[key.node]=valuedef__getitem__(self,key:PySymType)->_PySymProxyType:returnself.sym_node_dict[key.node]def__contains__(self,key:PySymType)->bool:returnkey.nodeinself.sym_node_dictdefget(self,key:PySymType,default:Optional[_PySymProxyType]=None)->_PySymProxyType:# dict.get()'s annotation doesn't accept `None` when the value type# isn't Optional.returnself.sym_node_dict.get(key.node,default)# type: ignore[arg-type]def__iter__(self)->Any:raiseNotImplementedErrordef__len__(self)->int:returnlen(self.sym_node_dict)classPythonKeyTracer(Tracer):script_object_tracker:MutableMapping[_AnyScriptObjectType,Proxy]symnode_tracker:_SymNodeDictsympy_expr_tracker:Dict[sympy.Symbol,object]tensor_tracker:MutableMapping[Tensor,_ProxyTensor]torch_fn_counts:Dict[OpOverload,int]enable_thunkify:bool=Falsedef__init__(self)->None:super().__init__(autowrap_modules=())# type: ignore[arg-type]self.tensor_tracker=WeakTensorKeyDictionary()self.symnode_tracker=_SymNodeDict()self.script_object_tracker=WeakIdKeyDictionary(dict=None,ref_type=_WeakHashRef)self.sympy_expr_tracker=dict()# Stores the torch function that was called during tracingself.torch_fn_metadata=None# Stores the counts for every torch function called. This is to help# distinguish between different calls to the same torch function.self.torch_fn_counts={}self.enable_thunkify=False# In general, we don't want to make modules leaves. In principle, users of# this tracer might want to override this in order to turn a couple specific# modules into leaves in the traced graph.defcall_module(self,m:Module,forward:Callable[...,Any],args:Tuple[Any,...],kwargs:Dict[str,Any],)->Any:returnforward(*args,**kwargs)# We don't want to turn getattr calls into proxies. So we just return the actual value.defgetattr(self,attr:str,attr_val:object,parameter_proxy_cache:Dict[str,Proxy])->object:returnattr_valdefcreate_arg(self,a:object)->fx.node.Node:ifisinstance(a,torch.nn.Parameter):forn,pinself.root.named_parameters():ifaisp:returnself.create_node("get_attr",n,(),{})qualname=self.get_fresh_qualname("_param_constant")setattr(self.root,qualname,a)returnself.create_node("get_attr",qualname,(),{})elifisinstance(a,py_sym_types):asserta.node.constantisnotNonereturna.node.constantreturnsuper().create_arg(a)# type: ignore[return-value]@overloaddefunwrap_proxy(self,e:Tensor)->Union[Proxy,Tensor]:...@overloaddefunwrap_proxy(self,e:PySymType)->Union[Proxy,PySymType]:...@overloaddefunwrap_proxy(self,e:_AnyScriptObjectType)->Union[Proxy,_AnyScriptObjectType]:...defunwrap_proxy(self,e:T)->object:ifisinstance(e,Tensor):returnget_proxy_slot(e,self,e,lambdax:x.proxy)elifisinstance(e,py_sym_types):returnget_proxy_slot(e,self,e,lambdae:e.force())elifisinstance(e,_AnyScriptObject):returnget_proxy_slot(e,self,e)else:returnedef_make_temp_remove_mode_context_manager(mode_ty:Type[TorchFunctionMode],)->Callable[[],_GeneratorContextManager[Optional[TorchFunctionMode]]]:@contextmanagerdefcontext_manager_fn()->Generator[Optional[TorchFunctionMode],None,None]:fromtorch.overridesimport_len_torch_function_stack,_pop_mode,_push_modetemp_elements=[]removed_mode=Nonewhile_len_torch_function_stack()>0:mode=_pop_mode()ifisinstance(mode,mode_ty):removed_mode=modebreakelse:temp_elements.append(mode)formodeinreversed(temp_elements):_push_mode(mode)try:yieldremoved_modefinally:ifremoved_modeisnotNone:count=len(temp_elements)whilecount>0:mode=_pop_mode()count-=1temp_elements.append(removed_mode)formodeinreversed(temp_elements):_push_mode(mode)returncontext_manager_fn@torch._disable_dynamodefdispatch_trace(root:Union[Module,Callable],tracer:Tracer,concrete_args:Optional[Tuple[Any,...]]=None,)->GraphModule:graph=tracer.trace(root,concrete_args)# type: ignore[arg-type]# NB: be careful not to DCE .item() callsdefimpure_pred(n:fx.Node)->bool:from.symbolic_shapesimportis_accessor_node# Always defer to the built-in notion of impureifn.is_impure():returnTrue# Accessors always OK to DCEifis_accessor_node(n):returnFalse# If the operator in question takes SymInt args to SymInt output,# we assume it's pure and OK to DCEif(isinstance(n.meta.get("val"),py_sym_types)and# NB: constant args okall(isinstance(a.meta.get("val"),py_sym_types)forainn.argsifisinstance(a,fx.Node))):returnFalse# No idea, just assume it's not OKreturnTruegraph.eliminate_dead_code(impure_pred)fromtorch._inductor.fx_passes.dedupe_symint_usesimportdedupe_symintsdedupe_symints(graph)name=root.__class__.__name__ifisinstance(root,Module)elseroot.__name__returnfx._lazy_graph_module._make_graph_module(tracer.root,graph,name)defwrap_key(f:Callable[_P,R],tensors:_P.args,tracer:_ProxyTracer,pre_dispatch:bool)->Callable[_P,R]:flat_tensors,_tensors_spec=pytree.tree_flatten(tensors)@functools.wraps(f)defwrapped(*proxies:_P.args,**_unused:_P.kwargs)->R:flat_proxies,_proxies_spec=pytree.tree_flatten(proxies)assertlen(flat_proxies)==len(flat_tensors)withdisable_proxy_modes_tracing()asm:assertisinstance(m,ProxyTorchDispatchMode)track_tensor_tree(flat_tensors,flat_proxies,constant=None,tracer=tracer)defget_tensor_proxy_slot(t:Tensor)->Union[Tensor,Proxy]:returnget_proxy_slot(t,tracer,t,lambdax:x.proxy)out=f(*tensors)# type:ignore[call-arg]out=pytree.tree_map_only(Tensor,get_tensor_proxy_slot,out)out=pytree.tree_map_only(_AnyScriptObject,lambdat:get_proxy_slot(t,tracer,t,lambdax:x),out)defget_sym_proxy_slot(t:PySymType)->Proxy:returnget_proxy_slot(t,tracer).force()out=pytree.tree_map_only(py_sym_types,get_sym_proxy_slot,out)returnoutreturnwrapped# TODO: Make downstream users of this work with OperatorBaseORIGINAL_ATEN:Optional[object]=None@contextmanagerdefset_original_aten_op(func:OpOverload)->Generator[None,None,None]:globalORIGINAL_ATENifORIGINAL_ATENisNoneandfx_traceback.has_preserved_node_meta():ORIGINAL_ATEN=funcfx_traceback.current_meta["original_aten"]=functry:yieldfinally:ORIGINAL_ATEN=Nonefx_traceback.current_meta["original_aten"]=Noneelse:yieldclassTorchFunctionMetadataMode(TorchFunctionMode):def__init__(self,tracer:_ProxyTracer)->None:self.tracer=tracerdef__torch_function__(self,func:OpOverload,types:Tuple[torch._C._TensorMeta,...],args:Tuple[object,...]=(),kwargs:Optional[Dict[str,object]]=None,)->object:kwargs=kwargsor{}self.tracer.torch_fn_metadata=funcself.tracer.torch_fn_counts[func]=self.tracer.torch_fn_counts.get(func,0)+1returnfunc(*args,**kwargs)_temp_remove_metadata_torch_function_mode=_make_temp_remove_mode_context_manager(TorchFunctionMetadataMode)# This mode is **only** used for pre_dispatch tracing.# In particular, we need to make sure that autograd/autocast API's# that do not desugar into dispatcher operators stay in the graph.classPreDispatchTorchFunctionMode(TorchFunctionMode):def__init__(self,tracer:_ProxyTracer)->None:self.tracer=tracer# The input to torch.amp.autocast_mode._exit_autocast graph node should be the# enter_autocast node. So we have to save the enter autocast node here, and assign it# to the exit_autocast call_function node.self.enter_autocast_nodes:List[torch.fx.Node]=[]def__torch_function__(self,func:Union[OpOverload,Callable],types:Tuple[torch._C._TensorMeta,...],args:Tuple[object,...]=(),kwargs:Optional[Dict[str,object]]=None,)->object:kwargs=kwargsor{}iffuncin_side_effectful_need_to_be_preserved_pre_dispatch:# It's for passing the export verifier which needs to verify the meta['val']# TODO(tmanlaibaatar): we should systematically couple it with expoert verifier,# instead of hardcoding it here.# T203648563iffunc==torch.amp.autocast_mode._exit_autocast:enter_node=self.enter_autocast_nodes.pop()args=(enter_node,)node=self.tracer.create_node("call_function",func,args,{})# type: ignore[arg-type]iffunc==torch.amp.autocast_mode._enter_autocast:self.enter_autocast_nodes.append(node)iffuncin[torch._C._set_grad_enabled,torch.amp.autocast_mode._enter_autocast,torch.amp.autocast_mode._exit_autocast,]:node.meta["val"]=Nonereturnnode# Don't actually run the function! We just want to trace the calls# into a graph. We don't actualy want to change global autograd state.returnfunc(*args,**kwargs)_temp_remove_pre_dispatch_torch_function_mode=_make_temp_remove_mode_context_manager(PreDispatchTorchFunctionMode)classProxyTorchDispatchMode(TorchDispatchMode):# Ensure this is read-only; this exists only for legacy reasons@propertydefenable_tracing(self)->bool:returnTruedef__init__(self,tracer:_ProxyTracer,tracing_mode:str,pre_dispatch:bool=False,_allow_fake_constant:bool=False,_error_on_data_dependent_ops:bool=True,)->None:dk=torch._C.DispatchKey.PreDispatchifpre_dispatchelseNonesuper().__init__(dk)self.tracer=tracerself.tracing_mode=tracing_modeself.pre_dispatch=pre_dispatchself._allow_fake_constant=_allow_fake_constantself._error_on_data_dependent_ops=_error_on_data_dependent_ops# Indicates to our torch_dispatch dispatching infra that# this is an "infra" mode with lower dispatching precedence.self._mode_key=torch._C._TorchDispatchModeKey.PROXY# Every time we enter a mode, we maintain a stack telling us what the previous# ProxyTorchDispatchMode state was (if there was any).# This lets us properly reset the state on exit.self.enter_stack:List[Optional[ProxyTorchDispatchMode]]=[]self.decomp_layers=0fromtorch._inductorimportconfigself.emulate_precision_casts=config.emulate_precision_casts@countdef__torch_dispatch__(self,func:OpOverload,types:Tuple[torch._C._TensorMeta,...],args:Tuple[object,...]=(),kwargs:Optional[Dict[str,object]]=None,)->object:withset_original_aten_op(func):kwargs=kwargsor{}iffuncin(prim.device.default,):returnfunc(*args,**kwargs)returnproxy_call(self,func,self.pre_dispatch,args,kwargs)def__enter__(self)->Self:# Stash and store the previous proxy mode (there may or may not be one)maybe_prev_proxy_mode=_unset_infra_mode(torch._C._TorchDispatchModeKey.PROXY)self.enter_stack.append(maybe_prev_proxy_mode)returnsuper().__enter__()def__exit__(self,exc_type:Optional[Type[BaseException]],exc_value:Optional[BaseException],traceback:Optional[types.TracebackType],)->Optional[bool]:b=super().__exit__(exc_type,exc_value,traceback)# Re-enable the previous proxy mode, if there was one.mb_previous_proxy_mode=self.enter_stack.pop()ifmb_previous_proxy_modeisnotNone:_push_mode(mb_previous_proxy_mode)returnb@classmethoddefis_infra_mode(cls)->bool:returnTruedef_compute_proxy(self,func:OpOverload,args:Tuple[object,...],out:PySymType)->Proxy:# Handle torch.sym_sumn_args:Tuple[object,...]iflen(args)==1andisinstance(args[0],(list,tuple)):n_args=(tuple(get_proxy_slot(a,self.tracer).force().nodeifisinstance(a,py_sym_types)elseaforainargs[0]),)else:n_args=tuple(get_proxy_slot(a,self.tracer).force().nodeifisinstance(a,py_sym_types)elseaforainargs)# func doesn't have a __torch_function__ that Proxy can interpose, so# we gotta do it manuallyn_out=self.tracer.create_node("call_function",func,n_args,{})# type: ignore[arg-type]p_out=fx.Proxy(n_out,self.tracer)set_meta(p_out,out)returnp_outdef__sym_dispatch__(self,func:OpOverload,types:Tuple[torch._C._TensorMeta,...],args:Tuple[object,...],kwargs:Dict[str,object],)->object:# Peephole optimize multiply by one# NB: be careful not to trigger guards here!iffunc==operator.mul:ifisinstance(args[1],int)andargs[1]==1:returnargs[0]elifisinstance(args[0],int)andargs[0]==1:returnargs[1]# For speed, we assume there are no nested data structures# (otherwise we could use tree_map)# We also assume there are no keyword arguments.assertnotkwargsout=func(*args,**kwargs)# If func returned a constant, we don't need to trace; we have# determined that the result is constant (no matter if the inputs# were symbolic) and it is no longer necessary to trace the# computation. This could occur if func triggered some guards.ifisinstance(out,py_sym_types):p_out_thunk=thunkify(self.tracer,self._compute_proxy,func=func,args=args,out=out)set_proxy_slot(out,self.tracer,p_out_thunk)returnoutclass_GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer):script_object_tracker:MutableMapping[_AnyScriptObjectType,Proxy]symnode_tracker:MutableMapping[PySymType,_PySymProxyType]tensor_tracker:MutableMapping[Tensor,_ProxyTensor]sympy_expr_tracker:Dict[sympy.Symbol,object]torch_fn_metadata:Optional[OpOverload]torch_fn_counts:Dict[OpOverload,int]enable_thunkify:bool=Falsedef__init__(self,graph:fx.graph.Graph)->None:super().__init__(graph)self.symnode_tracker=weakref.WeakKeyDictionary()self.tensor_tracker=WeakTensorKeyDictionary()self.sympy_expr_tracker={}self.script_object_tracker=WeakIdKeyDictionary(dict=None,ref_type=_WeakHashRef)# Stores the torch function that was called during tracingself.torch_fn_metadata=None# Stores the counts for every torch function called. This is to help# distinguish between different calls to the same torch function.self.torch_fn_counts={}# TODO: I'm not sure what the point of this class is; you can just# make_fx through a regular InterpreterclassDecompositionInterpreter(fx.Interpreter):def__init__(self,module:fx.GraphModule,new_graph:fx.Graph,decomposition_table:Optional[Mapping[OpOverload,Callable]]=None,**kwargs:object,)->None:super().__init__(module,**kwargs)# type: ignore[arg-type]self.new_graph=new_graphself.tracer=_GraphAppendingTracerEx(self.new_graph)# Bleghself.decomposition_table=decomposition_tableor{}self.mode=ProxyTorchDispatchMode(self.tracer,tracing_mode="real")defplaceholder(self,target:str,args:Tuple[object,...],kwargs:Dict[str,object]# type: ignore[override])->object:out=super().placeholder(target,args,kwargs)# type: ignore[arg-type]proxy=fx.Proxy(self.new_graph.placeholder(target),self.tracer)track_tensor_tree(out,proxy,constant=None,tracer=self.tracer)# TODO handle case where the first character of target is '*'returnoutdefget_attr(self,target:str,args:Tuple[object,...],kwargs:Dict[str,object]# type: ignore[override])->object:out=super().get_attr(target,args,kwargs)# type: ignore[arg-type]proxy=fx.Proxy(self.new_graph.get_attr(target),self.tracer)track_tensor_tree(out,proxy,constant=None,tracer=self.tracer)returnout# call_function, call_method, call_module get traced automatically by the outer mode.defoutput(self,target:str,args:Tuple[object,...],kwargs:Dict[str,object]# type: ignore[override])->object:out=super().output(target,args,kwargs)# type: ignore[arg-type]defget_proxy_node(x:_ProxyTensor)->fx.node.Node:returnx.proxy.nodedefunwrap(e:Tensor)->Union[Tensor,fx.Node]:returnget_proxy_slot(e,self.tracer,e,get_proxy_node)self.new_graph.output(pytree.tree_map(unwrap,out))returnoutdefrun(self,*args:object,**kwargs:object)->object:# Should enter the mode at least once for being able to restore it later# See: https://github.com/pytorch/pytorch/pull/82549#discussion_r934782025withdecompose(self.decomposition_table),self.mode:returnsuper().run(*args,**kwargs)# type: ignore[arg-type]defwrapper_and_args_for_make_fx(func:Callable[...,R],args:Tuple[object,...],kwargs:Dict[str,object])->Tuple[Callable[[List[object]],R],List[object]]:# make_fx doesn't support kwargs, so we need to do this flattening# and then unflatten the args before calling funcflat_args,spec=pytree.tree_flatten((args,kwargs))defwrapped(flat_args:List[object])->R:fn_args,fn_kwargs=pytree.tree_unflatten(flat_args,spec)returnfunc(*fn_args,**fn_kwargs)returnwrapped,flat_args@contextmanagerdefdisable_autocast_cache()->Generator[None,None,None]:old_value=torch.is_autocast_cache_enabled()torch.set_autocast_cache_enabled(False)try:yieldfinally:torch.set_autocast_cache_enabled(old_value)class_ModuleNotInstalledAsSubmoduleError(NameError):pass# Base class for inline _ModuleStackTracer.__init__.AttrProxyclass_AttrProxy:defreset_proxy_mapping(self,base:Module,path:str)->None:passclass_ModuleStackTracer(PythonKeyTracer):r"""Customized version of PythonKeyTracer that retains module stack information in node.meta["nn_module_stack"]. FX symbolic trace actually does this already, but it relies on `self.root` being the actual module being traced. Since make_fx traces a lambda of our creation, things don't work properly. So for this version we hold onto a reference to the original module (scope_root) and use that to match the path. Also when we see, A / \ B C \ / D we want to record the path as A.B.D by recording only one path. See Note [Preserving the nn module stack metadata during export non-strict mode] # noqa: W605 """def__init__(self,scope_root:GraphModule)->None:super().__init__()self.scope_root=scope_rootself.enable_attr_proxy=Falseself.submodule_paths={}forname,minself.scope_root.named_modules(remove_duplicate=False):ifminself.submodule_paths:self.enable_attr_proxy=Trueelse:self.submodule_paths[m]=nameself.proxy_paths:WeakKeyDictionary[_AttrProxy,str]=WeakKeyDictionary()self.attr_proxy_map:WeakKeyDictionary[Module,_AttrProxy]=WeakKeyDictionary()self.proxy_modules:WeakKeyDictionary[_AttrProxy,Module]=WeakKeyDictionary()self.counter=0self.module_id_cache=defaultdict(list)forname,modinself.scope_root.named_modules(remove_duplicate=False):self.module_id_cache[id(mod)].append(name)# Build a wrapper around _AttrProxy to provide the tracer. We can't# store it on _AttrProxy itself beceause we mimic the underlying class# (including its attributes).tracer=selfclassAttrProxy(_AttrProxy):def__init__(self,base:Module,path:str)->None:# Class is modified to be a subclass of torch.nn.Module# Warning: We blow away our own attributes here to mimic the base class# - so don't expect `self.x` to do anything useful.self.__class__=type(base.__class__.__name__,(self.__class__,base.__class__),{},)self.__dict__=base.__dict__self.__class__.__module__=base.__class__.__module__self.__class__.__qualname__=base.__class__.__qualname__self.reset_proxy_mapping(base,path)defreset_proxy_mapping(self,base:Module,path:str)->None:tracer.proxy_paths[self]=pathtracer.proxy_modules[self]=basedef__getattr__(self,name:str)->AttrProxy:assertisinstance(self,Module)# Calling into torch.nn.Module.__getattr__ with super(),# That __getattr__ is patched to be module_getattr_wrapper in _symbolic_trace.py.# which then calls into _ModuleStackTracer.getattrattr_val=super().__getattr__(name)# type: ignore[misc]ifisinstance(attr_val,AttrProxy):attr_val=tracer.proxy_modules[attr_val]elifnotisinstance(attr_val,Module):returnattr_valifattr_valnotintracer.attr_proxy_map:tracer.attr_proxy_map[attr_val]=AttrProxy(attr_val,tracer.proxy_paths[self]+"."+name)else:# NOTE [caching AttrProxy]. Caching ensures a 1-1 mapping between AttrProxy and the actual attr_val.# 1. We reset the proxy_mapping to solve the diamond shape reference problem: we want to record the# path as A.B.D instead of A.C.D (the purpose of _ModuleStackTracer).# 2. Instead of creating a new AttrProxy, we just reset the proxy_mapping of existing one. This is to avoid# dynamo creating multiple guards for the same attr_val but different AttrProxy when exporting# a model that calls torch.compile (e.g when a model uses torch.cond.)tracer.attr_proxy_map[attr_val].reset_proxy_mapping(attr_val,tracer.proxy_paths[self]+"."+name)returntracer.attr_proxy_map[attr_val]defget_base(self)->Module:returntracer.proxy_modules[self]@propertydef_modules(self)->Dict[str,AttrProxy]:assert"_modules"inself.__dict__submodules=self.__dict__["_modules"]assertisinstance(submodules,dict)return{key:(AttrProxy(value,tracer.proxy_paths[self]+"."+str(key))# type: ignore[misc]ifvalueisnotNoneelsevalue)forkey,valueinsubmodules.items()}self.proxy_type=AttrProxydefpath_of_module(self,mod:Module)->str:""" Use tracked access path during tracing instead of the default BFS behavior. Still use all the possible module paths to verify the result. """ifmodisself.scope_root:return""ifisinstance(mod,_AttrProxy):returnself.proxy_paths[mod]try:returnTracer.path_of_module(self,mod)exceptNameErrorase:raise_ModuleNotInstalledAsSubmoduleErrorfromedefgetattr(self,attr:str,attr_val:object,parameter_proxy_cache:Dict[str,Proxy])->object:if(notisinstance(attr_val,Module)orisinstance(attr_val,fx.GraphModule)ornotself.enable_attr_proxy):returnsuper().getattr(attr,attr_val,parameter_proxy_cache)ifisinstance(attr_val,_AttrProxy):returnattr_val# See NOTE [caching AttrProxy].ifattr_valnotinself.attr_proxy_map:self.attr_proxy_map[attr_val]=self.proxy_type(attr_val,attr)else:self.attr_proxy_map[attr_val].reset_proxy_mapping(attr_val,attr)returnself.attr_proxy_map[attr_val]deftrace(# type: ignore[override]self,root:Union[Module,Callable],concrete_args:Optional[Dict[str,object]])->fx.Graph:res=super().trace(root,concrete_args)# Since we are making _AttrProxy mimic the original# submodule, when someone registers a module directly# to the tracer while tracing, the proxy object gets registered# first. So we need to replace the proxy modules with the real ones# This can happen during HOO tracingproxy_module_names_to_be_replaced:List[Tuple[str,_AttrProxy]]=[]forname,moduleinself.root.named_modules():ifmoduleinself.proxy_modules:proxy_module_names_to_be_replaced.append((name,module))def_delete_proxy_attr(obj:Module,target:str)->bool:# Copied from fx/graph_module.py# Customized it for proxy typeatoms=target.split(".")path,target_submod=atoms[:-1],atoms[-1]assertisinstance(obj,Module)mod=obj# Get the parent moduleforiteminpath:ifnothasattr(mod,item):returnFalsemod=getattr(mod,item)ifnotisinstance(mod,(_AttrProxy,Module)):returnFalseifnothasattr(mod,target_submod):returnFalse# At least the leaf module should be proxy type.ifnotisinstance(getattr(mod,target_submod),_AttrProxy):returnFalsedelattr(mod,target_submod)returnTrueforproxy_module_name,proxy_moduleinproxy_module_names_to_be_replaced:_delete_proxy_attr(self.root,proxy_module_name)actual_module=self.proxy_modules[proxy_module]_assign_attr(actual_module,self.root,proxy_module_name)returnresdefcall_module(self,m:Module,forward:Callable,args:Tuple[object,...],kwargs:Dict[str,object],)->None:"""PythonKeyTracer overrides call_module to avoid the scope handling, but we actually want it. """fromtorch._dynamoimportOptimizedModule# FIXME (tmanlaibaatar)# When we call torch.compile inside HOO, we will end up# invoking a module that is not registered on the root. For# now, we just inline them. But once we start supporting# mark_strict in export, we do need to properly handle this.# Right now, it doesn't matter because current non-strict# use cases don't need to work with HOO.ifisinstance(m,(OptimizedModule,GraphModule)):returnforward(*args,**kwargs)try:returnTracer.call_module(self,m,forward,args,kwargs)except_ModuleNotInstalledAsSubmoduleError:warnings.warn(f"Unable to find the path of the module {m}. ""This might be because the module was not properly registered ""as a submodule, which is not good practice. We will trace ""through the module without recording stack information.")returnforward(*args,**kwargs)defis_leaf_module(self,m:Module,module_qualified_name:str)->bool:returnFalsedefcreate_node(self,*args:object,**kwargs:object)->fx.node.Node:""" Create node and add on metadata. Add nn_module_stack here instead of TracerBase, since calls to make_fx() might not want to record module stack metadata. Add torch_fn by looking at torch_fn_metadata and torch_fn_counts. Add stack_trace by filtering out forward() stack frames. """node=super().create_node(*args,**kwargs)# type: ignore[arg-type]# nn_module_stackifnode.opnotin["placeholder","output"]:if"nn_module_stack"notinnode.meta:node.meta["nn_module_stack"]=self.module_stack# convert nn_module_stack from Dict[key, (FQN, class)] -> Dict[str, Tuple[str, str]]forkey,(fqn,mod_cls)innode.meta["nn_module_stack"].items():ifisinstance(mod_cls,type):node.meta["nn_module_stack"][key]=(fqn,mod_cls.__module__+"."+mod_cls.__qualname__,)# torch_fnif(node.op=="call_function"andself.torch_fn_metadataisnotNoneand"torch_fn"notinnode.meta):node.meta["torch_fn"]=(f"{self.torch_fn_metadata.__name__}_{self.torch_fn_counts[self.torch_fn_metadata]}",f"{self.torch_fn_metadata.__class__.__name__}.{self.torch_fn_metadata.__name__}",)# stack_traceif"stack_trace"notinnode.metaandnode.opnotin["placeholder","output"]:user_frame_summary=CapturedTraceback.extract().summary()ifuser_frame_summary:# we retain frames from forward() calls, or ops# located in torch/__init__.py (e.g. sym_int, sym_constrain_range, vmap)stack_trace=[frameforframeinuser_frame_summaryif(frame.name=="forward"orframe.filename.endswith("torch/__init__.py"))]# filter out forward() frames from fx/_symbolic_trace.py, export/_trace.py# this is hardcoded, but leads to a much cleaner stack tracestack_trace=[frameforframeinstack_traceifnot(frame.filename.endswith("fx/_symbolic_trace.py")orframe.filename.endswith("export/_trace.py"))]if(stack_trace):# empty list for strict mode, dynamo should handle stack_tracestack_trace=traceback.StackSummary.from_list(stack_trace)node.meta["stack_trace"]="".join(stack_trace.format()).strip()returnnodeclass_MakefxTracer:def__init__(self,decomposition_table:Optional[Mapping[OpOverload,Callable]],tracing_mode:str,_allow_non_fake_inputs:bool,pre_dispatch:bool,record_module_stack:bool,_allow_fake_constant:bool,_error_on_data_dependent_ops:bool,)->None:# Configurations that are used to initialize the context managers and their states.# Should not modify them during tracing.self.decomposition_table:Dict[OpOverload,Callable]=dict(decomposition_tableor{})self.decomposition_table.setdefault(torch.ops.aten.sym_numel.default,torch._decomp.decompositions.sym_numel)self.tracing_mode:str=tracing_modeself._allow_non_fake_inputs:bool=_allow_non_fake_inputsself.pre_dispatch:bool=pre_dispatchself.record_module_stack:bool=record_module_stackself._allow_fake_constant:bool=_allow_fake_constantself._error_on_data_dependent_ops:bool=_error_on_data_dependent_ops# All context managers and their states should be initialized before tracing based on the inputs# and configurations. After tracing, their states should be cleaned except for shape_env.# Remember to specify how to intialize it from user inputs and from parent tracer whenever# adding new modes in _MakefxTracer.self.fake_tensor_mode:Optional[FakeTensorMode]=Noneself.proxy_mode:Union[nullcontext,ProxyTorchDispatchMode]=nullcontext()self.proxy_function_mode:Union[nullcontext,PreDispatchTorchFunctionMode]=nullcontext()self.fx_tracer:Optional[PythonKeyTracer]=Noneself.python_dispatcher_mode:Union[nullcontext,Any]=nullcontext()self.torch_fn_metadata_mode:Union[nullcontext,TorchFunctionMetadataMode]=nullcontext()def_checkpoint_modes(self)->List[Any]:return[self.fake_tensor_mode,self.proxy_mode,self.proxy_function_mode,self.fx_tracer,self.python_dispatcher_mode,self.torch_fn_metadata_mode,]def_restore_modes(self,prev_fake_tensor_mode:Optional[FakeTensorMode],prev_proxy_mode:Union[nullcontext,ProxyTorchDispatchMode],prev_proxy_function_mode:Union[nullcontext,PreDispatchTorchFunctionMode],prev_fx_tracer:Optional[PythonKeyTracer],prev_python_dispatcher_mode:Union[nullcontext,Any],prev_torch_fn_metadata_mode:Union[nullcontext,TorchFunctionMetadataMode],)->None:self.fake_tensor_mode=prev_fake_tensor_modeself.proxy_mode=prev_proxy_modeself.proxy_function_mode=prev_proxy_function_modeself.fx_tracer=prev_fx_tracerself.python_dispatcher_mode=prev_python_dispatcher_modeself.torch_fn_metadata_mode=prev_torch_fn_metadata_mode@contextmanagerdef_init_modes_from_inputs(self,f:Callable,args:Tuple[object,...])->Generator[None,None,None]:prev_modes=self._checkpoint_modes()try:# Avoid importing sympy at a module levelfrom.symbolic_shapesimportShapeEnvifhasattr(f,"_orig_mod")andself.record_module_stack:scope_root=f._orig_modself.fx_tracer=_ModuleStackTracer(scope_root)else:self.fx_tracer=PythonKeyTracer()ifself.tracing_mode=="fake":importtorch._dynamofake_tensor_mode=torch._dynamo.utils.detect_fake_mode(args)iffake_tensor_modeisNone:importtorch._functorch.configas_configwith_config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):fake_tensor_mode=FakeTensorMode(allow_fallback_kernels=True,allow_non_fake_inputs=self._allow_non_fake_inputs,shape_env=ShapeEnv(),static_shapes=True,)self.fake_tensor_mode=fake_tensor_modeelifself.tracing_mode=="symbolic":importtorch._dynamofake_tensor_mode=torch._dynamo.utils.detect_fake_mode(args)iffake_tensor_modeisNone:shape_env=ShapeEnv()importtorch._functorch.configas_configwith_config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):fake_tensor_mode=FakeTensorMode(allow_fallback_kernels=False,allow_non_fake_inputs=self._allow_non_fake_inputs,shape_env=shape_env,)assert(fake_tensor_mode.shape_envisnotNone),"shape_env should be set if tracing with 'symbolic'"self.fake_tensor_mode=fake_tensor_modeelse:ifnotself.tracing_mode=="real":raiseAssertionError(f"Unexpected tracing type: {self.tracing_mode}")self._construct_modes_with_fx_tracer(self.fx_tracer)yieldfinally:self._restore_modes(*prev_modes)def_construct_modes_with_fx_tracer(self,fx_tracer:_ProxyTracer)->None:self.proxy_mode=ProxyTorchDispatchMode(fx_tracer,self.tracing_mode,pre_dispatch=self.pre_dispatch,_allow_fake_constant=self._allow_fake_constant,_error_on_data_dependent_ops=self._error_on_data_dependent_ops,)ifself.pre_dispatch:self.proxy_function_mode=PreDispatchTorchFunctionMode(fx_tracer)# pre-autograd tracing uses per-dispatch-key modes,# which requires the python dispatcherifself.tracing_mode=="symbolic"orself.pre_dispatch:self.python_dispatcher_mode=enable_python_dispatcher()self.torch_fn_metadata_mode=TorchFunctionMetadataMode(fx_tracer)@contextmanagerdef_init_modes_from_parent(self,parent_tracer:_MakefxTracer)->Generator[None,None,None]:# By default, subtracer creates new modes based on parent tracer's config.# However, there are cases where we want to share the same modes with parent tracer# For example, fake_tensor_mode, we want the example value's fake_mode of parent graph and subgraphs to be the same.prev_modes=self._checkpoint_modes()try:self.fake_tensor_mode=parent_tracer.fake_tensor_modedef_create_sub_fx_tracer(parent_tracer:_ProxyTracer)->PythonKeyTracer:iftype(parent_tracer)==PythonKeyTracer:returnPythonKeyTracer()eliftype(parent_tracer)==_ModuleStackTracer:return_ModuleStackTracer(parent_tracer.scope_root)else:raiseRuntimeError(f"Unexpected tracer type: {type(parent_tracer)}.")assertparent_tracer.fx_tracerisnotNoneself.fx_tracer=_create_sub_fx_tracer(parent_tracer.fx_tracer)self._construct_modes_with_fx_tracer(self.fx_tracer)yieldfinally:self._restore_modes(*prev_modes)def_trace_inner(self,f:Callable,*args:object)->GraphModule:# TODO: We need to explicitly import torch._dynamo before calling dispatch_trace,# because dispatch_trace will introduce the lazy import of torch._dynamo,# and some contexts set before calling dispatch_trace will cause problems with the import of torch._dynamo,# such as some torch API(torch.ones and so on) in populate_builtin_to_tensor_fn_map() will be affected# by the context set before dispatch_trace.importtorch._dynamophs=pytree.tree_map(lambda_:torch.fx._symbolic_trace.PH,args)def_wrap_fake(args:T)->T:arg_count=0definner_wrap_fake(x:object)->object:nonlocalarg_count# TODO: it would be nice to line these up with the names# FX will choose for the placeholders, but we don't# actually know what the names will be at this point yet# NB: the Source here is actually meaninglessfromtorch._dynamo.sourceimportConstantSourceassertself.fake_tensor_modeisnotNonesource=ConstantSource(f"input{arg_count}")ifisinstance(x,Tensor):arg_count+=1returnself.fake_tensor_mode.from_tensor(x,source=source)# NB: don't match on boolseliftype(x)isintandself.tracing_mode=="symbolic":assert(self.fake_tensor_mode.shape_envisnotNone),"shape_env should be set if tracing with 'symbolic'"returnself.fake_tensor_mode.shape_env.create_symintnode(self.fake_tensor_mode.shape_env.create_symbol(x,source,positive=None),hint=x,source=source,)elifisinstance(x,torch.ScriptObject):returntorch._library.fake_class_registry.maybe_to_fake_obj(self.fake_tensor_mode,x)assertnotisinstance(x,FakeScriptObject),f"ScriptObject {x} has been fakified. Cannot wrap_fake it again."returnxwrap_fn_map={"real":lambdax:x,"fake":inner_wrap_fake,"symbolic":inner_wrap_fake,}returnpytree.tree_map(wrap_fn_map[self.tracing_mode],args)def_wrap_func(f:Callable[_P,R],phs:Sequence[PHBase])->Callable[_P,R]:if(nothasattr(inspect.unwrap(f),"__code__")orinspect.unwrap(f).__code__.co_flags&inspect.CO_VARARGS):# FX doesn't support varargs, so we gotta fake up a wrapper# TODO: Would be nice to fix this at the source...returnfake_signature(f,len(phs))returnfargs=_wrap_fake(args)func=_wrap_func(f,phs)# We disable the autocast cache as the autocast cache causes type conversions on parameters to# check a cache, which introduces untracked tensors into the graph## We also disable tracing by any other tensor proxy-based tracers except the current. The# purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is# thus irrelevant to any external functional trace.proxy_mode:ProxyTorchDispatchMode=typing.cast(ProxyTorchDispatchMode,self.proxy_mode)withExitStack()asstack:stack.enter_context(decompose(self.decomposition_table))ifself.fake_tensor_mode:stack.enter_context(self.fake_tensor_mode)stack.enter_context(self.python_dispatcher_mode)stack.enter_context(self.proxy_function_mode)stack.enter_context(self.torch_fn_metadata_mode)stack.enter_context(proxy_mode)stack.enter_context(disable_autocast_cache())stack.enter_context(_set_make_fx_tracer(self))assertself.fx_tracerisnotNonetry:t=dispatch_trace(wrap_key(func,args,self.fx_tracer,self.pre_dispatch),tracer=self.fx_tracer,concrete_args=tuple(phs),)exceptException:trace_structured("artifact",metadata_fn=lambda:{"name":"make_fx_fail_partial","encoding":"string",},payload_fn=lambda:self.fx_tracer.graph.python_code(# type: ignore[union-attr]root_module="self",verbose=True,include_stride=True,include_device=True,).src,)raise# TODO: kind of a bad way to do it, should maybe figure out a better wayifself.tracing_mode=="symbolic":assertself.fake_tensor_modeisnotNonet.shape_env=self.fake_tensor_mode.shape_envreturntdeftrace(self,f:Callable,*args:object)->fx.GraphModule:withself._init_modes_from_inputs(f,args):returnself._trace_inner(f,*args)deftrace_subgraph(self,f:Callable,*args:object)->GraphModule:# Create a new tracer based on parent's configsub_tracer=_MakefxTracer(self.decomposition_table,"real",self._allow_non_fake_inputs,self.pre_dispatch,self.record_module_stack,self._allow_fake_constant,self._error_on_data_dependent_ops,)withsub_tracer._init_modes_from_parent(self):returnsub_tracer._trace_inner(f,*args)_CURRENT_MAKE_FX_TRACER:Optional[_MakefxTracer]=None@contextmanagerdef_set_make_fx_tracer(tracer:_MakefxTracer)->Generator[None,None,None]:global_CURRENT_MAKE_FX_TRACERprev_tracer=_CURRENT_MAKE_FX_TRACERtry:_CURRENT_MAKE_FX_TRACER=traceryieldfinally:_CURRENT_MAKE_FX_TRACER=prev_tracer
[docs]defmake_fx(f:Callable,decomposition_table:Optional[Mapping[OpOverload,Callable]]=None,tracing_mode:str="real",_allow_non_fake_inputs:bool=False,*,pre_dispatch:bool=False,record_module_stack:bool=False,_allow_fake_constant:bool=False,_error_on_data_dependent_ops:bool=True,)->Callable[...,GraphModule]:""" Given a function f, return a new function which when executed with valid arguments to f, returns an FX GraphModule representing the set of operations that were executed during the course of execution. """asserttracing_modein["real","fake","symbolic"]make_fx_tracer=_MakefxTracer(decomposition_table,tracing_mode,_allow_non_fake_inputs,pre_dispatch,record_module_stack,_allow_fake_constant,_error_on_data_dependent_ops,)@functools.wraps(f)defwrapped(*args:object)->GraphModule:returnmake_fx_tracer.trace(f,*args)returnwrapped
defget_torch_dispatch_modes()->List[TorchDispatchMode]:returntorch.utils._python_dispatch._get_current_dispatch_mode_stack()# TODO: this is a legacy name, there is only ever one proxy mode as it's an# infra modedefget_innermost_proxy_mode()->Optional[ProxyTorchDispatchMode]:returnget_proxy_mode()
[docs]defget_proxy_mode()->Optional[ProxyTorchDispatchMode]:""" Current the currently active proxy tracing mode, or None if we are not currently tracing. This includes pre-dispatch proxy tracing. """pre_dispatch_mode=torch._ops._get_dispatch_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY)mode=torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)assert(pre_dispatch_modeisNoneormodeisNone),f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}"returnpre_dispatch_modeormode
[docs]defhandle_sym_dispatch(func:Callable[_P,R],args:_P.args,kwargs:_P.kwargs)->R:""" Call into the currently active proxy tracing mode to do a SymInt/SymFloat/SymBool dispatch trace on a function that operates on these arguments. """mode=get_proxy_mode()assertmode# Have to do it manually, because we're not doing the normal torch# dispatch machinery which disables it for uswithdisable_proxy_modes_tracing():# TODO: properly compute typestypes:List[Type]=[]returnmode.__sym_dispatch__(func,types,args,kwargs)# type: ignore[arg-type, return-value]
@contextmanagerdefdisable_proxy_modes_tracing()->Generator[ProxyTorchDispatchMode,None,None]:return_disable_infra_mode(torch._C._TorchDispatchModeKey.PROXY)defmaybe_handle_decomp(proxy_mode:ProxyTorchDispatchMode,op:OpOverload,args:Tuple[object,...],kwargs:Dict[str,object],)->object:fromtorch._inductor.compiler_bisectorimportCompilerBisectorifopinCURRENT_DECOMPOSITION_TABLE:ifCompilerBisector.disable_subsystem("aot_eager_decomp_partition","decomposition",lambda:repr(op)):returnNotImplementedwithproxy_mode:proxy_mode.decomp_layers+=1out=CURRENT_DECOMPOSITION_TABLE[op](*args,**kwargs)proxy_mode.decomp_layers-=1returnoutreturnNotImplementeddefget_isolated_graphmodule(func:Callable,args:Tuple[object,...],kwargs:Dict[str,object],tracing_mode:str="real",decomposition_table:Optional[Mapping[OpOverload,Callable]]=None,)->GraphModule:"""A helper function used to get the GraphModule for the given func. It's expected to be used in the ProxyTensor tracing context. It detaches the args and kwargs from the current tracer so that the trace of the current graph module can be created without any side-effects. """wrapped,all_args=wrapper_and_args_for_make_fx(func,args,kwargs)withdisable_proxy_modes_tracing():gm=make_fx(wrapped,decomposition_table=decomposition_table,tracing_mode=tracing_mode)(all_args)returngmdef_set_unbacked_bindings(out:object,out_proxy:_NestedProxys)->None:"""A helper function for setting up unbacked_bindings on the destination FX graph."""from.symbolic_shapesimportcompute_unbacked_bindings# Can't use detect_fake_mode here,## python test/distributed/_tensor/test_dtensor_compile.py -k# test_tp_compile_fullgraph_is_seq_parallel_False## will fail. Very strange, it probably isn't right for them to be using# two fake modes there...fake_mode=torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE)iffake_modeandfake_mode.shape_env:ifsymbol_to_path:=compute_unbacked_bindings(fake_mode.shape_env,out):assertisinstance(out_proxy,Proxy),out_proxyout_proxy.node.meta["unbacked_bindings"]=symbol_to_path
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.