fromtorch.fx.experimental.proxy_tensorimportis_sym_node,py_sym_typesfromtorch.fx.experimental.symbolic_shapesimporthint_intimporttorchimporttorch.fxasfximportoperatorimportmathimporttorch.utils._pytreeaspytreeimportcopyimportosfromcollectionsimportdefaultdictfromtorch.fx.passesimportgraph_drawerfromtypingimportTuplefrom.compile_utilsimportfx_graph_cse,get_aten_targetfrom.importconfigimportfunctoolsAOT_PARTITIONER_DEBUG=config.debug_partitionerclassInvalidNodeBase:def__repr__(self):return"Invalid Node"InvalidNode=InvalidNodeBase()def_extract_graph_with_inputs_outputs(joint_graph,inputs,outputs):""" Given a graph, extracts out a subgraph that takes the specified nodes as inputs and returns the specified outputs. This includes specifying non-placeholder nodes as inputs. The general strategy is to initialize all inputs with proxies as we encounter them, and trace through the graph, only keeping values which take in valid proxies. Then, all dead code is eliminated. """new_graph=fx.Graph()env={}# Add new placeholder nodes in the order specified by the inputsfornodeininputs:new_node=new_graph.placeholder(node.name)# Can't use node_copy here as we may be turning previous call_function into placeholdersnew_node.meta=node.metaenv[node]=new_nodefornodeinjoint_graph.nodes:ifnodeininputs:continueelifnode.op=='placeholder':env[node]=InvalidNodeelifnode.op=='call_function':all_args=pytree.tree_flatten((node.args,node.kwargs))[0]all_args=[isinstance(env[x],InvalidNodeBase)forxinall_argsifisinstance(x,fx.Node)]ifany(all_args):env[node]=InvalidNodecontinueenv[node]=new_graph.node_copy(node,lambdax:env[x])elifnode.op=='get_attr':env[node]=new_graph.node_copy(node,lambdax:env[x])elifnode.op=='output':passoutput_values=[]forxinoutputs:ifisinstance(x,fx.Node):ifxnotinenv:raiseRuntimeError(f"Node {x} couldn't be found in env")output_values.append(env[x])else:output_values.append(x)new_graph.output(output_values)new_graph.eliminate_dead_code()new_graph.lint()returnnew_graphdef_is_primal(node):returnnode.op=="placeholder"and"tangents"notinnode.targetdef_is_tangent(node):returnnode.op=="placeholder"and"tangents"innode.targetdef_extract_fwd_bwd_outputs(joint_module:fx.GraphModule,*,num_fwd_outputs):outputs=pytree.tree_flatten([node.argsfornodeinjoint_module.graph.nodesifnode.op=='output'])[0]fwd_outputs=outputs[:num_fwd_outputs]bwd_outputs=outputs[num_fwd_outputs:]returnfwd_outputs,bwd_outputsdef_extract_fwd_bwd_modules(joint_module:fx.GraphModule,saved_values,saved_sym_nodes=(),*,num_fwd_outputs):fwd_outputs,bwd_outputs=_extract_fwd_bwd_outputs(joint_module,num_fwd_outputs=num_fwd_outputs)primal_inputs=list(filter(_is_primal,joint_module.graph.nodes))tangent_inputs=list(filter(_is_tangent,joint_module.graph.nodes))# Construct the forward module# Keep symints separate from tensors, passed between fwd/bwd graphs, and in the right order.fwd_graph=_extract_graph_with_inputs_outputs(joint_module.graph,primal_inputs,fwd_outputs+saved_values+saved_sym_nodes)bwd_graph=_extract_graph_with_inputs_outputs(joint_module.graph,saved_sym_nodes+saved_values+tangent_inputs,bwd_outputs)# This is to filter out saved values that don't actually end up being used by the backwards passfornodeinbwd_graph.nodes:ifnode.op=='placeholder'andnotnode.users:forsaved_valueinsaved_values:ifsaved_value.name==node.name:saved_values.remove(saved_value)breakforsaved_syminsaved_sym_nodes:ifsaved_sym.name==node.name:saved_sym_nodes.remove(saved_sym)break# Now, we re-generate the fwd/bwd graphs.# NB: This might increase compilation time, but I doubt it mattersfwd_graph=_extract_graph_with_inputs_outputs(joint_module.graph,primal_inputs,fwd_outputs+saved_values+saved_sym_nodes)bwd_graph=_extract_graph_with_inputs_outputs(joint_module.graph,saved_sym_nodes+saved_values+tangent_inputs,bwd_outputs)fwd_module=fx.GraphModule(joint_module,fwd_graph)bwd_module=fx.GraphModule(joint_module,bwd_graph)returnfwd_module,bwd_module
[docs]defdefault_partition(joint_module:fx.GraphModule,_joint_inputs,*,num_fwd_outputs)->Tuple[fx.GraphModule,fx.GraphModule]:""" Partitions the :attr:`joint_module` in a manner that closely resembles the behavior observed in the original ``.forward()`` and ``.backward()`` of the callable, i.e., the resulting forward graph contains those operators that are executed in the original ``.forward()`` callable passed to :func:`aot_function`. The default partitioner collects the operators that are between the forward inputs and the forward outputs. This helps in finding the tensors which have to be stashed for the backward pass. These stashed tensors become the output of the generated forward graph. The remaining operators are then placed in the backward graph. .. warning:: This API is experimental and likely to change. Args: joint_module(fx.GraphModule): The joint forward and backward graph. This is the result of AOT Autograd tracing. Returns: Returns the generated forward and backward Fx graph modules. """primal_inputs=list(filter(_is_primal,joint_module.graph.nodes))fwd_outputs,bwd_outputs=_extract_fwd_bwd_outputs(joint_module,num_fwd_outputs=num_fwd_outputs)forward_only_graph=_extract_graph_with_inputs_outputs(joint_module.graph,primal_inputs,fwd_outputs)forward_node_names={node.namefornodeinforward_only_graph.nodesifnode.op!='output'}saved_values=[]saved_sym_nodes=[]fornodeinjoint_module.graph.nodes:ifnode.namenotinforward_node_names:continueifis_sym_node(node):# Symints must be kept separate from tensors so that PythonFunction only calls# save_for_backward on tensors and stashes symints in autograd .ctxsaved_sym_nodes.append(node)elif('tensor_meta'notinnode.metaandnode.op=='call_function'):# Since we can't save tuple of tensor values, we need to flatten out what we're savingusers=node.usersassertall(user.target==operator.getitemforuserinusers)foruserinusers:saved_values.append(user)else:backward_usages=[nforninnode.usersifn.namenotinforward_node_names]if'tensor_meta'innode.metaandall(is_sym_node(n)forninbackward_usages):# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,# and not the actual tensor data,# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.## Note that saving the tensor could also cause compilation problems:# If the user mutated an input in the forward and uses its sizes/strides in the backward,# then we would be obligated to clone the input before saving it to appease autograd.# (This is how we originally found this bug).foruserinbackward_usages:saved_sym_nodes.append(user)else:saved_values.append(node)saved_values=list(set(saved_values))saved_sym_nodes=list(set(saved_sym_nodes))return_extract_fwd_bwd_modules(joint_module,saved_values,saved_sym_nodes=saved_sym_nodes,num_fwd_outputs=num_fwd_outputs)
def_prod(x):s=1foriinx:s*=ireturnsdef_tensor_nbytes(numel,dtype):sizes={torch.float:4,torch.float16:2,torch.bfloat16:2,torch.float32:4,torch.float64:8,torch.int:4,torch.int8:1,torch.int16:2,torch.int32:4,torch.int64:8,torch.uint8:1,torch.bool:1,}ifdtypenotinsizes:raiseNotImplementedError("Don't know the size of dtype ",dtype)returnnumel*sizes[dtype]def_size_of(node:fx.Node)->int:if'val'innode.meta:val=node.meta['val']ifisinstance(val,py_sym_types):return1elifisinstance(val,(list,tuple)):returnsum(_tensor_nbytes(hint_int(n.numel()),n.dtype)forninvalifisinstance(n,torch.Tensor))elifisinstance(val,torch.Tensor):return_tensor_nbytes(hint_int(val.numel()),val.dtype)raiseRuntimeError(f"Unknown metadata type {type(val)}")# Only needed since we don't always trace with fake tensors.if'tensor_meta'innode.meta:metadata=node.meta['tensor_meta']numel=_prod(map(to_size_hint,metadata.shape))dtype=metadata.dtypeelse:return0return_tensor_nbytes(numel,dtype)# Used for some investigative purposesdef_count_ops(graph):fromcollectionsimportdefaultdictcnt=defaultdict(int)fornodeingraph.nodes:ifnode.op=='call_function':cnt[node.target.__name__]+=1print(sorted(cnt.items(),key=lambdax:x[1],reverse=True))@functools.lru_cache(None)defpointwise_ops():ops=[]forattr_nameindir(torch.ops.aten):opoverloadpacket=getattr(torch.ops.aten,attr_name)ifnotisinstance(opoverloadpacket,torch._ops.OpOverloadPacket):continueforoverloadinopoverloadpacket.overloads():op_overload=getattr(opoverloadpacket,overload)iftorch.Tag.pointwiseinop_overload.tags:# currently aot autograd uses packet not overloadops.append(opoverloadpacket)breakreturnops
[docs]defmin_cut_rematerialization_partition(joint_module:fx.GraphModule,_joint_inputs,compiler="nvfuser",recomputable_ops=None,*,num_fwd_outputs)->Tuple[fx.GraphModule,fx.GraphModule]:""" Partitions the joint graph such that the backward recomputes the forward. Recomputing helps in trading off memory bandwidth with computation. To create the fwd and bwd graph, we copy the joint graph, manually set the outputs to just original forward or backward outputs. And then we run the resulting graphs through dead code elimintation. .. warning:: This API is experimental and likely to change. Args: joint_module(fx.GraphModule): The joint forward and backward graph. This is the result of AOT Autograd tracing. _joint_inputs: The inputs to the joint graph. This is unused. compiler: This option determines the default set of recomputable ops. Currently, there are two options: ``nvfuser`` and ``inductor``. recomputable_ops: This is an optional set of recomputable ops. If this is not None, then this set of ops will be used instead of the default set of ops. num_fwd_outputs: The number of outputs from the forward graph. Returns: Returns the generated forward and backward Fx graph modules. """try:importnetworkxasnxexceptImportErrorase:raiseRuntimeError("Need networkx installed to perform smart recomputation ""heuristics")fromejoint_module.graph.eliminate_dead_code()joint_module.recompile()fx_g=joint_module.graph# add the CSE passifconfig.cse:cse_graph=fx_graph_cse(fx_g)joint_module.graph=cse_graphfull_bw_graph=joint_module.graphname_to_node={}fornodeinjoint_module.graph.nodes:name_to_node[node.name]=nodedefclassify_nodes(joint_module):required_bw_nodes=set()fornodeinjoint_module.graph.nodes:ifnode.op=='placeholder'and"tangents"innode.target:required_bw_nodes.add(node)ifnodeinrequired_bw_nodes:foruserinnode.users:required_bw_nodes.add(user)primal_inputs=list(filter(_is_primal,joint_module.graph.nodes))fwd_outputs,_=_extract_fwd_bwd_outputs(joint_module,num_fwd_outputs=num_fwd_outputs)forward_only_graph=_extract_graph_with_inputs_outputs(joint_module.graph,primal_inputs,fwd_outputs)required_fw_nodes={name_to_node[node.name]fornodeinforward_only_graph.nodesifnode.op!='output'}unclaimed_nodes={nodefornodeinjoint_module.graph.nodesifnodenotinrequired_fw_nodesandnodenotinrequired_bw_nodes}returnfwd_outputs,required_fw_nodes,required_bw_nodes,unclaimed_nodesorig_fw_outputs,required_fw_nodes,required_bw_nodes,unclaimed_nodes=classify_nodes(joint_module)defis_tensor_node(x):# When dynamic shapes are not enabled, fw outputs can be raw ints and not fx nodesifnotisinstance(x,fx.Node):returnFalse# It would be nice if we could guarantee that all fx nodes from make_fx get a 'val'# key in their meta dict, but that isn't always true today (see proxy_tensor.py)return'tensor_meta'inx.metaor('val'inx.metaandisinstance(x.meta['val'],torch.Tensor))# networkx blows up on graphs with no required backward nodes# Since there's nothing to partition anyway, and the default partitioner can "handle"# this case, send our graph over to the default partitioner.iflen(required_bw_nodes)==0:returndefault_partition(joint_module,_joint_inputs,num_fwd_outputs=num_fwd_outputs)fornodeinreversed(joint_module.graph.nodes):ifnodenotinrequired_fw_nodes:node.dist_from_bw=0else:node.dist_from_bw=int(1e9)foruserinnode.users:node.dist_from_bw=min(node.dist_from_bw,user.dist_from_bw+1)aten=torch.ops.atenprims=torch.ops.prims# compiler == "nvfuser" is the default set of recomputable opsdefault_recomputable_ops=[aten.add,aten.sub,aten.div,aten.atan2,aten.mul,aten.max,aten.min,aten.pow,aten.remainder,aten.fmod,aten.__and__,aten.__or__,aten.__xor__,aten.__lshift__,aten.__rshift__,aten.eq,aten.ne,aten.ge,aten.gt,aten.le,aten.lt,aten.abs,aten.bitwise_not,aten.ceil,aten.floor,aten.frac,aten.neg,aten.relu,aten.round,aten.silu,aten.trunc,aten.log,aten.log10,aten.log1p,aten.log2,aten.lgamma,aten.exp,aten.expm1,aten.erf,aten.erfc,aten.cos,aten.acos,aten.cosh,aten.sin,aten.asin,aten.sinh,aten.tan,aten.atan,aten.tanh,aten.atanh,aten.sqrt,aten.rsqrt,aten.reciprocal,aten.sigmoid,aten.softplus,aten.threshold,aten.threshold_backward,aten.clamp,aten.where,aten.lerp,aten.addcmul,aten.gelu,aten.gelu_backward,aten.sum,aten.mean,aten._grad_sum_to_size,aten.sum_to_size,aten.amax,aten.to,aten.type_as,operator.getitem,aten.squeeze,aten.unsqueeze,aten.rsub,aten._to_copy]# noqa: E501view_ops=[aten.squeeze,aten.unsqueeze,aten.alias]ifcompiler=="inductor":default_recomputable_ops+=[prims.div,prims.convert_element_type,aten.clone,aten._to_copy,aten.full_like,prims.var,prims.sum,aten.var,aten.std,prims.broadcast_in_dim,aten.select,aten.permute,aten._unsafe_view,aten.view,aten.expand,aten.slice,aten.reshape,aten.broadcast_tensors,aten.scalar_tensor,aten.ones,aten.new_zeros,aten.lift_fresh_copy,aten.arange,aten.triu,aten.var_mean,aten.isinf,aten.any,aten.full,aten.as_strided,aten.zeros,aten.argmax,aten.maximum]# noqa: E501view_ops+=[aten.view,aten.slice,aten.permute,aten.t,prims.broadcast_in_dim,aten.expand,aten.as_strided]# Natalia said that we should allow recomputing indexing :)default_recomputable_ops+=[aten.index]default_recomputable_ops+=view_opsdefault_recomputable_ops+=pointwise_ops()recomputable_ops=set(recomputable_ops)ifrecomputable_opsisnotNoneelseset(default_recomputable_ops)random_ops=[aten.native_dropout,aten.rand_like,aten.randn_like]compute_intensive_ops=[aten.mm,aten.convolution,aten.convolution_backward,aten.bmm,aten.addmm,aten.upsample_bilinear2d,aten._softmax,aten._softmax_backward_data,aten.native_layer_norm,aten.native_layer_norm_backward,aten.native_batch_norm,aten.native_batch_norm_backward,aten._native_batch_norm_legit]# noqa: E501unrecomputable_ops=random_ops+compute_intensive_opsfusible_ops=recomputable_ops|set(random_ops)ifAOT_PARTITIONER_DEBUG:joint_module_ops={str(node.target._overloadpacket)fornodeinjoint_module.graph.nodesifnode.op=="call_function"andhasattr(node.target,"_overloadpacket")}ops_ignored=joint_module_ops-{str(i)foriinrecomputable_ops}print("Ops banned from rematerialization: ",ops_ignored)print()AGGRESSIVE_RECOMPUTATION=Falsedefis_materialized_backwards(node):cur_nodes={node}whilelen(cur_nodes)>0:cur=cur_nodes.pop()foruserincur.users:ifusernotinrequired_fw_nodesandnotis_fusible(cur,user):returnTrueifusernotinrequired_fw_nodesandget_aten_target(user)inview_ops:cur_nodes.add(user)returnFalsedefban_recomputation(node):ifAGGRESSIVE_RECOMPUTATION:return(node.op=='call_function'andget_aten_target(node)inunrecomputable_ops)else:ifnode.op!='call_function':returnFalseifget_aten_target(node)notinrecomputable_ops:returnTrueifnode.target==operator.getitem:returnFalseifnode.targetin[aten.lift_fresh_copy.default,aten.lift_fresh.default]:returnFalse# If a node *must* be materialized in the backwards pass, then we# should never recompute it. This is a pretty subtle point. In# general, the assumption we make is that recomputing a node in the# backwards pass is "free". However, if a node must be materialized# in the backwards pass, then recomputing it is never free.ifis_materialized_backwards(node):returnTrue# Arbitrary hack that sometimes seems to help things. The above# modification appears to have made this heuristic a lot less critical# for performance.# TODO: Investigate why this hack helps.ifcompiler=="inductor"andnode.dist_from_bw>config.max_dist_from_bw:returnTrue# If the output of an op is 4x smaller (arbitrary choice),# then we don't allow recomputation.input_tensors_size=sum(_size_of(i)foriinnode.argsifisinstance(i,fx.Node))output_size=_size_of(node)return(output_size*4<input_tensors_size)defis_fusible(a,b):returnget_aten_target(a)infusible_opsandget_aten_target(b)infusible_opsdefis_materialized(node):ifnode.op=='placeholder':returnTruereturnnotall(is_fusible(node,user)foruserinnode.users)defget_node_weight(node)->int:mem_sz=_size_of(node)# Heuristic to bias towards nodes closer to the backwards pass# Complete guess about current valuemem_sz=int(mem_sz*(1.1**max(min(node.dist_from_bw,100),1)))# mem_sz = int(mem_sz + node.dist_from_bw)ifis_materialized(node):returnmem_szelse:returnmem_sz*2nx_graph=nx.DiGraph()fornodeinfull_bw_graph.nodes:ifnode.op=='output':continueifnodeinrequired_bw_nodes:nx_graph.add_edge(node.name+"_in","sink",capacity=math.inf)continueifnode.op=='placeholder'and"primals"innode.target:nx_graph.add_edge("source",node.name+"_in",capacity=math.inf)# If a node can't be recomputed (too expensive or involves randomness),# we prevent it from being recomputed by adding an inf edge to the source# We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed.ifban_recomputation(node)andnodeinrequired_fw_nodes:nx_graph.add_edge("source",node.name+"_in",capacity=math.inf)# Checks if a node is actually a tuple. Can be simplified to just an isisinstance check if we always use faketensors.is_non_tensor_node=(('val'notinnode.metaand'tensor_meta'notinnode.meta)or('val'innode.metaandnotisinstance(node.meta['val'],torch.Tensor)))ifis_sym_node(node):weight=1elifis_non_tensor_node:weight=math.infelse:weight=get_node_weight(node)# Creates the weights on the "node" edgenx_graph.add_edge(node.name+"_in",node.name+"_out",capacity=weight)foruserinnode.users:nx_graph.add_edge(node.name+"_out",user.name+"_in",capacity=math.inf)cut_value,partition=nx.minimum_cut(nx_graph,"source","sink")reachable,non_reachable=partitioncutset=set()foru,nbrsin((n,nx_graph[n])forninreachable):cutset.update((u,v)forvinnbrsifvinnon_reachable)cut_nodes=set()fornode_in,node_outincutset:assertnode_in[:-3]==node_out[:-4]node_name=node_in[:-3]cut_nodes.add(node_name)# To make this stuff deterministicnode_idx={node:idxforidx,nodeinenumerate(joint_module.graph.nodes)}saved_values=sorted((name_to_node[node]fornodeincut_nodes),key=lambdax:node_idx[x])# Symints must be kept separate from tensors so that PythonFunction only calls# save_for_backward on tensors and stashes symints in autograd .ctxsaved_sym_nodes=list(filter(lambdan:is_sym_node(n),saved_values))saved_values=list(filter(lambdan:notis_sym_node(n),saved_values))fw_module,bw_module=_extract_fwd_bwd_modules(joint_module,saved_values,saved_sym_nodes=saved_sym_nodes,num_fwd_outputs=num_fwd_outputs)ifAOT_PARTITIONER_DEBUG:print("Theoretical Activations Stored: ",sum([_size_of(i)foriinsaved_values])/1e9)fw_module_nodes={node.namefornodeinfw_module.graph.nodesifnode.op=='call_function'}bw_module_nodes={node.namefornodeinbw_module.graph.nodesifnode.op=='call_function'}remat_nodes=fw_module_nodes&bw_module_nodescounts=defaultdict(int)fornodeinfw_module.graph.nodes:ifnode.nameinremat_nodesandhasattr(node.target,'_overloadpacket'):counts[str(node.target._overloadpacket)]+=1print(f"# remat/fw/bw: {len(remat_nodes)}/{len(fw_module_nodes)}/{len(bw_module_nodes)}")print("Count of Ops Rematerialized: ",sorted(counts.items(),key=lambdax:x[1],reverse=True))returnfw_module,bw_module
defdraw_graph(traced:torch.fx.GraphModule,fname:str,figname:str="fx_graph",clear_meta=True):ifclear_meta:new_graph=copy.deepcopy(traced.graph)traced=fx.GraphModule(traced,new_graph)fornodeintraced.graph.nodes:node.meta={}base,ext=os.path.splitext(fname)ifnotext:ext=".svg"print(f"Writing FX graph to file: {base}{ext}")g=graph_drawer.FxGraphDrawer(traced,figname)x=g.get_main_dot_graph()getattr(x,"write_"+ext.lstrip("."))(f"{base}{ext}")defdraw_joint_graph(graph,joint_inputs,file_name="full_graph.png"):draw_graph(graph,file_name)returndefault_partition(graph,joint_inputs)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.