importcollectionsimportdataclassesimportitertoolsimportloggingimportwarningsimportpprintfromcontextlibimportcontextmanager,nullcontextfromdataclassesimportdataclassfromenumimportEnumfromfunctoolsimportpartial,wrapsfromtypingimportAny,Callable,Dict,List,Optional,Set,Tuple,Union,NewTypefromunittest.mockimportpatchfromfunctorchimportmake_fximporttorchimporttorch.fx.tracebackasfx_tracebackimporttorch.nnasnnimporttorch.utils._pytreeaspytreeimporttorch.utils.dlpackfromtorchimportTensorfromtorch._dispatch.pythonimportenable_python_dispatcherfromtorch._dynamo.utilsimportdynamo_timed,lazy_format_graph_codefromtorch._guardsimportdetect_fake_mode,tracingfromtorch._prims_commonimportCUDARngStateHelperfromtorch._loggingimportgetArtifactLoggerfromtorch._subclassesimportFakeTensor,FakeTensorModefromtorch.fximportimmutable_collections,Interpreterfromtorch.fx.experimental.proxy_tensorimportis_sym_node,py_sym_typesfromtorch.fx.experimental.symbolic_shapesimportShapeEnv,is_concrete_int,fx_placeholder_valsfromtorch.multiprocessing.reductionsimportStorageWeakReffromtorch.nn.utilsimportstatelessfromtorch._decomp.decompositions_for_rngimportPhiloxStateTracker,rng_decompositions,PhiloxTotalOffsetsfrom.importconfigfrom.partitionersimportdefault_partitionfromtorch._guardsimportTracingContext,DuplicateInputs,Sourcelog=logging.getLogger(__name__)aot_joint_log=getArtifactLogger(__name__,"aot_joint_graph")aot_graphs_log=getArtifactLogger(__name__,"aot_graphs")MutationType=Enum("MutationType",("none","metadata_only","data","data_and_metadata"))OutputType=Enum("OutputType",(# output is not an alias"non_alias",# output aliases an input"alias_of_input",# output **is** an input tensor"is_input",# output has a ._base tensor, which is a graph intermediate.# We need to return its ._base as a graph output,# so its requires_grad info is populated correctly.# Instructs the runtime code to regenerate the current output# from a base tensor, graph_intermediates[base_idx]"alias_of_intermediate_save_as_output",# Same as above; but we don't need to explicitly add its ._base# as a graph output, because it already **is** a graph output."alias_of_intermediate",# Same as above; but the output's ._base is **already** a user output.# Instructs the runtime code to regenerate the current output from# a base tensor, user_outputs[base_idx]"alias_of_intermediate_base_is_user_output",# See Note [Intermediate Bases Optimization]"unsafe_view_alias",))pytree._register_pytree_node(immutable_collections.immutable_list,lambdax:(list(x),None),lambdax,c:immutable_collections.immutable_list(x),)pytree._register_pytree_node(immutable_collections.immutable_dict,lambdax:(list(x.values()),list(x.keys())),lambdax,c:immutable_collections.immutable_dict(dict(zip(c,x))),)defpartial_asdict(obj:Any)->Any:ifdataclasses.is_dataclass(obj):return{field.name:getattr(obj,field.name)forfieldindataclasses.fields(obj)}elifisinstance(obj,(list,tuple)):returnobj.__class__([partial_asdict(item)foriteminobj])elifisinstance(obj,dict):return{k:partial_asdict(v)fork,vinobj.items()}else:returnobjaten=torch.ops.aten# This global counter increments every time we compile a graph with# AOTAutograd. You can use this to correlate runtime error messages# with compile time (e.g., if you get an error at runtime saying# compiled graph 3 failed, you can set a breakpoint at compile time# for this graph number to investigate further at compile time.)## NB: this is different from get_aot_compilation_context, which tracks# each underlying graph that is compiled. In contrast, AOT_COUNTER# corresponds to top-level invocations of aot_module/aot_function;# one counter is allocated per entire compiled block (but this block# may involve compiling multiple subgraphs; e.g., for forwards/backwards)AOT_COUNTER=itertools.count()KNOWN_TYPES=tuple([torch.Tensor,int,str,float,bool,type(None)]+list(py_sym_types))@contextmanagerdefpreserve_rng_state():withtorch.utils._python_dispatch._disable_current_modes():rng_state=torch.clone(torch.random.get_rng_state())iftorch.cuda.is_available():cuda_rng_state=torch.clone(torch.cuda.get_rng_state())try:yieldfinally:withtorch.utils._python_dispatch._disable_current_modes():torch.random.set_rng_state(rng_state)iftorch.cuda.is_available():torch.cuda.set_rng_state(cuda_rng_state)# Set up hooks so that during backward the fx's stack_trace is properly setcallback_set=Falsedefsetup_stacktrace_preservation_hooks(roots:List):defiter_graph(roots):ifnotroots:returnseen=set()q=collections.deque()fornodeinroots:ifnodeisnotNone:seen.add(node)q.append(node)whileq:node=q.popleft()forfn,_idxinnode.next_functions:iffninseenorfnisNone:continueseen.add(fn)q.append(fn)yieldnodedefget_callback(saved_stack_):defcallback():globalcallback_setfx_traceback.set_stack_trace(saved_stack_)callback_set=Falsereturncallbackdefget_prehook(stack_):defprehook(grad_output):globalcallback_setifnotcallback_set:torch.autograd.variable.Variable._execution_engine.queue_callback(get_callback(fx_traceback.format_stack()))callback_set=Truefx_traceback.set_stack_trace(stack_)returnprehookdefget_posthook(special_stack_):defposthook(grad_input,grad_output):fx_traceback.set_stack_trace(special_stack_)returnposthookfornodeiniter_graph(roots):forward_node_stack=node.metadata.get("traceback_",[])node.register_prehook(get_prehook(forward_node_stack))special_stack=forward_node_stack.copy()special_stack.append("Gradient addition node due to multiple use of tensor around:")node.register_hook(get_posthook(special_stack))# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~## AOT Autograd contains a pretty non-trivial amount of logic to handle edge cases around aliasing and mutation# that are external to the graph (they show up as side effects in some way when you run the graph).## Take a look at `test_aotdispatch.py TestAOTAutograd.test_input_mutation*` tests for some examples functions# and what they're compiled graphs looks like.# Below is a very long comment detailing several edge cases, and showing how AOT Autograd handles them.## Note [AOT Autograd: input data mutations]## If we compile a function that mutates inputs, then those input mutations are real side effects# that a user expects to see after running the compiled graph.# However, the graph that we want to send to a backend needs to be *entirely* functional.# The way we reconcile this difference is that we remove the mutations completely from the graph that we compile# but we update the graph to return (updated_inputs, user_outputs).# In the epilogue that runs after the compiled graph is executed, we copy the updated inputs back to the originals.## Example: original user code:# def f(x):# x.mul_(2)# out = x.mul(3)# return out## After AOT Autograd compiles, we end up with a:# (a) compiled graph# (b) autograd.Function.forward() method, that executes the compiled graph# (c) wrapper function, that calls the autograd.Function.forward() and performs the epilogue## The output of (a, b, c) are all written below.## def compiled_forward_graph(x):# x_updated = x.mul(2)# out = x_updated.mul(3)# return x_updated, out## # x_updated gets a gradient in the compiled backward# def compiled_backward_graph(grad_x_updated, grad_out):# grad_x = ...# return grad_x## def autograd.Function.forward(x):# x_updated, out = compiled_forward_graph(x)# return x_updated, out## def compiled_wrapper(x):# x_updated, out = autograd.Function.apply(x)# x.copy_(x_updated)# return out## Another important thing to note is that updated inputs (due to data mutations) *do* participate# in the compiled backward graph! Since the compiled forward graph gets N extra outputs# (due to updated inputs showing up as graph outputs),# The compiled backward gets an additional N inputs.# That way, during the x.copy_(x_updated) bit in the epilogue, gradients will flow from the updated input# back to the original input.# Note [AOT Autograd: input metadata mutations]## For the same reason as input mutations, we also don't put input metadata mutations in the graph.# Instead, we return the updated version of the input (a view), and mutate the input's metadata outside of the graph## Example: original user code:# def f(x):# x.t_()# out = x.mul(3)# return out## AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function):# def compiled_forward_graph(x):# x_updated = x.t()# out = x_updated.mul(3)# return x_updated, out## # x_updated does *not* get a gradient in the compiled backward# def compiled_backward_graph(grad_out):# grad_x = ...# return grad_x## def autograd.Function.forward(x):# x_updated, out = compiled_forward_graph(x)# return x_updated, out## def compiled_wrapper(x):# x_updated, out = autograd.Function.apply(x)# x.as_strided_(x_updated)# return out# Note [AOT Autograd: outputs aliasing inputs or intermediates!]## AOT Autograd needs special handling for outputs that alias graph inputs or intermediates!# Why?# (1) autograd.Function.forward() has a limitation, where views that returned in the forward cannot later be mutated.# (2) views don't need to be compiled in the graph anyway - it's cheap to generate them outside of the compiled graph,# in an epilogue.# For outputs that alias inputs, we do the following:# (a) *still* return the aliased output as a graph output# (b) In the AOT Autograd wrapper/epilogue, we don't return that aliased output. Instead, we use it to regenerate the output.## For outputs that alias *intermediates*, we do the following:# (a) Return the output in the compiled forward, **and** return it's ._base (a graph intermediates) as an output in the forward# (b) Use (output, graph_intermediate) to regenerate the alias, and return that to the user (instead of the compiled fw output).# You might wonder why we return the aliased output directly in the graph (and making the graph compute it),# only to not return it and instead generate a fresh alias off of the intermediate,# instead of (say) just storing metadata about the size/stride of the output somewhere to generate the alias. There are two reasons:# (1) Getting the actual alias tensor allows us to use view-replay to generate the alias, instead of an as_strided() call# (2) Inductor (and other backends) are free to change the memory format of graph outputs, if it results in better performance.# This can result in problems if a user later tries to .view() that output expecting it to have one set of strides,# when it has a different set of strides.# By including the view op directly in the graph, inductor takes that into account when deciding what memory format# the graph intermediate should be.## Another important thing to note is how our traced backward() graph handles aliases.# (this applies to outputs aliasing inputs, outputs aliasing intermediates,# *and* updated inputs returned in the compiled forward due to metadata-only mutations).# Any outputs that alias (either inputs or intermediates) do NOT participate in the compiled backward graph# It would be wasteful to include them in the compiled backward(), because we regenerate them eagerly# at the end of the forward.## Example: original user code:# def f(x):# out1 = x.t()# intermediate = x.mul(2)# out2 = intermediate.view(-1)# return out1, out2## AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function):# def compiled_forward_graph(x):# out1 = x.t()# intermediate = x.mul(2)# out2 = intermediate.view(-1)# # the compiled graph also returns the intermediate# return out1, out2, intermediate## # intermediate gets a gradient in the compiled backward.# # both output aliases (out1 and out2) do not.# def compiled_backward_graph(grad_intermediate):# grad_x = ...# return grad_x## def autograd.Function.forward(x):# out1, out2, intermediate = compiled_forward_graph(x)# return out1, out2, intermediate## def compiled_wrapper(x):# out1, out2, intermediate = autograd.Function.apply(x)# # regenerate out1 from the input# out1_regenerated = out1._view_func(x)# # regenerate out1 from the intermediate# out2_regenerated = out2._view_func(intermediate)# return out1_regenerated, out2_regenerated# Note [AOT Autograd: mutations to inputs that alias other inputs]## Another edge case that is (only partially) handled today is when an input is mutated, but itself aliases another input.# AOT Autograd needs to **ensure** that functionalization knows that the two inputs are aliased to each other.# That way, when the aliased input is accessed later in the graph, functionalization knows to "update" the alias# given the mutation that occurred.## This is handled by updating the calling convention: we create a "synthetic base" that becomes a new input# in the compiled function, and we regenerate the original (aliased) inputs directly off of the base# inside of the compiled function.## This logic is fully encapsulated in aot_wrapper_synthetic_base()## Example: original user code:# def f(x, x_view):# x.mul_(2)# out = x * x_view# return out# f(x, x.view(-1))## AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function):# def compiled_forward_graph(base)# x = generate_x(base)# x_view = generate_x_view(base)# x_updated = x.mul(2)# x_view_updated = x_updated.view(-1)# out = x_updated * x_view_udpated# return x_updated, out## # The calling convention change from (aliases) -> (base) happens# # *outside* of the autograd.Function.forward().# # That means the forward() only has 1 input (base),# # and the backward() only has 1 output (grad_base)# def compiled_backward_graph(grad_out):# grad_base = ...# return grad_base## def autograd.Function.forward(base):# x_updated, out = compiled_forward_graph(base)# return x_updated, out## # The compiled wrapper is where we create synthetic bases.# # The info on which inputs are mutated is also tracked *before* synthetic base creation.# def compiled_wrapper(x, x_view):# base = merge_view_inputs(x, x_view)# x_updated, out = autograd.Function.apply(base)# # x and x_view are aliased in eager mode, so this mutation to x will automatically affect x_view.# x.copy_(x_updated)# return out## ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# This class stores info about every user output.@dataclass(frozen=True)classOutputAliasInfo:# Tells us if this output is:# (1) a regular (non-aliased) output# (2) an alias of a forward input# (3) **is** a forward input (special case of "alias_of_input")# (4) an alias of an intermediate (aka an alias of an output of the inner traced forward)# (5) an alias of an intermediate, that explicitly requires returning the intermediate# as a graph output# (6) an alias of an intermediate, where that intermediate is also a user outputoutput_type:OutputType# The raw type of the output (torch.Tensor, SymInt, etc)raw_type:type# If (1) above, then# - base_idx is None# If (2) or (3) above, then# - Tells us that the base of this alias is user_fwd_input[base_idx]# (This is an index into the inputs *before* we make synthetic bases)# If (4) or (5) above, then# - Tells us that the base of this alias is output_graph_intermediates[base_idx]# here, this refers to the index of the *direct* traced# If (6) above, then:# - Tells us that the base of this alias is output_user_fwds[base_idx]# here, this refers to the index of the *direct* tracedbase_idx:Optional[int]# If it is a Tensor, what the dynamic dims are (otherwise is None)dynamic_dims:Optional[Set[int]]# This class tells us info about user inputs.@dataclass(frozen=True)classInputAliasInfo:is_leaf:boolmutates_data:boolmutates_metadata:bool# This class encapsulates all aliasing + mutation info we need about the forward graph# See a more detailed overview of the edge case handling at# https://docs.google.com/document/d/19UoIh_SVrMy_b2Sx5ZaeOJttm6P0Qmyss2rdBuyfoic/edit@dataclass(eq=False)classViewAndMutationMeta:# length = # user inputs# This gives us info about every input, and what sort of mutation happened to it (if any)input_info:List[InputAliasInfo]# length = # user outputs# This gives us info about every output (mostly around whether it aliases other tensors)output_info:List[OutputAliasInfo]# length = # mutated inps + # user outputs# For every output *and* mutated input returned from the forward,# tells us whether or not the output should require gradients or notrequires_grad_info:List[bool]# length = the number of intermediate bases appended as outputs to the end of the forward graph.# Note: this is not necessarily the same thing as:# len([x for x in output_info if x.output_type == OutputType.alias_of_intermediate])# Because outputs might share a ._base, or an output's ._base might itself be# another user output (in both cases, we won't redundantly append bases to the end of the graph)num_intermediate_bases:int# For inference only: instructs us to keep data-only input mutations directly in the graphkeep_input_mutations:int# length = (# inputs w data mutations) + (# user outputs that are non_aliasing tensors)# + (# intermediate bases)# These are the FakeTensor (or potential SymInt) outputs that we traced from our# metadata pass of the user's forward function.# Their only use today is to pass them as a best-guess for tangents when tracing the joint.# Stashing them as part of our "metadata" makes it simpler if we want to run our analysis# pass once, and re-use the output throughout AOTAutogradtraced_tangents:List[Any]def__post_init__(self):mutated_inp_indices=[ifori,minenumerate(self.input_info)ifm.mutates_metadataorm.mutates_data]# pre-compute the indices of the inputs that are mutated.# When keep_input_mutations is set, we don't need to worry about our epilogue# handling data-only mutations, because we keep them directly in the graph.mutated_inp_runtime_indices=[ifori,minenumerate(self.input_info)ifm.mutates_metadataor(notself.keep_input_mutationsandm.mutates_data)]aliased_out_indices=[ifori,minenumerate(self.output_info)ifm.output_typenotin[OutputType.non_alias,OutputType.unsafe_view_alias]]self.mutated_inp_indices=mutated_inp_indices# This is pre-computed in post_init for perf.# It contains the index of every element# of input_info that corresponds to a mutation (data or metadata or both)self.mutated_inp_runtime_indices=mutated_inp_runtime_indices# This is pre-computed for perf.# It contains the index of every element# of output_info that corresponds to an alias (either of an input or intermediate)self.aliased_out_indices=aliased_out_indicesself.num_outputs=len(self.output_info)self.num_outputs_non_aliased=len([xforxinself.output_infoifx.output_typein[OutputType.non_alias,OutputType.unsafe_view_alias]])self.num_outputs_aliased_to_inputs=len([xforxinself.output_infoifx.output_typein[OutputType.alias_of_input,OutputType.is_input,]])self.num_outputs_aliased_to_intermediates=len([xforxinself.output_infoifx.output_typein[OutputType.alias_of_intermediate,OutputType.alias_of_intermediate_save_as_output,OutputType.alias_of_intermediate_base_is_user_output,]])self.num_outputs_aliased=(self.num_outputs_aliased_to_inputs+self.num_outputs_aliased_to_intermediates)self.num_mutated_data_inputs=len([xforxinself.input_infoifx.mutates_data])self.num_mutated_metadata_inputs=len([xforxinself.input_infoifx.mutates_metadata])self.num_mutated_metadata_only_inputs=len([xforxinself.input_infoifnotx.mutates_dataandx.mutates_metadata])self.num_mutated_inputs=self.num_mutated_data_inputs+self.num_mutated_metadata_only_inputsself.dynamic_outputs=any(o.dynamic_dimsforoinself.output_info)def__eq__(self,other):ifnotisinstance(other,ViewAndMutationMeta):returnNotImplementedreturn(self.input_info==other.input_infoandself.output_info==other.output_infoandself.requires_grad_info==other.requires_grad_infoandself.num_intermediate_bases==other.num_intermediate_basesandself.keep_input_mutations==other.keep_input_mutationsandlen(self.traced_tangents)==len(other.traced_tangents)andall(x.shape==y.shapeandx.dtype==y.dtypeforx,y,inzip(self.traced_tangents,other.traced_tangents)))# This side data structures stores the functionalization of RNG related metadata# to be used at runtime. In future, we can repurpose this class to RuntimeMeta# if more metadata usecases popup@dataclassclassRNGMeta:# Stores if the config.functionalize_rng_ops was True at compile timeis_compiled_with_functional_rng_ops:bool# Stores PhiloxTotalOffsets to be used at runtimephilox_total_offsets:PhiloxTotalOffsets# This class exists because:# - the autograd.Function.forward() in aot autograd returns outputs that might alias inputs# - we only care about the metadata on those aliases, so we can regenerate them.# We do not want them to participate in the autograd.Function.# We do that by wrapping them in an opaque class, so the autograd.Function# does not know to treat them as tensors.@dataclass(frozen=True)classTensorAlias:alias:torch.Tensordefhas_same_metadata(t1,t2):return(t1.size()==t2.size()andt1.stride()==t2.stride()andt1.storage_offset()==t2.storage_offset())defgen_alias_from_base(aliased_base_tensor,target_meta_tensor,target_requires_grad):# Try to do view-replay if possible.# fall back to .as_strided() if we can't.iftarget_meta_tensor._baseisnotNone:# The base that we want to replay our view off of might have a different shape than the view's original base.b=target_meta_tensor._baseabt=aliased_base_tensor# Don't unnecessarily call as_strided if nothing changed; as_strided's# backward is poorly implemented and slowifabtisnotband(abt.size()!=b.size()orabt.stride()!=b.stride()orabt.storage_offset()!=b.storage_offset()):reshaped_base_tensor=aliased_base_tensor.as_strided(b.size(),b.stride(),b.storage_offset())else:reshaped_base_tensor=aliased_base_tensorout=target_meta_tensor._view_func(reshaped_base_tensor)# This shape mismatch can happen due to a bug in inplace/view handling in autograd.# Try putting a breakpoint here and running# `test/functorch/test_aotdispatch TestAOTAutograd.test_output_all_alias_types`# Also, https://github.com/pytorch/pytorch/issues/49825## As a stopgap, we'll fall back to as_strided.ifoutisnotNoneandout.shape==target_meta_tensor.shape:ifaliased_base_tensor.requires_gradandnottarget_requires_grad:out=out.detach()elifnotaliased_base_tensor.requires_gradandtarget_requires_grad:out.requires_grad_(True)returnoutsize=target_meta_tensor.size()stride=target_meta_tensor.stride()storage_offset=target_meta_tensor.storage_offset()ifaliased_base_tensor.is_complex()andnottarget_meta_tensor.is_complex():aliased_out=torch.view_as_real(aliased_base_tensor).as_strided(size,stride,storage_offset)elifnotaliased_base_tensor.is_complex()andtarget_meta_tensor.is_complex():aliased_out=torch.view_as_complex(aliased_base_tensor).as_strided(size,stride,storage_offset)else:aliased_out=aliased_base_tensor.as_strided(size,stride,storage_offset)# For outputs aliasing inputs, we need to check if the requires-gradness has changed.ifaliased_base_tensor.requires_gradandnottarget_requires_grad:aliased_out=aliased_out.detach()elifnotaliased_base_tensor.requires_gradandtarget_requires_grad:aliased_out.requires_grad_(True)returnaliased_outdefto_fun(t):ifisinstance(t,Tensor):returntorch._to_functional_tensor(t,mirror_autograd_meta=True)else:returntdeffrom_fun(t):ifnotisinstance(t,Tensor)ornottorch._is_functional_tensor(t):returnttorch._sync(t)returntorch._from_functional_tensor(t)# This is a version of functionalization that is specifically designed# for the AOTAutograd use case.## Unlike functorch's variant, this doesn't use the functorch level system,# instead it directly uses PyTorch's conventional dispatcher to hit the# functionalization key. In particular, this means that FunctionalTensorWrapper# can have autograd data stored directly on it.## In typical AOTAutograd usage, the dispatch key order will look like:## Autograd - Functionalization ~~~~> Proxy Mode - Fake Tensor# outer tensor inner tensor## Returns:# - ViewAndMutationMeta, telling us metadata about the inputs and outputs, and# The list of outputs from the forward, but **only** the outputs that we need# to pass in as tangents into the backward.# Specifically, aliased outputs from the forward get regenerated, and don't participate# in the compiled backward function.defrun_functionalized_fw_and_collect_metadata(f,*,keep_input_mutations:bool)->ViewAndMutationMeta:memo={}defto_fun(t):ifisinstance(t,Tensor):iftinmemo:returnmemo[t]r=torch._to_functional_tensor(t,mirror_autograd_meta=True)memo[t]=rreturnrelse:returntdeffrom_fun(t):ifnotisinstance(t,Tensor)ornottorch._is_functional_tensor(t):returnttorch._sync(t)returntorch._from_functional_tensor(t)@wraps(f)definner(*flat_args):# This function is meant to be run with the forward, which expects a flat list of tensor/symint/other args.assertall(isinstance(a,KNOWN_TYPES)forainflat_args)input_info:List[InputAliasInfo]=[]output_info:List[OutputAliasInfo]=[]input_requires_grad_info:List[bool]=[]output_requires_grad_info:List[bool]=[]flat_f_args=pytree.tree_map(to_fun,flat_args)torch._enable_functionalization(reapply_views=True)try:# precondition: The passed in function already handles unflattening inputs + flattening outputsflat_f_outs=f(*flat_f_args)finally:torch._disable_functionalization()# Inspect the state of the input tensor functional wrapper to detect input mutation info# If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated versionfor(i,(arg,f_arg))inenumerate(zip(flat_args,flat_f_args)):ifnotisinstance(arg,Tensor):new_arg=argelse:torch._sync(f_arg)new_arg=torch._from_functional_tensor(f_arg)ifargisnotnew_arg:ifStorageWeakRef(arg.untyped_storage())==StorageWeakRef(new_arg.untyped_storage()):mutates_data=Falsemutates_metadata=Trueelse:mutates_data=Truemutates_metadata=nothas_same_metadata(arg,new_arg)# Only track requires_grad info on *mutated* inputs,# because they show up in the autograd.Function.forward as outputsinput_requires_grad_info.append(isinstance(f_arg,torch.Tensor)andf_arg.requires_grad)else:mutates_data=Falsemutates_metadata=Falseinput_info.append(InputAliasInfo(is_leaf=isinstance(arg,torch.Tensor)andarg.is_leaf,mutates_data=mutates_data,mutates_metadata=mutates_metadata))# If a function involves creating a tensor, and returning a view of it, such that its _base is the intermediiate,# We need to make sure our graph returns the _base as a graph output, and we manually recreate the view# to return to the user. Why? The backend compiler is free to (incorrectly) not set requires_grad# on the base tensor, but we are obligated to properly set requires-gradness on the real output.num_mutated_inps=len([xforxininput_infoifx.mutates_dataorx.mutates_metadata])inp_storage_refs={StorageWeakRef(inpt.untyped_storage()):idxforidx,inptinenumerate(flat_f_args)ifisinstance(inpt,torch.Tensor)}# We need inp tensor id's to be able to tell if an outputs **are** inputs.inp_tensor_ids={id(inpt)forinptinflat_f_argsifisinstance(inpt,torch.Tensor)}# We need output tensor id's to tell if any output._base` attributes **are** other outputs.# (This is also a dict because we need to know that output's index, so we can regenerate# the alias from it).out_tensor_ids={id(o):ifori,oinenumerate(flat_f_outs)}# Keep track of which outputs alias other outputsout_tensor_alias_counts=collections.defaultdict(int)foroinflat_f_outs:ifisinstance(o,torch.Tensor):out_tensor_alias_counts[StorageWeakRef(o.untyped_storage())]+=1# maps the id of an intermediate base to its index in the output of the compiled forwardintermediate_base_tensor_id_to_output_idx:Dict[int,int]={}intermediate_bases:List[torch.Tensor]=[]foroinflat_f_outs:if(isinstance(o,torch.Tensor)andStorageWeakRef(o.untyped_storage())ininp_storage_refs):base_idx=inp_storage_refs[StorageWeakRef(o.untyped_storage())]is_input_tensor=id(o)ininp_tensor_idsifis_input_tensor:output_type=OutputType.is_inputelse:output_type=OutputType.alias_of_input# We only need to handle the intermediate base case when both# the intermediate base and the output require gradients.# See Note [AOT Autograd: outputs aliasing inputs or intermediates!]elif(isinstance(o,torch.Tensor)ando._baseisnotNoneando.requires_gradando._base.requires_grad):ifout_tensor_alias_counts[StorageWeakRef(o.untyped_storage())]==1:# Note [Intermediate Bases Optimization]# Normally if we have an output that aliases an intermediate,# we need to add the extra "intermediate base" logic further down# to prevent autograd from yelling at us if the user later tries to# mutate that output.# However, the common case here is if we have an output that aliases an intermediate,# but doesn't alias any other outputs.# In that case, autograd shouldn't have to worry about the aliasing at all# (if that output is mutated, there are no other live aliases for autograd to worry about).# The "intermediate bases" can hurt inductor perf by forcing more variables to become outputs.# So as an optimization, we won't do intermediate base handling in this case.# Instead, we'll hide the aliasing from autograd using aten._unsafe_view().output_type=OutputType.unsafe_view_aliasbase_idx=Noneelse:# First, check if o's ._base is an existing outputmaybe_existing_out_idx=out_tensor_ids.get(id(o._base),None)ifmaybe_existing_out_idxisnotNone:# Special case where the output is an alias of a graph intermediate, but that intermediate# is itself also a user output.output_type=OutputType.alias_of_intermediate_base_is_user_outputbase_idx=maybe_existing_out_idxelse:# Next, check if o's ._base is an intermediate base that we already returnedmaybe_existing_base_output_idx=intermediate_base_tensor_id_to_output_idx.get(id(o._base),None)ifmaybe_existing_base_output_idxisnotNone:output_type=OutputType.alias_of_intermediatebase_idx=maybe_existing_base_output_idxelse:# Otherwise, take o._base and explicitly return it as an output in the compiled graphnew_out_idx=len(intermediate_bases)base_idx=new_out_idx# Indicate to the logic later on (when we trace the joint)# that this particular output should get it's ._base appended to the forward graph outputsoutput_type=OutputType.alias_of_intermediate_save_as_outputintermediate_base_tensor_id_to_output_idx[id(o._base)]=new_out_idxintermediate_bases.append(o._base)else:output_type=OutputType.non_aliasbase_idx=Noneifisinstance(o,torch.Tensor):dynamic_dims={ifori,sinenumerate(o.shape)ifnotis_concrete_int(s)}else:dynamic_dims=Noneout_info=OutputAliasInfo(output_type=output_type,raw_type=type(o),base_idx=base_idx,dynamic_dims=dynamic_dims,)output_info.append(out_info)output_requires_grad_info.append(isinstance(o,torch.Tensor)ando.requires_grad)# Our autograd.Function.forward returns both mutated inputs and outputs,# so we need grad info on all of them.requires_grad_info=input_requires_grad_info+output_requires_grad_infoassertlen(requires_grad_info)==len(output_info)+len([xforxininput_infoifx.mutates_dataorx.mutates_metadata])# This analysis function returns *only* the outputs that are meant to be tangents to the backwards.# Anything that aliases (inputs returned in the fw due to metadata mutations, or outputs that alias inputs/intermediates)# are *regenerated* later, and not used directly in the autograd graphf_input_tangents=[inpforinp,infoinzip(flat_f_args,input_info)ifinfo.mutates_data]f_output_tangents=[oforo,infoinzip(flat_f_outs,output_info)ifinfo.output_typein[OutputType.non_alias,OutputType.unsafe_view_alias]andissubclass(info.raw_type,torch.Tensor)]# intermediate bases are also included in the backward graphf_tangents=f_input_tangents+f_output_tangents+intermediate_basestraced_tangents=pytree.tree_map(from_fun,f_tangents)metadata=ViewAndMutationMeta(input_info=input_info,requires_grad_info=requires_grad_info,output_info=output_info,num_intermediate_bases=len(intermediate_bases),keep_input_mutations=keep_input_mutations,traced_tangents=traced_tangents,)returnmetadatareturninner@dataclasses.dataclassclassAOTConfig:""" Configuration for AOTDispatcher """fw_compiler:Callablebw_compiler:Callablepartition_fn:Callabledecompositions:Dict[Callable,Callable]num_params_buffers:intaot_id:intkeep_inference_input_mutations:booldynamic_shapes:bool=Falseaot_autograd_arg_pos_to_source:Optional[List[Source]]=Noneinference_compiler:Optional[Callable]=Noneenable_log:bool=True# This function takes in a tensor t, and returns one of t, t.view(), or t.clone().# When tracing the joint forward + backward, for any inputs in the graph that are mutated,# we need to clone them first (and similarly for metadata-only mutations, we need to view them first).# The idea is that when we trace the backward, we need to pass in the *original* primals# to autograd.grad(), before they were mutated.# Note: when we have synthetic base inputs, we need to clone them *before* creating views off of them.# This means that "idx" here represents the index of the (potentially) synthetic base.# What we need to do is:# (1) map the current (post-synthetic-base calling convention) input argument index# to int index pre-synthetic-base-calling-convention.# (2) There could be multiple, if this index corresponds to a synthetic base# that has multiple input aliases.# (3) If any of those corresponding inputs get metadata mutations, then we clone the base.defmaybe_to_fresh_input(idx,t,meta):ifnotisinstance(t,Tensor):returntifidxinmeta.mutated_inp_indices:# We only need to bother cloning mutated inputs that participate in autograd.mutated_inp_idx=meta.mutated_inp_indices.index(idx)ifmeta.requires_grad_info[mutated_inp_idx]andmeta.input_info[idx].mutates_data:# Make sure the primal we pass to autograd.grad()# sees the tensor before the mutationreturnt.clone()ifmeta.requires_grad_info[mutated_inp_idx]andmeta.input_info[idx].mutates_metadata:# Make sure the primal we pass to autograd.grad()# sees the tensor before the metadata mutationreturnt.view(t.shape)returnt# This function returns a new function that returns mutated inputs as outputs.# if keep_data_input_mutations is set, then we assume that data-only mutations# will be left in the graph, and we only return metadata-mutated inputs as outputs.deffn_input_mutations_to_outputs(fn:Callable,meta:ViewAndMutationMeta,keep_data_input_mutations:bool,)->Any:definner_fn(*args):outs=fn(*args)assertlen(meta.output_info)==len(outs)# The compiled fw will return mutated input tensors, *including* metadata-only mutation.# However, if keep_data_input_mutations is set, the compiled fw only needs to return metadata-mutated inputs.# (because data-only input mutations are handled directly in the compiled graph)mutated_inputs_to_return=[xfor(i,x)inenumerate(args)ifmeta.input_info[i].mutates_metadataor(meta.input_info[i].mutates_dataandnotkeep_data_input_mutations)]return*mutated_inputs_to_return,*outsreturninner_fn# This function takes in a fn with external aliasing and mutation,# and returns a new fn with no external aliasing and mutation,# as needed for autograd.# The main transformations are:# - Return mutated inputs as extra outputs# - Clone mutated inputs that require gradients,# because autograd will require us to pass the pre-mutated inputs into autograd.grad# - Return intermediate bases of outputs as additional outputs,# needed to appease autograd.Function# The new function returns:# (1) The updated outputs# (2) A boolean mask of len(new_fn_outputs),# that can be used to tell autograd.grad which outputs should get tangents# if we trace the backward.deffn_prepped_for_autograd(fn:Callable,meta:ViewAndMutationMeta,)->Any:definner_fn(*args):args_maybe_cloned=[maybe_to_fresh_input(i,t,meta)fori,tinenumerate(args)]outs=fn(*args_maybe_cloned)assertisinstance(outs,(tuple,list))outs=list(outs)assertlen(meta.output_info)==len(outs)mutated_inputs_to_return=[xfor(i,x)inenumerate(args_maybe_cloned)ifmeta.input_info[i].mutates_metadataormeta.input_info[i].mutates_data]intermediate_bases=[]fori,(o,info)inenumerate(zip(outs,meta.output_info)):ifinfo.output_type==OutputType.alias_of_intermediate_save_as_output:intermediate_bases.append(o._base)elifinfo.output_type==OutputType.unsafe_view_alias:# See Note [Intermediate Bases Optimization]outs[i]=torch.ops.aten._unsafe_view.default(o,o.shape)assertmeta.num_intermediate_bases==len(intermediate_bases)# the compiled forward should return (mutated_inputs, user_outs, intermediate_bases)fw_outs_to_return=*mutated_inputs_to_return,*outs,*intermediate_bases# Also return a boolean mask specifying which outputs to this function will be used as tangentsmutated_inputs_grad_mask=[meta.input_info[meta.mutated_inp_indices[i]].mutates_datafor(i,x)inenumerate(mutated_inputs_to_return)]# Pass any (non-aliased) outputs in as tangents, since they'll be returned as outputs in the fw# For outputs that are aliases of intermediates, we will have returned the output's _base as an output in the graph instead,# which we *should* send to grad()output_grad_mask=[meta.output_info[i].output_typein[OutputType.non_alias,OutputType.unsafe_view_alias]# Also, only tensor outputs should participate in the backward# (in particular, Symint outputs in the forward graph shouldn't get tangents)andissubclass(meta.output_info[i].raw_type,torch.Tensor)for(i,x)inenumerate(outs)]intermediate_base_grad_mask=[Truefor_inrange(len(intermediate_bases))]out_grad_mask=mutated_inputs_grad_mask+output_grad_mask+intermediate_base_grad_maskassertlen(out_grad_mask)==len(fw_outs_to_return)# Take care to grab and sync the updated inputs from primals_after_cloning (the inputs we actually mutate!)# and not primals (the preserved inputs, pre-mutation, that we pass to grad())# This is annoying: our joint function needs to be aware of functionalization# (syncing mutated inputs before calling autograd.grad())# In theory, we could make the autograd engine do this automatically, although that probably isn't any cleaner.fori,arginenumerate(args_maybe_cloned):ifnotisinstance(arg,Tensor):continuetorch._sync(arg)returnfw_outs_to_return,out_grad_maskreturninner_fn# Given a fn, computes the joint.# NOTE: fn is expects the following behavior:# (1) fn() needs to return a tuple of (outs, mask),# where `mask` tells us which outputs are meant to have tangents.# we don't know this info automatically, because we don't actually want to blindly# compute tangents for every output that requires grad.# Specifically, outputs that alias inputs won't participate in the backward and get tangents.# (2) fn() cannot mutate any inputs that require gradient.# otherwise, when we compute autograd.grad(), we will not take those input mutations into account# (the way this is handled is that we ensure any inputs that normally get mutated are cloned first)defcreate_joint(fn:Callable,)->Any:definner_fn(primals:List[Any],tangents:List[Any]):outs,tangent_mask=fn(*primals)assertlen(tangent_mask)==len(outs)outs_to_grad=[oforneeds_tangent,oinzip(tangent_mask,outs)ifneeds_tangent]assertlen(outs_to_grad)==len(tangents)# Get the inputs that need gradientsgrad_primals=[]inputs_needs_grads=[]# Note that we're not using primals here,# being carefully not to pass any mutated inputs into autograd.grad()forpinprimals:is_grad_tensor=isinstance(p,Tensor)andp.requires_gradinputs_needs_grads.append(is_grad_tensor)ifis_grad_tensor:grad_primals.append(p)# Get the outputs that need gradientsneeded_outs=[]needed_tangents=[]forout,tangentinzip(outs_to_grad,tangents):ifisinstance(out,Tensor)andout.requires_grad:# A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32# The issue is that we are sensitive to decomps that don't accurately maintain# their output's _base.shape compared to eager mode, and this helps mitigate a bit.needed_outs.append(outifout.shape==tangent.shapeelseout.view(tangent.shape))needed_tangents.append(tangent)setup_stacktrace_preservation_hooks([out.grad_fnforoutinneeded_outs])ifconfig.functionalize_rng_ops:PhiloxStateTracker.mark_beginning_of_backward()backward_out=[]# Call the backwards passifgrad_primals:withfx_traceback.preserve_node_meta():backward_out=torch.autograd.grad(needed_outs,grad_primals,grad_outputs=needed_tangents,allow_unused=True,)backward_out_iter=iter(backward_out)returnouts,[next(backward_out_iter)ifielseNoneforiininputs_needs_grads]returninner_fn# This creates the final function that we want to trace using make_fx(),# in both aot_dispatch_autograd and aot_dispatch_base.# Preconditions:# - fn corresponds to the user's fw function# - fn arguments have been flattened, duplicate arguments have been handled# - In the returned function, the "primals" arguments *includes* synthetic bases.# This function does the work of functionalizing the input function,# and performing copy_() calls at the end of the function if `keep_input_mutations` is set.# The function returned has signature that is either:# (1) "traced_fn(primals: List[Any])" if trace_joint is False# (2) "traced_fn(primals: List[Any], tangents: List[Any])" if trace_joint is Truedefcreate_functionalized_graph(fn,args,*,meta:ViewAndMutationMeta,aot_config:AOTConfig,trace_joint:bool,):deffunctionalized_f_helper(*args):# Wrap inputs into functional wrappersf_args=pytree.tree_map(to_fun,args)torch._enable_functionalization(reapply_views=True)try:# Run the jointf_outs=fn(*f_args)finally:torch._disable_functionalization()ifaot_config.keep_inference_input_mutationsandnottrace_joint:# Note: This is a bit annoying. There's a layering issue here, where:# (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs.# (2) For keep_input_mutations, we support tracing a call to copy_() directly on mutated inputs.# However, we **only** want to support this for inputs that have data-only (and no metadata) mutations,# because inductor (and backends in generally) would prefer not to see these (e.g. as_strided_(), resize_()).# This makes it pretty difficult for this logic to operate on synthetic bases.# (3) In addition, there are cases where it's significantly cheaper to perform the copy on the individual# (unpacked) input aliases, instead of the synthetic base.# Example case where (3) could be important:## def f(x, y):# x.mul_(2)# y.mul_(3)# return x, y# a = torch.ones(1'000'000)# x, y = out(a[0:9], a[1:10])## It would be much better to add copy_() calls into the graph for the two tiny slices, instead of materializing# a giant "updated synthetic base" and copying into a's entire storage.## For now, we are pessimistically not performing the optimization from (3);# we will materialize an "updated" synthetic base, and copy it back to the synthetic input base.# This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry# about synthetic bases.fori,(inpt_old,inpt_f)inenumerate(zip(args,f_args)):ifnotisinstance(inpt_f,torch.Tensor):continuetorch._sync(inpt_f)inpt_new=torch._from_functional_tensor(inpt_f)ifmeta.input_info[i].mutates_dataandnotmeta.input_info[i].mutates_metadata:# We found an input that had a (data-only) mutation.# Since keep_input_mutations is set, we need to faithfully apply a copy_()# so the compiler will see the input mutation in the graph.assertinpt_newisnotinpt_oldasserthas_same_metadata(inpt_new,inpt_old)inpt_old.copy_(inpt_new)returnpytree.tree_map(from_fun,f_outs)# Kinda annoying, but needed to make sure that the fx graph we trace out has "primals"# and "tangents" as its input names (which are special-cased by the partitioner)defjoint_helper(primals,tangents):returnfunctionalized_f_helper(primals,tangents)deffwd_helper(*args):returnfunctionalized_f_helper(*args)helper=joint_helperiftrace_jointelsefwd_helperifconfig.functionalize_rng_ops:# Setup the wrapper for functionalization of rng opshelper,args=create_functionalized_rng_ops_wrapper(helper,args,trace_joint)withenable_python_dispatcher():fx_g=make_fx(helper,decomposition_table=aot_config.decompositions)(*args)returnfx_gdefnormalize_as_list(x):ifisinstance(x,tuple):returnlist(x)elifisinstance(x,list):returnxreturn[x]aot_autograd_decompositions={}# This is a list since looking forward, we can have this arbitrarily nested.graph_being_compiled:List[str]=[]# TODO: It would be nice to reset the numbering every time aot_id goes# up, but this is annoying to do right now (because we don't know if# an aot_id will come back from the dead), so right now this also happens# to be a globally unique number too (at the cost of wobbling if you change# how the graphs compile)nth_graph:int=0model_name:str="model"defset_model_name(name):globalmodel_namemodel_name=namedefget_aot_compilation_context()->Tuple[List[str],str,int]:returnlist(graph_being_compiled),model_name,nth_graphdefget_aot_graph_name()->str:""" Returns the name of the graph being compiled. """globalmodel_name,graph_being_compiled,nth_graphreturnf"{model_name}__{'_'.join(graph_being_compiled)}_{nth_graph}"get_graph_being_compiled=get_aot_graph_name@contextmanagerdeftrack_graph_compiling(aot_config,graph_name):globalgraph_being_compiled# TODO: Don't shove the aot_id in here; set it in the contextgraph_being_compiled=[f"{aot_config.aot_id}_{graph_name}"]try:yieldfinally:globalnth_graphnth_graph+=1graph_being_compiled=[]defmake_boxed_func(f):defg(args):returnf(*args)g._boxed_call=Truereturngdefmake_boxed_compiler(compiler):@wraps(compiler)deff(fx_g,inps):out_f=compiler(fx_g,inps)fx_g=make_boxed_func(out_f)returnfx_greturnfdefcall_func_with_args(f,args,steal_args=False,disable_amp=False):ifnotsteal_args:args=list(args)assertisinstance(args,list)ifdisable_amp:guard=torch._C._DisableAutocast()try:ifhasattr(f,"_boxed_call"):out=normalize_as_list(f(args))else:# TODO: Please remove soon# https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670warnings.warn("Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. ""Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. ""See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.")out=normalize_as_list(f(*args))finally:ifdisable_amp:delguardreturnoutdefaot_dispatch_base(flat_fn,flat_args:List[Tensor],aot_config:AOTConfig,*,fw_metadata:ViewAndMutationMeta):# aot_dispatch_base requires functionalization, but doesn't need to handle as many cases as the autograd case.# The cases that aot_dispatch_base doesn't need to handle include:# - outputs that are aliases of graph intermediates# - outputs that are aliases of graph inputs# While cases that it does need to handle include:# - input mutations (including when inputs are aliases of each other)# - input metadata mutationskeep_mutations=aot_config.keep_inference_input_mutationsfn_to_trace=fn_input_mutations_to_outputs(flat_fn,fw_metadata,keep_data_input_mutations=aot_config.keep_inference_input_mutations,)fw_module=create_functionalized_graph(fn_to_trace,flat_args,meta=fw_metadata,aot_config=aot_config,trace_joint=False,)# As long as we opted to remove input mutations, then# there should be *NO* mutating ops in the graph at this point.copy_count=assert_functional_graph(fw_module.graph,allow_input_mutations=aot_config.keep_inference_input_mutations)fw_module.graph.eliminate_dead_code()fw_module.recompile()copy_count2=assert_functional_graph(fw_module.graph,allow_input_mutations=aot_config.keep_inference_input_mutations)assertcopy_count==copy_count2ifaot_config.enable_log:aot_graphs_log.info("%s",lazy_format_graph_code("Forward graph",fw_module,aot_config.aot_id))disable_amp=torch._C._is_any_autocast_enabled()context=disable_autocast_managerifdisable_ampelsenullcontextwithcontext(),track_graph_compiling(aot_config,"inference"):compiler=aot_config.inference_compilerifaot_config.inference_compilerisnotNoneelseaot_config.fw_compilerifconfig.functionalize_rng_ops:# Add the seed and offset as example inputs to pass to the compilerfake_mode=detect_fake_mode()seed,offset=CUDARngStateHelper.get_torch_state_as_tuple(fake_mode)flat_args=(seed,offset,*flat_args)compiled_fw=compiler(fw_module,flat_args)# Get the RNG functionalization related metadata to be used at runtime.rng_metadata=RNGMeta(False,PhiloxTotalOffsets(0,0))ifconfig.functionalize_rng_ops:rng_metadata=RNGMeta(config.functionalize_rng_ops,PhiloxStateTracker.get_accumulated_offsets())compiled_fn=create_runtime_wrapper(compiled_fw,runtime_metadata=fw_metadata,indices_of_inps_to_detach=[],trace_joint=False,keep_input_mutations=aot_config.keep_inference_input_mutations,disable_amp=disable_amp)@wraps(compiled_fn)defwrapper(*args):ifrng_metadata.is_compiled_with_functional_rng_ops:# Add the seed and offset to argsseed,offset=CUDARngStateHelper.get_torch_state_as_tuple()out=compiled_fn(seed,offset,*args)# Advance the rng state offsetCUDARngStateHelper.advance_torch_state(rng_metadata.philox_total_offsets.total_fwd_offset)returnoutelse:returncompiled_fn(*args)returnwrapper# Returns the number of detected copy_defassert_functional_graph(fx_g:torch.fx.Graph,*,allow_input_mutations:bool=False)->int:placeholders=set()copy_count=0# NB: It would also be nice to verify that the mutations all happen at the# end, but we also do some administrative views after mutations so this# isn't actually true. (TODO: Could this cause problems for Inductor?)forninfx_g.nodes:ifn.op=="placeholder":placeholders.add(n)ifisinstance(n.target,torch._ops.OpOverload):ifn.targetisaten.copy_.defaultandallow_input_mutations:suffix=True# Can only copy_ into an input, and can only do so onceassertn.args[0]inplaceholdersplaceholders.remove(n.args[0])copy_count+=1else:assertnotn.target._schema.is_mutable, \
f'aot_autograd expected to have an entirely functional graph, but found {n.format_node()}'returncopy_count@contextmanagerdefdisable_autocast_manager():guard=torch._C._DisableAutocast()try:yieldfinally:delguarddefare_differentiable_views(view1,view2):ifview1isview2:returnTrueifview1._baseisNoneandview2._baseisNone:returnFalseifview1._baseisview2._baseorview1._baseisview2orview1isview2._base:returnTruereturnFalsedefsame_dtype_views(view1,view2):ifview1.dtype!=view2.dtype:returnFalseifview1._baseisnotNoneandview1.dtype!=view1._base.dtype:returnFalseifview2._baseisnotNoneandview2.dtype!=view2._base.dtype:returnFalsereturnTrue# Note [Handling mutations on an input that aliases other inputs]# The easiest example to show-case this edge case is here:## def f(a, b):# a.mul_(2)# out = a + b# return out# b = torch.ones(...)# a = b.view(-1)# f(a, b)## In this situation, if a and b happened to be aliased, we need to trace something different!# Suppose we had b = a.view(-1)# (In this case, that means that `a._base is b`)## We need to ensure that the aliasing relationship between a and b is preserved.# We do that detecting the specific situation above (mutate an input that aliases another input),# and when we do that, we create a synthetic base argument. Then inside of the traced forward,# we regenerate a and b off of that base.# The complete example of the transformed function looks like this:## // The traced forward takes in a synthetic base, and regenerates the aliased inputs as views# // We could consider getting view-replay support here to minimize as_strided_scatter ops in the graph# def traced_forward(base):# a = base.as_strided(...)# b = base.as_strided(...)# a_updated = a.mul(2)# base_updated = torch.as_strided_scatter(base, a_updated, ...)# b_updated = base_updated.as_strided(...)# out = a_updated + b_updated# return a_updated, out## def compiled_fn(a, b):# // we detect that a is the "differentiable base" here# base = a# // In other situations, we might do either:# // (1) a and b are both views off of some larger differentiable base# // assert a._base is b._base and a._base is not None# // base = a._base# // (2) a and b both don't require gradients. Create a base from the storage# // assert a._base is None and b._base is None# // base = torch.Tensor(a.storage())# a_updated, out = traced_forward(base)# a.copy_(a_updated)# return out## This function:# (1) Merges input views into a synthetic base argument, when any of those input views are mutated# (2) Returns metadata telling the autograd.Function how to modify their arguments properly,# to respect the new calling convention.## The calling convention is as follows.# Any inputs that were originally views of one another get yanked, and replaced with a synthetic base.# The argument list ordering goes [base1, ..., baseN], [arg1, ..., argN],# Where the ordering of the bases is determined from the ordering of the original view args.# baseA will come before baseB if the earliest original argument coming from baseA# showed up earlier in the argument list than the earliest original argument coming from baseB.## Example, given some tensors a, b, c, d# call site:# f(a, c.view(-1), b.view(-1), b, c, d)# Modified argument list:# c_base comes first because the first c view came earlier in arg list than the first b view# a and d still show up in the modified arg list, but b and c don't- they're regenerated from their bases# b_base = torch.Tensor(b.storage())# c_base = torch.Tensor(c.storage())# f(c_base, b_base, a, d)defmerge_view_inputs(fwd_inputs:List[Any],mutated_input_info:List[InputAliasInfo],*,# The autograd case currently has more restrictions than the inference case.is_inference:bool,)->Tuple[List[Any],Optional[List[Union[int,Tuple[int,torch.Tensor]]]]]:assertlen(fwd_inputs)==len(mutated_input_info)storage_ref_to_idx:Dict[StorageWeakRef,List[int]]=collections.defaultdict(list)base_args=[]other_args=[]fori,inptinenumerate(fwd_inputs):ifisinstance(inpt,Tensor):storage_ref=StorageWeakRef(inpt.untyped_storage())storage_ref_to_idx[storage_ref].append(i)else:other_args.append(inpt)# Note [Synthetic Base Info Metadata]# This list contains metadata that tells you what the i'th argument in the inner calling convention should be.# It's either:# - another int (corresponding to the index in the argument list of the element from the outer calling convention)# - idx, view_tensor, where we can generate the new output with view_tensor._view_func(old_args[idx])# idx corresponds to which synthetic base from the outer calling context to viewinner_calling_convention_meta:Dict[int,Union[int,Tuple[int,torch.Tensor]]]={}foraliased_input_indicesinstorage_ref_to_idx.values():iflen(aliased_input_indices)<=1ornotany(# We only care about mutations that affect all aliases,# so metadata mutations on an input doesn't require us to do synthetic base handling.mutated_input_info[inpt_idx].mutates_dataforinpt_idxinaliased_input_indices):forcurr_idxinaliased_input_indices:other_args.append(fwd_inputs[curr_idx])continue# We detected an input that was mutated, AND aliases with another input.# we need to replace this set of aliased inputs with a single synthetic base.# For now, I'm banning a bunch of cases. We expect dynamo to properly detect these cases# and error out. We can fix them later.# These checks are transitive, so we don't need to check every pair.foridx1,idx2inzip(aliased_input_indices,aliased_input_indices[1:]):view1=fwd_inputs[idx1]view2=fwd_inputs[idx2]# The "inputs that are aliased but have different differentiable bases" case# is more complicated and hopefully pretty rare. Not currently handled.ifnotis_inference:assertare_differentiable_views(view1,view2),"aot_autograd() does not yet handle non-differentiable view input mutations."# Regenerating views when reinterpreting complex / real tensors seems non-trivial,# not handling for nowassertsame_dtype_views(view1,view2),"aot_autograd() does not yet handle input mutations on views with different dtypes."non_none_bases=[fwd_inputs[i]._baseforiinaliased_input_indicesiffwd_inputs[i]._baseisnotNone]aliases_with_none_bases=[fwd_inputs[i]foriinaliased_input_indicesiffwd_inputs[i]._baseisNone]iflen(non_none_bases)==0:# Case where none of the aliases have a ._base# we generate a synthetic base without gradients, and generate views off of it# We hit this case when we have input tensors to the graph that share a storage,# but do not have a ._base field.# Wondering when we hit this case?# The _base field simply says that autograd knows about the aliasing relationship,# but sometimes we create tensors which are aliased out of the same storage but guaranteed# to be disjoint. In these cases, we will skip setting up the _base relationship# for performance reasons (because the fact that the tensors share the same storage# is unobservable unless you (1) do naughty things with resize_/as_strided# or (2) look at the storage--as we are doing here.)# One particular example of this is optimizer steps on the LSTM module:# LSTM parameters are packed into a contiguous storage for efficiency reasons when# calling cuDNN kernels, so when these parameters get passed to the optimizer we will# find they share the same storage, but do not have _base set since they are all disjoint.## NOTE: There is one case where this is unsafe:# torch.Tensor(storage) will ALWAYS create a 1D tensor, which is not necessarily# the same shape as the "actual" base that the tensor came from.# For the most part this is fine, because we always use as_strided()# to generate the original aliased inputs again.# If we were to use view-replay though, this could cause the aliased views# to have incorrect sizes.example_idx=aliased_input_indices[0]example_alias=fwd_inputs[example_idx]# Note that this function is re-used at both trace time and rutnime.# At trace time, we're under a FakeMode so synthetic_base becomes a FakeTensor.synthetic_base=torch.empty((0,),dtype=example_alias.dtype,device=example_alias.device)# We don't actually have a convenient way of going from storage -> tensor,# So using set_() here (we suffer some minor overhead, but this case is rare).synthetic_base.set_(example_alias.untyped_storage())else:# Case where all of the aliases require gradients, and have the same _base.synthetic_base=non_none_bases[0]forother_baseinnon_none_bases[1:]:assert(other_baseissynthetic_base),"aot_autograd() does not yet handle non-differentiable view input mutations."foraliasinaliases_with_none_bases:assert(aliasissynthetic_base),"aot_autograd() does not yet handle non-differentiable view input mutations."base_args.append(synthetic_base)forcurr_view_idxinaliased_input_indices:curr_view=fwd_inputs[curr_view_idx]base_idx=len(base_args)-1# We store just enough info here so that we can regenerate the view later.# Regeneration: curr_view._view_func(args[base_idx])inner_calling_convention_meta[curr_view_idx]=(base_idx,curr_view)iflen(base_args)==0:assertlen(other_args)==len(fwd_inputs)# If no synthetic bases are necessary, just return the original inputs.returnfwd_inputs,Noneelse:# Otherwise, return:# (1) The new args according to the updated calling convention: (synthetic_bases, other_args)# (2) Metadata telling functionalization how to generate the inner argument list given the outer calling convention.# We post-process it into a list, where meta[i] tells you info about the i'th argument in the inner calling convention.args_to_functionalization=base_args+other_argsarg_to_old_idx_map={arg:ifor(i,arg)inenumerate(fwd_inputs)}fori,other_arginenumerate(other_args):new_idx=len(base_args)+iold_idx=arg_to_old_idx_map[other_arg]inner_calling_convention_meta[old_idx]=new_idx# post process into a listpost_processed_calling_convention_meta:List[Union[int,Callable]]=[-1for_inrange(len(inner_calling_convention_meta))]fork,vininner_calling_convention_meta.items():post_processed_calling_convention_meta[k]=v# Quick assert: every argument in the inner calling convention should be accounted for.forxinpost_processed_calling_convention_meta:assertx!=-1returnargs_to_functionalization,post_processed_calling_convention_metadefformat_guard_bug_msg(aot_config,expected):return(f"At compilation time, graph {aot_config.aot_id} was compiled under the "f"assumption that {expected}, but at runtime this was not the case. ""This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.")defremove_dupe_metadata(m:ViewAndMutationMeta,keep_arg_mask:List[bool],)->ViewAndMutationMeta:assertlen(m.input_info)==len(keep_arg_mask)# Easy invariant: the first argument should never be a dupe (it will be kept)assertlen(keep_arg_mask)>0andkeep_arg_mask[0]dupe_to_dedup_idx=[0]fori,binenumerate(keep_arg_mask[1:]):ifb:dupe_to_dedup_idx.append(dupe_to_dedup_idx[-1]+1)else:dupe_to_dedup_idx.append(dupe_to_dedup_idx[-1])# Filter dupe'd mutated inputs out of traced_tangentsnum_data_mutations=len([xforxinm.input_infoifx.mutates_data])other_traced_tangents=m.traced_tangents[num_data_mutations:]inp_traced_tangents=m.traced_tangents[:num_data_mutations]filtered_inp_traced_tangents=[xfori,xinenumerate(inp_traced_tangents)ifkeep_arg_mask[m.mutated_inp_indices[i]]]traced_tangents=filtered_inp_traced_tangents+other_traced_tangentsreturnViewAndMutationMeta(input_info=[xfori,xinenumerate(m.input_info)ifkeep_arg_mask[i]],# requires_grad_info consists of (mutated_inputs, forward_outputs).# Need to remove only the duplicate entries that correspond to the mutated inputs.requires_grad_info=[xfori,xinenumerate(m.requires_grad_info)ifi>=len(m.mutated_inp_indices)orkeep_arg_mask[m.mutated_inp_indices[i]]],# For outputs that are views of inputs, we store the index of the input that the output# was generated from. Need to update that index to account for removed dupes.output_info=[OutputAliasInfo(output_type=o.output_type,raw_type=o.raw_type,dynamic_dims=o.dynamic_dims,base_idx=Noneifo.base_idxisNoneelsedupe_to_dedup_idx[o.base_idx])foroinm.output_info],num_intermediate_bases=m.num_intermediate_bases,keep_input_mutations=m.keep_input_mutations,traced_tangents=traced_tangents,)# Given our ViewAndMutation metadata, this fn constructs a new set of metadata,# after adding synthetic base arguments to the function.# Most of the work in this fn is slogging through all of the metadata corresponding to inputs,# and updating it with our synthetic base calling convention.## When config.debug_assert is set, we automatically regenerate the metadata# and compare it to this output for sanity.## In addition to the updated metadata, also return the list of input indices# that will need to be updated in the synthetic base epiloguedefcreate_synthetic_base_metadata(m:ViewAndMutationMeta,# Maps each outer argument idx to its inner idx (or, if this outer arg is generated from a# synthetic base, you get a tuple of (i, TensorMeta), telling you the base tensor idx, and view metadata)synthetic_base_info:List[Union[int,Tuple[int,torch.Tensor]]],outer_args:List[Any],inner_args:List[Any],)->Tuple[ViewAndMutationMeta,List[int]]:S_Outer=NewType('S_Outer',int)S_Inner=NewType('S_Inner',int)synthetic_base_to_indices:Dict[S_Inner,List[S_Outer]]={}forinner_idxinrange(len(inner_args)):outer_aliased_indices_of_current_base_arg=[outer_idxforouter_idx,inner_idx_or_tupleinenumerate(synthetic_base_info)if(isinstance(inner_idx_or_tuple,int)andinner_idx_or_tuple==inner_idx)or(isinstance(inner_idx_or_tuple,tuple)andinner_idx_or_tuple[0]==inner_idx)]synthetic_base_to_indices[inner_idx]=outer_aliased_indices_of_current_base_arg# given the requires_grad info on mutated inputs,# generate the requires_grad info on those same mutated inputs, but after constructing synthetic bases.input_infos=[]mutated_inp_require_grad_info=[]for_,outer_indicesinsynthetic_base_to_indices.items():# leaf-ness should be all-or-nothing for aliased tensor.# (aka if "a" and "b" are views, then a.is_leaf == b.is_leaf)any_leaf=any(m.input_info[x].is_leafforxinouter_indices)all_leaf=all(m.input_info[x].is_leafforxinouter_indices)assertany_leaf==all_leafinpt_info=InputAliasInfo(# If len(outer_indices) > 1, then this input is a synthetic base.# The invariant is that to the rest of aot autograd, synthetic bases only show up if# one of their aliases gets a data mutation. And if any of their aliases get metadata# mutations, they will be hidden from the rest of aot autograd.mutates_data=Trueiflen(outer_indices)>1elsem.input_info[outer_indices[0]].mutates_data,mutates_metadata=Falseiflen(outer_indices)>1elsem.input_info[outer_indices[0]].mutates_metadata,is_leaf=any_leaf,)input_infos.append(inpt_info)# requires_grad_info consists of (mutated_inputs, forward_outputs).# For any mutated inputs that correspond to aliased inputs,# Need to replace them with their mutated synthetic baseifinpt_info.mutates_dataorinpt_info.mutates_metadata:mutated_inp_require_grad_info.append(any(m.requires_grad_info[x]forxinouter_indices))# Find any inputs that fulfill the following criteria:# (1) They are part of a synthetic base (because they alias another input,# and at least one input experiences a data mutation)# (2) They experience a metadata mutationouter_aliased_arg_idx_with_metadata_mutations=[outer_idxforouter_idx,inpt_infoinenumerate(m.input_info)ifinpt_info.mutates_metadataandnotisinstance(synthetic_base_info[outer_idx],int)]# grab the original requires grad info on the outputs, except the ones from the mutated inputsnum_original_input_data_mutations=len([xforxinm.input_infoifx.mutates_dataorx.mutates_metadata])output_grad_info=m.requires_grad_info[num_original_input_data_mutations:]input_metadata_mutation_grad_info=[outer_args[outer_idx].requires_gradforouter_idxinouter_aliased_arg_idx_with_metadata_mutations]input_metadata_output_info=[OutputAliasInfo(output_type=OutputType.alias_of_input,raw_type=torch.Tensor,dynamic_dims={ifori,sinenumerate(outer_args[outer_idx].shape)ifnotis_concrete_int(s)},base_idx=synthetic_base_info[outer_idx][0],)forouter_idxinouter_aliased_arg_idx_with_metadata_mutations]existing_output_infos=[OutputAliasInfo(output_type=o.output_type,raw_type=o.raw_type,dynamic_dims=o.dynamic_dims,# Map the input idx pre-synthetic-bases to the new idx post-synthetic-basesbase_idx=Noneifo.base_idxisNoneelsesynthetic_base_info[o.base_idx]ifisinstance(synthetic_base_info[o.base_idx],int)elsesynthetic_base_info[o.base_idx][0])foroinm.output_info]num_outer_mutated_data_inps=len([xforxinm.input_infoifx.mutates_data])inner_mutated_data_inps=[xforinner_idx,xinenumerate(inner_args)ifinput_infos[inner_idx].mutates_data]requires_grad_info=mutated_inp_require_grad_info+output_grad_info+input_metadata_mutation_grad_infooutput_info=existing_output_infos+input_metadata_output_info# Regenerate traced tangents to include mutated inputs including synthetic basestraced_tangents=inner_mutated_data_inps+m.traced_tangents[num_outer_mutated_data_inps:]returnViewAndMutationMeta(input_info=input_infos,requires_grad_info=requires_grad_info,output_info=output_info,num_intermediate_bases=m.num_intermediate_bases,keep_input_mutations=m.keep_input_mutations,traced_tangents=traced_tangents,),outer_aliased_arg_idx_with_metadata_mutations# MOTIVATION:## When tracing functions for future execution, one must be careful not to pass# in the same input tensor multiple times (e.g., f(x, x), as this can result# in graphs that are ONLY valid if you later pass a new tensor in exactly the# same way (e.g., f(y, y)). (NB: we really mean duplicate; two distinct# tensors that alias each other is a different situation that is covered by# aot_dispatch_deduplicated_autograd). Here are two examples:## (1) Suppose you have a function:## def f(x, y):# return x + y## If you make_fx(f)(x, x), you will trace out:## def f(x, y):# return y + y## Oops!## (2) For most tensors x and y, you can compute f's gradient with respect to# these to inputs by saying torch.autograd.grad(f(x, y), (x, y)). However,# if x is y, you will trace out a program that gets incorrect gradients:## >>> x = torch.randn(1, requires_grad=True)# >>> torch.autograd.grad(x + x, (x, x))# (tensor([2.]), tensor([2.]))## In other words, the gradient is double-counted. Deduplicating the arguments# gives you an appropriate gradient:## >>> y = torch.randn(1, requires_grad=True)# >>> torch.autograd.grad(x + y, (x, y))# (tensor([1.]), tensor([1.]))## HOW TO DEDUPLICATE:## There are a few strategies, in order of preference:## 1. For every duplicate argument to the function, detach it into# a separate leaf tensor, so that it is no longer duplicated.## PRO: The resulting compiled graph works for any configuration# of duplicated arguments.## CON: It does not (naively) work if you mutate the metadata of inputs:## def f(x, y):# x.transpose_(0, 1)# y.transpose_(0, 2)## x = torch.randn(2, 3, 4)# f(x, x)## The ordering of the transposes inside f dictates whether or not# you get [4, 2, 3] or [3, 4, 2]. This means that you cannot precompute# what metadata mutations should get applied to each input; you need to# assume they aren't duplicates (what we do today) or preserve# the original metadata mutations exactly in order, so that they work# for any duplicate configuration.## CON: It does not (naively) work if you mutate the data of inputs.# In particular, leaf tensors that require grad cannot be mutated,# this makes it impossible to differentiate with respect to the original# base.## 2. For every duplicate argument to the function, remove it, so it is# no longer part of the "true" signature:## PRO: Implemented naively, it still works for metadata/data mutation.## CON: The resulting compiled graph is duplicate-specialized: it only# works if future calls duplicate arguments in exactly the same way.# Horribly, Dynamo doesn't guard on this at the moment. But even if# it did, you could still end up recompiling a bunch of each duplicate.## Our strategy is to do (1) if we can, and do (2) otherwise, erroring if# Dynamo's guards are not enough. In practice, this seems to cover# everything.#defaot_wrapper_dedupe(flat_fn,flat_args:List[Tensor],aot_config:AOTConfig,*,compiler_fn,fw_metadata,):# Use information about whether or not flat_fn mutates its arguments# or not to handle dupe args# Strategy 1: For any input that is not mutated, we can leafify it if we# need to remove a duplicate.leaf_flat_args=[]args_set=set()ok=Truefori,ainenumerate(flat_args):ifnotisinstance(a,torch.Tensor):leaf_flat_args.append(a)elifanotinargs_set:args_set.add(a)leaf_flat_args.append(a)elifnotfw_metadata.input_info[i].mutates_dataandnotfw_metadata.input_info[i].mutates_metadata:leaf_flat_args.append(a.detach().requires_grad_(a.requires_grad))else:ok=Falsebreakifok:returncompiler_fn(flat_fn,leaf_flat_args,aot_config,fw_metadata=fw_metadata)# Strategy 2: Duplicate specialize.## In Haskell types, suppose you have:## add_dupe_args :: DedupedArgs -> Args# remove_dupe_args :: Args -> DedupedArgs## compiler_fn# :: (DedupedArgs -> R) -> DedupedArgs -> AOTConfig -> (DedupedArgs -> R)# deped_compiler_fn# :: (Args -> R) -> Args -> AOTConfig -> (Args -> R)## Then the code below can be written in point-free style as:## deduped_compiler_fn f a c =# compiler_fn (f . add_dupe_args) (remove_dupe_args a) c . remove_dupe_args## Suppose you have:## [a, b, a, c]## We want:## remove_dupe_args([a, b, a, c]) == [a, b, c]# add_dupe_args([a, b, c]) == [a, b, a, c]## This is done via (respectively):## seen_args = {a: 0, b: 1, c: 2}# enumerate(add_dupe_map) = [ # how to get args from the deduped list# (0, 0),# (1, 1),# (2, 0),# (3, 2),# ]# keep_arg_mask = [True, True, False, True]seen_args={}keep_arg_mask=[]# Implicitly map duped arg position (list index) to de-duped arg positionadd_dupe_map:List[int]=[]duped_arg_len=len(flat_args)j=0# index into deduped_flat_argsfori,tinenumerate(flat_args):iftinseen_args:keep_arg_mask.append(False)add_dupe_map.append(seen_args[t])continuekeep_arg_mask.append(True)seen_args[t]=jadd_dupe_map.append(j)j+=1assertlen(add_dupe_map)==duped_arg_len,(f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}")# NB: Hot path, avoid set lookups here# TODO: Can avoid the zip here too, probablydefremove_dupe_args(args):return[tfort,keepinzip(args,keep_arg_mask)ifkeep]defadd_dupe_args(args):return[args[add_dupe_map[i]]foriinrange(duped_arg_len)]deduped_flat_args=remove_dupe_args(flat_args)# Update our input metadata to remove duped input metadata.updated_fw_metadata=remove_dupe_metadata(fw_metadata,keep_arg_mask)tracing_context=TracingContext.get()iftracing_contextandaot_config.aot_autograd_arg_pos_to_source:# TODO(voz): This structure is 1:1, we could consider an alternate structure like# kept_pos:[dupe_arg_pos], however, add_dupe_map is 1:1 so we would need a new structure there,# which feels like needless complexity for a tiny bit of efficiency at this point.fordupe_arg_pos,(kept_pos,keep_arg)inenumerate(zip(add_dupe_map,keep_arg_mask)):ifnotkeep_arg:dupe_arg_source=aot_config.aot_autograd_arg_pos_to_source[dupe_arg_pos]kept_arg_source=aot_config.aot_autograd_arg_pos_to_source[kept_pos]tracing_context.guards_context.aotautograd_guards.append(DuplicateInputs(kept_arg_source,dupe_arg_source))@wraps(flat_fn)defwrapped_flat_fn(*args):returnflat_fn(*add_dupe_args(args))ifconfig.debug_assert:ref_fw_metadata=run_functionalized_fw_and_collect_metadata(wrapped_flat_fn,keep_input_mutations=fw_metadata.keep_input_mutations,)(*deduped_flat_args)assertref_fw_metadata==updated_fw_metadata, \
f'ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}'compiled_fn=compiler_fn(wrapped_flat_fn,deduped_flat_args,aot_config,fw_metadata=updated_fw_metadata)ifnothasattr(compiled_fn,"_boxed_call"):compiled_fn=make_boxed_func(compiled_fn)@wraps(compiled_fn)defwrapped_compiled_fn(args):deduped_args=remove_dupe_args(args)args.clear()returncompiled_fn(deduped_args)wrapped_compiled_fn._boxed_call=True# This can be uncommented when we properly guard for duplicates,# but right now we must not do it.# if not config.debug_assert:# return wrapped_compiled_fn@wraps(wrapped_compiled_fn)defdebugged_compiled_fn(args):# Test that the computed remove/add arg functions are an inversenew_args=add_dupe_args(remove_dupe_args(args))seen={}fori,(x,y)inenumerate(zip(new_args,args)):seen[y]=Noneassertxisy,format_guard_bug_msg(aot_config,f"{describe_input(i,aot_config)} would be a duplicate of "f"{describe_input(add_dupe_map[i],aot_config)}",)# This is only an error if there is metadata mutation on both of# the duped arguments; in this case, we need to know what order# the metadata mutation applies in. You'll get the correct result# otherwise, because a graph that assumes distinct inputs works if# you dupe the inputs (the gradient contributions from each input# will get summed up appropriately.)## TODO: work out how to setup this assert correctly""" assert len(seen) == unique_args, format_guard_bug_msg(aot_config, f"there would be {unique_args} distinct arguments" ) """returnwrapped_compiled_fn(args)debugged_compiled_fn._boxed_call=Truereturndebugged_compiled_fn# This layer handles the situation where you have two inputs that alias each other,# and one of the inputs is mutated.# We need to take special care to ensure that the mutation is applied to the other aliases in the graph.## pre-condition: aot_wrapper_dedup has already run.# (This function will in theory work if there are duplicate args.# However, the synthetic base code path is a bit sub-optimal, and running with dupe'd inputs# would cause us to hit that path more frequently).defaot_wrapper_synthetic_base(flat_fn,flat_args:List[Tensor],aot_config:AOTConfig,*,fw_metadata:ViewAndMutationMeta,# Currently, the only reason we need to plumb this bool is because# the synthetic base code prohibits more cases in the autograd case than the inference case.needs_autograd:bool,compiler_fn,):is_inference=notneeds_autogradflat_args_with_synthetic_bases,synthetic_base_info=merge_view_inputs(flat_args,fw_metadata.input_info,is_inference=is_inference,)# Happy path: we don't need synthetic basesifsynthetic_base_infoisNone:returncompiler_fn(flat_fn,flat_args,aot_config,fw_metadata=fw_metadata)assertlen(fw_metadata.input_info)==len(synthetic_base_info)# Update our forward metadata to take synthetic bases into accountfw_metadata_updated,aliased_arg_idx_with_metadata_mutations= \
create_synthetic_base_metadata(fw_metadata,synthetic_base_info,flat_args,flat_args_with_synthetic_bases)num_aliased_args_with_metadata_mutations=len(aliased_arg_idx_with_metadata_mutations)defunpack_synthetic_bases(primals:List[Any])->List[Any]:f_args_inner=[]forinner_idx_or_tupleinsynthetic_base_info:ifisinstance(inner_idx_or_tuple,int):f_args_inner.append(primals[inner_idx_or_tuple])else:inner_base_idx,view_tensor=inner_idx_or_tuplebase=primals[inner_base_idx]view_arg=gen_alias_from_base(base,view_tensor,view_tensor.requires_grad)f_args_inner.append(view_arg)returnf_args_inner@wraps(flat_fn)defwrapped_flat_fn(*args):unpacked_args=unpack_synthetic_bases(args)# This is a bit subtle. The goal of this entire function (aot_dispatch_synthetic_bases)# is to relieve the downstream logic from having to reason about mutations on inputs that alias# each other, by replacing aliased inputs with a synthetic base.# One area where this breaks down a bit however is if one of those aliased inputs# experienced a metadata mutation.# We are now obligated to reapply the metadata mutation directly to the user's input;# it isn't enough to apply mutations back to the synthetic base in the downstream logic.## The way we handle this is by pretending that those aliased inputs that experience metadata mutations# are additional outputs in the user's forward function.# The downstream logic will just treat these as "user outputs that alias inputs".# However, we will manually grab them at runtime here, use them to reapply the metadata mutation# to the user inputs, and not return them to the user.aliased_args_with_metadata_mutations=[xfori,xinenumerate(unpacked_args)ifiinaliased_arg_idx_with_metadata_mutations]iflen(aliased_args_with_metadata_mutations)>0:return*(flat_fn(*unpacked_args)),*aliased_args_with_metadata_mutationselse:returnflat_fn(*unpacked_args)ifconfig.debug_assert:ref_fw_metadata=run_functionalized_fw_and_collect_metadata(wrapped_flat_fn,keep_input_mutations=fw_metadata.keep_input_mutations,)(*flat_args_with_synthetic_bases)assertref_fw_metadata==fw_metadata_updated,(f'ref_metadata={pprint.pformat(partial_asdict(ref_fw_metadata))}, 'f'actual_metadata={pprint.pformat(partial_asdict(fw_metadata_updated))}')compiled_fn=compiler_fn(wrapped_flat_fn,flat_args_with_synthetic_bases,aot_config,fw_metadata=fw_metadata_updated)ifnothasattr(compiled_fn,"_boxed_call"):compiled_fn=make_boxed_func(compiled_fn)@wraps(compiled_fn)defwrapped_compiled_fn(args):args_with_synthetic_bases,synthetic_base_info=merge_view_inputs(args,fw_metadata.input_info,is_inference=is_inference)assertsynthetic_base_infoisnotNonealiased_args_w_metadata_mutations=[args[i]foriinaliased_arg_idx_with_metadata_mutations]args.clear()outs=compiled_fn(args_with_synthetic_bases)ifnum_aliased_args_with_metadata_mutations>0:# This code does not handle **all** input metadata mutations.# Instead, it only handles metadata mutations on inputs that were converted into synthetic bases# (which only happens if at least one aliased input experienced a data mutation).# e.g:# def f(a, b):# a.mul_(2)# b.t_(1, 0)# f(x.view(2, 2), x.view(2, 2))mutated_metadata_inps=outs[-num_aliased_args_with_metadata_mutations:]user_outs=outs[:-num_aliased_args_with_metadata_mutations]forinp,mutated_inpinzip(aliased_args_w_metadata_mutations,mutated_metadata_inps):inp.as_strided_(mutated_inp.size(),mutated_inp.stride(),mutated_inp.storage_offset())returnuser_outsreturnoutsreturnwrapped_compiled_fndefdescribe_input(i,aot_config):ifi<aot_config.num_params_buffers:returnf"parameter/buffer {i}"else:returnf"input {i-aot_config.num_params_buffers}"# The wrapper created by this function handles all of the runtime aliasing and mutation "epilogue" logic# that needs to run after the compiled function.## This function accepts a trace_joint flag, indicating whether or not we're generating the runtime# epilogue for a forward-only inference graph, or for an autograd.Function.apply function.# This is because there are some minor differences in how we treat these cases at runtime:# - resize_() is currently handled in the inference case, but not fully handled in the autograd case.# - the autograd cases inserts TensorAlias wrapper objects for outputs that alias inputsdefcreate_runtime_wrapper(compiled_fn,*,runtime_metadata:ViewAndMutationMeta,indices_of_inps_to_detach:List[int],trace_joint:bool,keep_input_mutations:bool,disable_amp:bool):ifnothasattr(compiled_fn,"_boxed_call"):compiled_fn=make_boxed_func(compiled_fn)defruntime_wrapper(*args):iftrace_joint:args_=list(args)# See Note [Detaching inputs that never need gradients]foridxinindices_of_inps_to_detach:ifisinstance(args_[idx],torch.Tensor):args_[idx]=args_[idx].detach()withtorch.autograd._force_original_view_tracking(True):all_outs=call_func_with_args(compiled_fn,args_,disable_amp=disable_amp,)else:all_outs=call_func_with_args(compiled_fn,args,disable_amp=disable_amp,)num_mutated_inps=runtime_metadata.num_mutated_inputsnum_metadata_mutated_inps=runtime_metadata.num_mutated_metadata_inputsnum_intermediate_bases=runtime_metadata.num_intermediate_basesifkeep_input_mutations:assert(len(all_outs)==num_metadata_mutated_inps+runtime_metadata.num_outputs+num_intermediate_bases)assert(len(runtime_metadata.mutated_inp_runtime_indices)==num_metadata_mutated_inps)else:assert(len(all_outs)==num_mutated_inps+runtime_metadata.num_outputs+num_intermediate_bases)assert(len(runtime_metadata.mutated_inp_runtime_indices)==num_mutated_inps)# Step 3: After running the compiled fw, apply updates to mutated inputsnum_mutations_to_apply=len(runtime_metadata.mutated_inp_runtime_indices)ifnum_mutations_to_apply>0:updated_inputs=all_outs[:num_mutations_to_apply]fw_outs=all_outs[num_mutations_to_apply:]fori,inpt_idxinenumerate(runtime_metadata.mutated_inp_runtime_indices):meta=runtime_metadata.input_info[inpt_idx]ifnotmeta.mutates_dataandnotmeta.mutates_metadata:continueoriginal_inpt=args[inpt_idx]updated_inpt=updated_inputs[i]# TODO: add better resize_() support for autograd case.# Check for the case when an input has been resized.# Note: One important thing to check for is user code that calls inpt.storage().resize_().# We can't trace operations on storage into the graph, so we should get dynamo to graph break.# TODO: handle resize_() on inputs to a larger size.# This is actually non-trivial to detect, so we should probably just handle it# (or make dynamo detect).# We can't just check of original_inpt.storage_size != updated_inpt.storage_size,# Because the original_inpt might be a view of some larger tensor,# and updated_inpt is always densely packed.ifnottrace_jointandoriginal_inpt.storage().size()!=updated_inpt.storage().size():original_inpt.resize_(updated_inpt.size())ifmeta.mutates_metadataandnotmeta.mutates_data:iftrace_joint:assertisinstance(updated_inpt,TensorAlias)updated_inpt=updated_inpt.alias# We need to grab the size/stride/storage_offset from the compiled forward,# and use that to mutate the metadata of the inputoriginal_inpt.as_strided_(updated_inpt.size(),updated_inpt.stride(),updated_inpt.storage_offset(),)else:ifmeta.mutates_dataandmeta.mutates_metadata:original_inpt.as_strided_(updated_inpt.size(),updated_inpt.stride(),updated_inpt.storage_offset(),)else:assertmeta.mutates_dataifmeta.is_leafandoriginal_inpt.requires_grad:# We can hit this situation in this case:# def f(x):# x.detach().mul_(2)# return x + 1# AOTAutograd will see a mutation in the above case, and try to# apply a copy_() here, in the epilogue.# But if x required gradients, and is a leaf, then autograd# will yell at us for trying to mutate it.# However, it's only possible to end up in this scenario (like the above)# if all of the mutations to the leaf input were non-autograd-tracking mutations# (aka mutations under no_grad(), or on detached views).# In that case, we fully want to hide the mutation from autograd, so detaching is ok.original_inpt.detach().copy_(updated_inpt)else:original_inpt.copy_(updated_inpt)else:fw_outs=all_outs# Step 4: Manually regenerate any outputs that are aliased to inputs, instead of# compiling them.ifruntime_metadata.num_outputs_aliased>0:# The compiled forward also returned intermediate bases. We don't want to return them to the user.ifruntime_metadata.num_intermediate_bases>0:fw_outs_no_intermediate_bases=fw_outs[:-runtime_metadata.num_intermediate_bases]intermediate_bases=fw_outs[-runtime_metadata.num_intermediate_bases:]else:fw_outs_no_intermediate_bases=fw_outsintermediate_bases=[]assertlen(fw_outs_no_intermediate_bases)==len(runtime_metadata.output_info)fw_outs_including_aliases=[]fori,(o,info)inenumerate(zip(fw_outs_no_intermediate_bases,runtime_metadata.output_info)):ifinfo.output_type==OutputType.non_aliasorinfo.output_type==OutputType.unsafe_view_alias:fw_outs_including_aliases.append(o)continueiftrace_joint:assertisinstance(o,TensorAlias)o_=o.aliaselse:o_=oo_grad=runtime_metadata.requires_grad_info[runtime_metadata.num_mutated_inputs+i]ifinfo.output_type==OutputType.alias_of_input:aliased_base_tensor=args[info.base_idx]regenerated_out=gen_alias_from_base(aliased_base_tensor,o_,o_grad)fw_outs_including_aliases.append(regenerated_out)continueelifinfo.output_type==OutputType.is_input:aliased_base_tensor=args[info.base_idx]regenerated_out=aliased_base_tensorfw_outs_including_aliases.append(regenerated_out)continueelifinfo.output_type==OutputType.alias_of_intermediate:base_tensor_list=intermediate_baseselifinfo.output_type==OutputType.alias_of_intermediate_save_as_output:base_tensor_list=intermediate_baseselse:assertinfo.output_type==OutputType.alias_of_intermediate_base_is_user_outputbase_tensor_list=fw_outs_no_intermediate_basesaliased_base_tensor=base_tensor_list[info.base_idx]# TODO: handle the custom autograd function case here.# We need a way to check whether a tensor came from a custom autograd fn from python,# AND a way to replay that custom view fn.regenerated_out=gen_alias_from_base(aliased_base_tensor,o_,o_grad)fw_outs_including_aliases.append(regenerated_out)ret_outs=fw_outs_including_aliaseselse:ret_outs=fw_outsifruntime_metadata.dynamic_outputs:fort,oinzip(ret_outs,runtime_metadata.output_info):ifo.dynamic_dimsisNone:continueifhasattr(t,'_dynamo_weak_dynamic_indices'):t._dynamo_weak_dynamic_indices|=o.dynamic_dimselse:t._dynamo_weak_dynamic_indices=o.dynamic_dims.copy()returnret_outsreturnruntime_wrapperdefcreate_functionalized_rng_ops_wrapper(func,args,trace_joint=True):# Functionalization of rng ops changes the calling convention of the joint graph.# It goes from (primals, tangents) to (seed, offset, primals, tangents)# At runtime, we pass on the current seed and offset. This is hidden from# the user.fake_mode=detect_fake_mode()defoverride_get_rng_state(device:Union[int,str,torch.device]='cuda'):out=PhiloxStateTracker.get_state_as_tensor()returnoutdefoverride_set_rng_state(x,device:Union[int,str,torch.device]='cuda'):PhiloxStateTracker.set_state_from_tensor(x)deftraced_joint(fwd_seed,fwd_base_offset,bwd_seed,bwd_base_offset,primals,tangents):withpatch("torch.cuda.get_rng_state",override_get_rng_state),patch("torch.cuda.set_rng_state",override_set_rng_state):returnfunc(primals,tangents)deftraced_forward(fwd_seed,fwd_base_offset,*primals):withpatch("torch.cuda.get_rng_state",override_get_rng_state),patch("torch.cuda.set_rng_state",override_set_rng_state):returnfunc(*primals)iftrace_joint:# Get the current seed and offset to setup tracing.fwd_seed,fwd_base_offset=CUDARngStateHelper.get_torch_state_as_tuple(fake_mode)bwd_seed,bwd_base_offset=CUDARngStateHelper.get_torch_state_as_tuple(fake_mode)PhiloxStateTracker.record_state(fwd_seed,fwd_base_offset,"forward")PhiloxStateTracker.record_state(bwd_seed,bwd_base_offset,"backward")returntraced_joint,(fwd_seed,fwd_base_offset,bwd_seed,bwd_base_offset,*args)else:# Get the current seed and offset to setup tracing.fwd_seed,fwd_base_offset=CUDARngStateHelper.get_torch_state_as_tuple(fake_mode)PhiloxStateTracker.record_state(fwd_seed,fwd_base_offset,"forward")returntraced_forward,(fwd_seed,fwd_base_offset,*args)# Has the precondition that there# are no duplicate arguments in flat_args (e.g., the same Tensor# object never shows up twice. However, two tensor inputs MAY alias# the same storage, so long as they have separate TensorImpls.)defaot_dispatch_autograd(flat_fn,flat_args:List[Any],aot_config:AOTConfig,*,fw_metadata:ViewAndMutationMeta):# traced_tangents corresponds to the set of outputs in the traced forward that should get grad_outputs in the traced backward.# It includes outputs of the original forward, *and* any updated inputs due to input mutations.# However, it does *not* include any outputs that are aliases of inputs or intermediates, or any metadata-only input mutations.traced_tangents=pytree.tree_map(lambdax:x.detach().contiguous()ifisinstance(x,Tensor)elsex,fw_metadata.traced_tangents,)assertlen(fw_metadata.requires_grad_info)==fw_metadata.num_mutated_inputs+fw_metadata.num_outputsjoint_inputs=(flat_args,traced_tangents)disable_amp=torch._C._is_any_autocast_enabled()fn_prepared_for_autograd=fn_prepped_for_autograd(flat_fn,fw_metadata,)joint_fn_to_trace=create_joint(fn_prepared_for_autograd)fx_g=create_functionalized_graph(joint_fn_to_trace,joint_inputs,meta=fw_metadata,aot_config=aot_config,trace_joint=True,)# There should be *NO* mutating ops in the graph at this point.assert_functional_graph(fx_g.graph)# Redudant with the check above, but worth having in case tracing introduced# a fake tensor. Unlikely.# See Note: [Fake Modules and AOTAutograd]torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g)fx_g.graph.eliminate_dead_code()fx_g.recompile()ifaot_config.enable_log:aot_joint_log.info("%s",lazy_format_graph_code("Joint graph",fx_g,aot_config.aot_id))withtorch.no_grad():withtrack_graph_compiling(aot_config,"joint"):num_inner_fwd_outputs=fw_metadata.num_mutated_inputs+fw_metadata.num_outputs+fw_metadata.num_intermediate_basesfw_module,bw_module=aot_config.partition_fn(fx_g,joint_inputs,num_fwd_outputs=num_inner_fwd_outputs)fw_outs=[nforninfw_module.graph.nodesifn.op=="output"][0].args[0]# we only need to bookkeep the symints that are saved for bw, not any symints# the user forward might have returned in its own outputfw_outs_saved_for_bw=fw_outs[num_inner_fwd_outputs:]symint_outs_saved_for_bw=[nforninfw_outs_saved_for_bwifis_sym_node(n)]_num_symints_saved_for_bw=len(symint_outs_saved_for_bw)# Note [Detaching inputs that never need gradients]# See https://github.com/pytorch/pytorch/issues/97745# Suppose we have a function like this that we want to compile:## def f(x, y):# return torch.mul(x, y.detach())## What gradients should we compute for x and y?# By default, AOTAutograd will compute a gradient for **every** input that requires gradients,# and so we'll compute:# x_grad_input = y# y_grad_input = None# Does this preserve the semantics of eager mode?# Unfortunately, no.# Doing the above will cause autograd to **continue** to backprop the autograd tape# that was generated from constructing y.## This is **different** from what would have happened in eager mode.# In eager mode, if we backprop through the output of this function, autograd will only traverse# the bit of the autograd tape corresponding to "x".# In particular, if a user had previously backpropped through y's autograd tape,# And then they try to backprop through the output of the above function,# then we'll hit the dreaded "Trying to backward through the graph a second time" error.## You might think: If autograd sees that a gradient is None, shouldn't it stop early,# instead of continuing the backprop through the ancestors of that node in the graph?## Autograd has two passes:# (1) a first pass that traverses the autograd graph and figures out which nodes need to be executed# (2) a second pass that actually goes ahead and executes each node when it becomes ready,# propagating gradients# By the time we're executing a node and we see that it produces a None, the set of nodes to execute# is already locked-in.## The fix: instead, we can recognize statically that the graph we're compiling will never contribute# gradients to y, and prevent autograd from trying to traverse y's autograd tape at all.# We can do this by manually detach'ing y before sending it through the `CompiledFunction`.## Note that this solution is not bulletproof.# It's possible to construct a case where eager may or may not have have tried to autograd through y,# depending on the actual grad_outputs that were passed in during the backward.# There is no easy fix for this: the simplest fix would be to run with `retain_graph=True`,# allowing autograd to re-use the graph.## An example of this case is:# def f(x):# return x.detach() * 2, x * 3# If we were to only backprop through outs[0], in eager, we would stop# If we backward only on the first output, we shouldn't send a grad through x.# But the custom autograd function doesn't know that: it will materialize zero grads for x * 3# and we will end up with a zero grad at x.# If we later backprop through the second output, this will also require backprop'ing through x.# Meaning we'll need to use `retain_graph=True` to be able to backprop through x the second time._indices_of_inps_to_detach=[]bw_outs=[nforninbw_module.graph.nodesifn.op=="output"][0].args[0]assertlen(bw_outs)==len(fw_metadata.input_info)fori,(bw_out)inenumerate(bw_outs):ifbw_outisNone:_indices_of_inps_to_detach.append(i)ifaot_config.enable_log:aot_graphs_log.info("%s",lazy_format_graph_code("Forward graph",fw_module,aot_config.aot_id))aot_graphs_log.info("%s",lazy_format_graph_code("Backward graph",bw_module,aot_config.aot_id))withtrack_graph_compiling(aot_config,"forward"):ifconfig.functionalize_rng_ops:# Update example inputs for the fw_compilerfake_mode=detect_fake_mode()seed,offset=CUDARngStateHelper.get_torch_state_as_tuple(fake_mode)flat_args=(seed,offset,*flat_args)compiled_fw_func=aot_config.fw_compiler(fw_module,flat_args)# Get the rng_metadata so that it can be used at runtimerng_metadata=RNGMeta(False,PhiloxTotalOffsets(0,0))ifconfig.functionalize_rng_ops:rng_metadata=RNGMeta(config.functionalize_rng_ops,PhiloxStateTracker.get_accumulated_offsets())# Total offsets differ for each AOT traced graph, so they can't be saved# in the PhiloxStateTracker singleton object. Therefore, we save them# into rng_meta side datastructure.saved_context=TracingContext.get()classCompiledFunction(torch.autograd.Function):compiled_fw=compiled_fw_funccompiled_bw=Nonemetadata=fw_metadatanum_symints_saved_for_bw=_num_symints_saved_for_bw@staticmethoddefforward(ctx,*deduped_flat_tensor_args):args=deduped_flat_tensor_argsifrng_metadata.is_compiled_with_functional_rng_ops:# Add the seed and offset to argsseed,offset=CUDARngStateHelper.get_torch_state_as_tuple()args=(seed,offset,*args)# There is a pretty complicated calling convention around what the compiled fw returns.# The full list of outputs and their relative order is:# (*mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints)# - Note that in the synthetic bases case, mutated_inputs will correspond to an updated version# of the original view, and not the synthetic basefw_outs=call_func_with_args(CompiledFunction.compiled_fw,args,disable_amp=disable_amp,)num_outputs=CompiledFunction.metadata.num_outputsnum_outputs_aliased_to_inputs=(CompiledFunction.metadata.num_outputs_aliased_to_inputs)num_outputs_aliased_to_intermediates=(CompiledFunction.metadata.num_outputs_aliased_to_intermediates)num_outputs_aliased=CompiledFunction.metadata.num_outputs_aliasednum_intermediate_bases=CompiledFunction.metadata.num_intermediate_basesnum_symints_saved_for_bw=CompiledFunction.num_symints_saved_for_bwnum_mutated_inputs=CompiledFunction.metadata.num_mutated_inputsnum_mutated_metadata_only_inputs=(CompiledFunction.metadata.num_mutated_metadata_only_inputs)# Our forward() returns both (mutated_inputs, outputs, output_intermediate_bases, saved_tensors, saved_symints)num_forward_returns=num_mutated_inputs+num_outputs+num_intermediate_basesassertnum_forward_returns==len(CompiledFunction.metadata.requires_grad_info)+num_intermediate_bases# Partitioners must put symint arguments at the end separate from tensor argumentsifnum_symints_saved_for_bw>0:tensors_saved_for_backwards=fw_outs[num_forward_returns:-num_symints_saved_for_bw]assertall([isinstance(x,torch.Tensor)forxintensors_saved_for_backwards])# See Note [Detaching saved tensors in AOTAutograd]ctx.save_for_backward(*(x.detach()ifx._is_view()elsexforxintensors_saved_for_backwards))symint_outs=fw_outs[-num_symints_saved_for_bw:]assertall([isinstance(x,(int,float,torch.SymInt,torch.SymFloat))forxinsymint_outs])ctx.symints=symint_outselse:tensors_saved_for_backwards=fw_outs[num_forward_returns:]# See Note [Detaching saved tensors in AOTAutograd]ctx.save_for_backward(*(x.detach()ifx._is_view()elsexforxintensors_saved_for_backwards))ctx.symints=[]raw_returns=fw_outs[0:num_forward_returns]# Wrap all autograd.Function.forward() outputs that are aliases# so that autograd.Function doesn't treat them as tensorsifnum_mutated_metadata_only_inputs>0:fori,idxinenumerate(CompiledFunction.metadata.mutated_inp_indices):# We could make this faster by only looping over inputs with metadata-only mutations# (instead of looping over inputs with either data or metadata mutations), but there shouldn't be many.info=CompiledFunction.metadata.input_info[idx]ifinfo.mutates_metadataandnotinfo.mutates_data:raw_returns[i]=TensorAlias(raw_returns[i])ifconfig.debug_assert:user_mutated_inputs_raw=raw_returns[0:num_mutated_inputs]mut_inp_infos=[xforxinCompiledFunction.metadata.input_infoifx.mutates_dataorx.mutates_metadata]assertlen(user_mutated_inputs_raw)==len(mut_inp_infos)ifnum_outputs_aliased>0:foridxinCompiledFunction.metadata.aliased_out_indices:raw_return_idx=num_mutated_inputs+idxraw_returns[raw_return_idx]=TensorAlias(raw_returns[raw_return_idx])ifconfig.debug_assert:intermediates_raw=raw_returns[num_mutated_inputs+num_outputs:]assertnotany(isinstance(x,TensorAlias)forxinintermediates_raw)# invariant: intermediate bases always require gradients, so we don't have to# consider marking them as non-differentiable.raw_returns_not_including_intermediate_bases=raw_returns[:num_mutated_inputs+num_outputs]fw_outs_not_requiring_grad=[xfor(i,x)inenumerate(raw_returns_not_including_intermediate_bases)ifisinstance(x,torch.Tensor)andnotCompiledFunction.metadata.requires_grad_info[i]]ctx.mark_non_differentiable(*fw_outs_not_requiring_grad)ifrng_metadata.is_compiled_with_functional_rng_ops:# Advance total fwd offsetCUDARngStateHelper.advance_torch_state(rng_metadata.philox_total_offsets.total_fwd_offset)returntuple(raw_returns)@staticmethoddefbackward(ctx,*flat_args):# Calling convention: we expect a grad_out passed to the backward:# - for every output of the fw that does *not* alias an input or graph intermediate# - for every updated_input generated by the fw that does *not* alias an input (aka only data-mutations)# - for every graph intermediate that we need to use to generate an output later.# The other outputs in the autograd.Function.forward that do *not* show up in the backward include:# - outputs that alias inputs or graph intermediates# - updated inputs due to metadata-only mutations.# We need to return them in the forward, but ensure that they all do not get gradients in the backward,# and we filter them out here before passing the remaining grad_outputs into the compiled backward.num_mutated_inps=CompiledFunction.metadata.num_mutated_inputsnum_intermediate_bases=CompiledFunction.metadata.num_intermediate_basesexpected_grad_outs=(CompiledFunction.metadata.num_outputs+num_mutated_inps+num_intermediate_bases)assertlen(flat_args)==expected_grad_outsout_info=CompiledFunction.metadata.output_infoif(CompiledFunction.metadata.num_mutated_metadata_only_inputs>0orCompiledFunction.metadata.num_outputs_aliased>0):inp_tangents,out_tangents,intermediate_base_tangents=(flat_args[0:num_mutated_inps],flat_args[num_mutated_inps:num_mutated_inps+CompiledFunction.metadata.num_outputs],flat_args[num_mutated_inps+CompiledFunction.metadata.num_outputs:],)# input_info contains info on *every* input,# But in the backward(), we are only given grad outputs for every mutated input.# We then need to filter out the grad outputs that correspond to metadata-only mutations.mutated_inp_indices=CompiledFunction.metadata.mutated_inp_indicesinput_info=CompiledFunction.metadata.input_infoassertlen(inp_tangents)==len(mutated_inp_indices)inp_tangents_filtered=[xforx,info_idxinzip(inp_tangents,mutated_inp_indices)ifinput_info[info_idx].mutates_data]# We also need to filter out grad outputs that correspond to outputs aliasing inputs/intermediatesout_tangents_filtered=[xforx,infoinzip(out_tangents,out_info)if(info.output_type==OutputType.non_aliasorinfo.output_type==OutputType.unsafe_view_alias)andissubclass(info.raw_type,torch.Tensor)]# intermediate bases always require gradients, and always participate in the backward graph.flat_bw_args=itertools.chain(inp_tangents_filtered,out_tangents_filtered,intermediate_base_tangents)# sanity asserts# metadata_only_inps = [# x for x, info_idx in zip(inp_tangents, mutated_inp_indices)# if not input_info[info_idx].mutates_data# ]# aliased_outputs = [# x for x, info in zip(out_tangents, out_info) if info.output_type != OutputType.non_alias]# assert all(x is None for x in metadata_only_inps)# assert all(x is None for x in aliased_outputs)else:# filter out non-tensor grad_outputs (aka due to ints being returned as outputs in the forward)num_mutated_inps=CompiledFunction.metadata.num_mutated_inputsmutated_inp_args=flat_args[:num_mutated_inps]ifnum_mutated_inps>0else[]user_tangents=flat_args[num_mutated_inps:]assertlen(user_tangents)==len(out_info)filtered_user_tangents=[xforx,infoinzip(user_tangents,out_info)ifissubclass(info.raw_type,torch.Tensor)]flat_bw_args=tuple(mutated_inp_args)+tuple(filtered_user_tangents)contiguous_args=[t.contiguous()iftorch.is_tensor(t)elsetfortinflat_bw_args]rng_args=[]ifrng_metadata.is_compiled_with_functional_rng_ops:# Add the seed and offset to argsrng_args=CUDARngStateHelper.get_torch_state_as_tuple()all_args=(list(rng_args)+list(ctx.symints)+list(ctx.saved_tensors)+list(contiguous_args))delcontiguous_argsdefcall_compiled_backward():ifCompiledFunction.compiled_bwisNone:assertall(aisnotNoneforainall_args)context=disable_autocast_managerifdisable_ampelsenullcontextwithtracing(saved_context),context(),track_graph_compiling(aot_config,"backward"):CompiledFunction.compiled_bw=aot_config.bw_compiler(bw_module,fx_placeholder_vals(bw_module))ctx.maybe_clear_saved_tensors()out=call_func_with_args(CompiledFunction.compiled_bw,all_args,steal_args=True,disable_amp=disable_amp,)ifrng_metadata.is_compiled_with_functional_rng_ops:# Advance total bwd rng offsetCUDARngStateHelper.advance_torch_state(rng_metadata.philox_total_offsets.total_bwd_offset)returntuple(out)iftorch.is_grad_enabled()andany(t.requires_gradfortinall_argsifisinstance(t,torch.Tensor)):# Ensure that the graph is connected, and error if double backward is performed.# See comment for why once_differentiable is not sufficient:# https://github.com/pytorch/pytorch/pull/92348/files#r1072962107classCompiledFunctionBackward(torch.autograd.Function):@staticmethoddefforward(ctx,*unused_args):returncall_compiled_backward()@staticmethoddefbackward(ctx,*args):raiseRuntimeError("torch.compile with aot_autograd does not currently support double backward")# Pass args even though they're unused, so that the graph is builtout=CompiledFunctionBackward.apply(*all_args)else:out=call_compiled_backward()returnoutcompiled_function=create_runtime_wrapper(CompiledFunction.apply,runtime_metadata=fw_metadata,indices_of_inps_to_detach=_indices_of_inps_to_detach,trace_joint=True,keep_input_mutations=False,disable_amp=disable_amp)ifnotconfig.debug_assert:returncompiled_functionflat_requires_grad=[a.requires_gradifisinstance(a,Tensor)elseNoneforainflat_args]@wraps(compiled_function)defdebug_compiled_function(*args):# TODO: Check aliasing relationships# TODO: Check strides for metadata mutation# (NB: ideally, this logic is factored out of this function and# you move these debug checks there)# Check requires grad. Bad case is when we compiled with# requires_grad = False, but input requires_grad = True# (vice versa is OK; we compute a gradient and then throw# it away when it hits the input.)fori,ainenumerate(args):can_require_grad=flat_requires_grad[i]ifcan_require_gradisNone:assertnotisinstance(a,Tensor)elifnotcan_require_grad:assertnota.requires_grad,format_guard_bug_msg(aot_config,f"{describe_input(i,aot_config)} would not require grad",)returncompiled_function(*args)returndebug_compiled_function@dynamo_timeddefcreate_aot_dispatcher_function(flat_fn,flat_args:List[Any],aot_config:AOTConfig):""" Traces the forward and backward graphs of the attr:`flat_fn` to generate a joint graph. The joint graph is an Fx graph with Aten ops. Please refer to the tracing mechanism to understand the graph capturing details. The joint graph is then passed through attr:`partition_fn` to isolate the forward and backward portions, which are then respectively compiled via the provided attr:`fw_compiler` and attr:`bw_compiler`. The resulting compiled forward and backward graphs are then wrapped up in a ``torch.autograd.Function`` object. The calling convention here is that the first aot_config.num_params_buffers inputs in flat_args are parameters and buffers, and the rest are inputs. We use this to assume that parameters/buffer's shapes don't change. """# This is the main entry point.# TODO: Chillee argues that dynamo itself should pass in fake tensors to# the list of arguments when compiling; at the moment we do not do thisifaot_config.decompositionsisNone:aot_config.decompositions={}aot_config.decompositions={**aot_autograd_decompositions,**aot_config.decompositions,}ifconfig.functionalize_rng_ops:# Update the decompositions with functionalized random decompositionsaot_config.decompositions={**rng_decompositions,**aot_config.decompositions,}# Check flat_args to see if they're already fake. If so, use that fake# mode instead.fake_mode=detect_fake_mode(flat_args)iffake_modeisNone:shape_env=ShapeEnv()ifaot_config.dynamic_shapeselseNonefake_mode=FakeTensorMode(shape_env=shape_env)else:shape_env=fake_mode.shape_envpython_dispatcher_mode=(enable_python_dispatcher()ifshape_envisnotNoneelsenullcontext())withtorch.autograd.set_multithreading_enabled(False),preserve_rng_state(),fake_mode,python_dispatcher_mode,PhiloxStateTracker():defprocess_inputs(flat_args):defconvert(idx,x):ifshape_envisnotNone:fromtorch._dynamo.sourceimportConstantSourceifisinstance(x,int):returnshape_env.create_symintnode(shape_env.create_symbol(x,ConstantSource(f"sym_{idx}")),hint=x)ifnotisinstance(x,torch.Tensor):returnxifisinstance(x,FakeTensor):assertx.fake_modeisfake_modereturnx# TODO: Ensure that this codepath is never exercised from# Dynamoif(idx<aot_config.num_params_buffersandconfig.static_weight_shapes):returnfake_mode.from_tensor(x,static_shapes=True)returnfake_mode.from_tensor(x,static_shapes=False)return[convert(idx,x)foridx,xinenumerate(flat_args)]fake_flat_args=process_inputs(flat_args)needs_autograd=(any([x.requires_gradforxinfake_flat_argsifisinstance(x,Tensor)])andtorch.is_grad_enabled())withenable_python_dispatcher():# Patch set_rng_state as set_rng_state with fake tensors is# nonsensical. This does not affect the collection of metadata.withpatch("torch.cuda.set_rng_state",lambda*args:None):fw_metadata=run_functionalized_fw_and_collect_metadata(flat_fn,keep_input_mutations=aot_config.keep_inference_input_mutationsandnotneeds_autograd,)(*fake_flat_args)# crappy version of dispatcher# TODO: Do this properlyifneeds_autograd:compiler_fn=aot_dispatch_autogradelse:compiler_fn=aot_dispatch_basecompiler_fn=partial(aot_wrapper_synthetic_base,compiler_fn=compiler_fn,needs_autograd=needs_autograd)compiler_fn=partial(aot_wrapper_dedupe,compiler_fn=compiler_fn)# You can put more passes herecompiled_fn=compiler_fn(flat_fn,fake_flat_args,aot_config,fw_metadata=fw_metadata)ifnothasattr(compiled_fn,"_boxed_call"):compiled_fn=make_boxed_func(compiled_fn)returncompiled_fn# Inspired by autodidax (thanks!)classPytreeThunk:spec=None# These are some kinda dumb microoptimizations that save about 3-4 us of overhead.is_simple=(None# if the output spec is a tuple/list, we won't bother unflattening it.)is_really_simple=None# if the output spec is a LeafSpecdefset(self,spec):assertself.specisNoneorself.spec==specself.spec=speciftype(self.spec)in[tuple,list]andall(isinstance(i,pytree.LeafSpec)foriinspec.children_specs):self.is_simple=Trueifisinstance(self.spec,pytree.LeafSpec):self.is_really_simple=Truedefunflatten(self,x):ifself.is_really_simple:returnx[0]ifself.is_simple:returnxreturnpytree.tree_unflatten(x,self.spec)
[docs]defaot_function(fn:Callable,fw_compiler:Callable,bw_compiler:Optional[Callable]=None,partition_fn:Callable=default_partition,decompositions:Optional[Dict]=None,num_params_buffers:int=0,keep_inference_input_mutations:bool=False,inference_compiler:Optional[Callable]=None,*,# Whether or not to trace with dynamic shapesdynamic=False,enable_log=True,)->Callable:""" Traces the forward and backward graph of :attr:`fn` using torch dispatch mechanism, and then compiles the generated forward and backward graphs through :attr:`fw_compiler` and :attr:`bw_compiler`. :func:`aot_function` traces the forward and backward graph ahead of time, and generates a joint forward and backward graph. :attr:`partition_fn` is then used to separate out forward and backward graphs. The partitioner function can be used to perform optimizations such as recomputation. One can set `decompositions` dictionary to decompose the operators into a sequence of core or simpler operators supported by the backend compilers. .. warning:: This API is experimental and likely to change. Args: fn (Callable): A Python function that takes one ore more arguments. Must return one or more Tensors. fw_compiler (Callable): A Python function that accepts an Fx graph with Aten ops and input args, and returns a Callable that semantically is equivalent to the input Fx graph. bw_compiler (Optional[Callable]): A Python function that accepts an Fx graph with Aten ops and input args, and returns a Callable that semantically is equivalent to the input Fx graph. Default: None (when None, it defaults to the :attr:`fw_compiler`) partition_fn (Callable): A Python function that takes a joint forward and backward graph, and partitions it into separate forward and backward graphs. decompositions (Dict): A dictionary to define the decomposition of larger Aten ops into simpler or core Aten ops. inference_compiler (Optional[Callable]): A Python function that accepts an Fx graph with Aten ops and input args, and returns a Callable that semantically is equivalent to the input Fx graph. inference_compiler is invoked if no autograd is needed. Default: None (when None, it defaults to the :attr:`fw_compiler`) Returns: Returns a ``Callable`` that retains the eager behavior of the original :attr:`fn`, but with forward and backward graph compiled via :attr:`fw_compile` and :attr:`bw_compile`. A simple example usage of :func:`aot_function` is as follows. This example will print the forward and backward graphs of the function ``fn`` >>> fn = lambda x : x.sin().cos() >>> def print_compile_fn(fx_module, args): >>> print(fx_module) >>> return fx_module >>> aot_fn = aot_function(fn, print_compile_fn) >>> x = torch.randn(4, 5, requires_grad=True) >>> aot_fn(x) """ifbw_compilerisNone:bw_compiler=fw_compilerifinference_compilerisNone:inference_compiler=fw_compileraot_config=AOTConfig(fw_compiler=fw_compiler,bw_compiler=bw_compiler,inference_compiler=fw_compiler,partition_fn=partition_fn,decompositions=decompositions,num_params_buffers=num_params_buffers,aot_id=next(AOT_COUNTER),keep_inference_input_mutations=keep_inference_input_mutations,dynamic_shapes=dynamic,aot_autograd_arg_pos_to_source=None,enable_log=enable_log,)cached_res=None@wraps(fn)defreturned_function(*args,**kwargs):nonlocalcached_res# Now flatten the tensor argsflat_args,_=pytree.tree_flatten((args,kwargs))# Compile the function and save it in the cacheifcached_resisNone:# Save the args_spec for flat_tensor_args to unflatten while tracing_,tensor_args_spec=pytree.tree_flatten((args,kwargs))out_spec=PytreeThunk()defflat_fn(*flat_args):# The input are flattened tensor args. Prepare the args in the# order that original function expects. Add static args as well.# They will appear as tensor constants in the traced graph.nonlocalout_specargs,kwargs=pytree.tree_unflatten(flat_args,tensor_args_spec)tree_out=fn(*args,**kwargs)flat_out,spec=pytree.tree_flatten(tree_out)foriinflat_out:is_known_type=FalseforjinKNOWN_TYPES:ifisinstance(i,j):is_known_type=Truebreakifnotis_known_type:raiseRuntimeError(f"Found {type(i)} in output, which is not a known type. ""If this type holds tensors, you need to register a pytree for it. ""See https://github.com/pytorch/functorch/issues/475 for a brief ""explanation why. If you don't need to register a pytree, please ""leave a comment explaining your use case and we'll make this more ""ergonomic to deal with")out_spec.set(spec)returnflat_outcompiled_fn=create_aot_dispatcher_function(flat_fn,flat_args,aot_config,)cached_res=(compiled_fn,out_spec)cached_fn,out_spec=cached_resout=cached_fn(flat_args)returnout_spec.unflatten(out)returnreturned_function
[docs]defaot_module(mod:nn.Module,*args,**kwargs)->nn.Module:""" Traces the forward and backward graph of :attr:`mod` using torch dispatch tracing mechanism. It is wrapper function, that underneath uses :func:`aot_function` to perform tracing and compilation. :func:`aot_module` lifts the parameters and buffers of ``nn.Module`` as inputs to a new callable which is then compiled through :func:`aot_function`. .. warning:: This API is experimental and likely to change. Args: mod (Callable): A ``nn.Module`` module. args : args to be passed to :func:`aot_function` kwargs : kwargs to be passed to :func:`aot_function` Returns: Returns a ``nn.Module`` that retains the eager behavior of the original :attr:`mod`, but with forward and backward graph compiled. """# See Note: [Fake Modules and AOTAutograd]torch._dynamo.utils.assert_no_fake_params_or_buffers(mod)deffunctional_call(named_params,named_buffers,*args,**kwargs):params_and_buffers={**named_params,**named_buffers}returntorch.func.functional_call(mod,params_and_buffers,args,kwargs)named_params=dict(mod.named_parameters(remove_duplicate=False))named_buffers=dict(mod.named_buffers(remove_duplicate=False))num_params_buffers=len(named_params)+len(named_buffers)compiled_f=aot_function(functional_call,num_params_buffers=num_params_buffers,*args,**kwargs)classAOTModule(nn.Module):def__init__(self):super().__init__()self.orig_module=moddefforward(self,*args,**kwargs):returncompiled_f(named_params,named_buffers,*args,**kwargs,)returnAOTModule()
defaot_module_simplified(mod:nn.Module,args,fw_compiler:Callable,bw_compiler:Optional[Callable]=None,partition_fn:Callable=default_partition,decompositions:Optional[Dict]=None,keep_inference_input_mutations=False,inference_compiler:Optional[Callable]=None,)->nn.Module:""" This is the simplified or low overhead version of aot_module. For frontends like TorchDynamo, the input functions/modules to AOT are static and have unpacked inputs/outputs. This gives us an opportunity to remove the (1) pytree overhead to parse inputs/outputs, (2) AOT Autograd cache, (3) Reading of params/buffers in every forward call :func:`aot_module_simplified` removes these overheads. """########################################################## Redudant with dynamo, but worth having in case this gets invoked elsewhere.# Note [Fake Modules and AOTAutograd]## A simple heuristic for when to use fake versus real tensors is that fake tensors are for compile time# (when we don't want to actually run the compute, but we do want to know about metadata),# and real tensors are for runtime (when we actually want to do the compute.) However, in AOTAutograd,# modules are the exception: we always pass AOTAutograd modules with real tensors.# This is because AOTAutograd will produce a compiled function which needs to directly access any# parameters the compiled function may need, but these parameters will NOT be passed in by the caller (aka Dynamo).# So at compile time, the compiled function we produce must close over any parameters, and those parameters must be# real parameters, and we cannot do this unless at compile time we get a module with real tensors.# Even if Dynamo did pass all parameters explicitly at runtime, which would eliminate the need to close over# the parameters, it would still be profitable to pass real tensor parameters to the compiler at compile time,# because some compilation strategies like CUDA graphs want to burn in the pointer addresses where the parameter data live,# and of course we can't do that unless we give the backend a real tensor.torch._dynamo.utils.assert_no_fake_params_or_buffers(mod)params={**dict(mod.named_parameters(remove_duplicate=False)),**dict(mod.named_buffers(remove_duplicate=False)),}params_flat,params_spec=pytree.tree_flatten(params)params_flat=tuple(params_flat)params_len=len(params_flat)deffunctional_call(*args,**kwargs):withstateless._reparametrize_module(mod,pytree.tree_unflatten(args[:params_len],params_spec)):ifisinstance(mod,torch.fx.GraphModule):withfx_traceback.preserve_node_meta(),warnings.catch_warnings():warnings.filterwarnings("ignore","Anomaly Detection has been enabled.")withtorch.autograd.detect_anomaly(check_nan=False):out=Interpreter(mod).run(*args[params_len:],**kwargs)else:out=mod(*args[params_len:],**kwargs)ifnotisinstance(out,(tuple,list)):raiseRuntimeError("Graph output must be a tuple(). This is so that we can avoid ""pytree processing of the ouputs. Please change the module to ""have tuple outputs or use aot_module instead.")returnoutifbw_compilerisNone:bw_compiler=fw_compilerifinference_compilerisNone:inference_compiler=fw_compilerseen_sources=set()full_args=[]# First, the paramsfull_args.extend(params_flat)aot_autograd_arg_pos_to_source=None# Then, the params 1:1 mapped sources, if relevant.ifhasattr(mod,"_param_name_to_source"):aot_autograd_arg_pos_to_source=[]# We now know this came from dynamo, and (1) we care about guards,# so setting up aot_autograd_arg_pos_to_source for downstream dedup guards# can now be done safely. (2) Dynamo logic protects the 1:1 sizing below.fornameinparams.keys():assertnameinmod._param_name_to_source,f"{name} not found."source=mod._param_name_to_source[name]assertsourcenotinseen_sources,sourceseen_sources.add(source)aot_autograd_arg_pos_to_source.append(source)# Next, the input argsfull_args.extend(args)ifhasattr(mod,"graph"):# Non dynamo entrypoints can get to here...fori,nodeinenumerate(mod.graph.nodes):ifnode.op=="placeholder":ifhasattr(node,"_dynamo_source"):# ... but not here!ifaot_autograd_arg_pos_to_sourceisNone:aot_autograd_arg_pos_to_source=[]source=node._dynamo_sourceassertsourcenotinseen_sources,sourceseen_sources.add(source)aot_autograd_arg_pos_to_source.append(source)ifaot_autograd_arg_pos_to_sourceisnotNone:assertlen(full_args)==len(aot_autograd_arg_pos_to_source)dynamic_shapes=Falseforxinfull_args:ifisinstance(x,FakeTensor):dynamic_shapes=x.fake_mode.shape_envisnotNonebreakaot_config=AOTConfig(fw_compiler=fw_compiler,bw_compiler=bw_compiler,inference_compiler=inference_compiler,partition_fn=partition_fn,decompositions=decompositions,num_params_buffers=params_len,aot_id=next(AOT_COUNTER),keep_inference_input_mutations=keep_inference_input_mutations,dynamic_shapes=dynamic_shapes,aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source)compiled_fn=create_aot_dispatcher_function(functional_call,full_args,aot_config,)# TODO: There is something deeply wrong here; compiled_fn running with# the boxed calling convention, but aot_module_simplified somehow# historically returned a function that was not the boxed calling# convention. This should get fixed...defforward(*runtime_args):full_args=[]full_args.extend(params_flat)full_args.extend(runtime_args)returncompiled_fn(full_args)# Just for convenienceforward.zero_grad=mod.zero_gradforward.named_parameters=mod.named_parametersforward.named_buffers=mod.named_buffersreturnforwardcompiled_function=aot_functioncompiled_module=aot_module
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.