importtorchimporttorch.fxasfximportoperatorimportmathimporttorch.utils._pytreeaspytreeimportcopyimportosfromcollectionsimportdefaultdictfromtorch.fx.passesimportgraph_drawerfromtypingimportTuplefrom.compile_utilsimportfx_graph_cse,get_aten_targetfrom.importconfigAOT_PARTITIONER_DEBUG=config.debug_partitionerclassInvalidNodeBase(object):def__repr__(self):return"Invalid Node"InvalidNode=InvalidNodeBase()def_extract_graph_with_inputs_outputs(joint_graph,inputs,outputs):""" Given a graph, extracts out a subgraph that takes the specified nodes as inputs and returns the specified outputs. This includes specifying non-placeholder nodes as inputs. The general strategy is to initialize all inputs with proxies as we encounter them, and trace through the graph, only keeping values which take in valid proxies. Then, all dead code is eliminated. """new_graph=fx.Graph()env={}# Add new placeholder nodes in the order specified by the inputsfornodeininputs:new_node=new_graph.placeholder(node.name)# Can't use node_copy here as we may be turning previous call_function into placeholdersnew_node.meta=node.metaenv[node]=new_nodefornodeinjoint_graph.nodes:ifnodeininputs:continueelifnode.op=='placeholder':env[node]=InvalidNodeelifnode.op=='call_function':all_args=pytree.tree_flatten((node.args,node.kwargs))[0]all_args=[isinstance(env[x],InvalidNodeBase)forxinall_argsifisinstance(x,fx.Node)]ifany(all_args):env[node]=InvalidNodecontinueenv[node]=new_graph.node_copy(node,lambdax:env[x])elifnode.op=='get_attr':env[node]=new_graph.node_copy(node,lambdax:env[x])elifnode.op=='output':passoutput_values=[]forxinoutputs:ifisinstance(x,fx.Node):ifxnotinenv:raiseRuntimeError(f"Node {x} couldn't be found in env")output_values.append(env[x])else:output_values.append(x)new_graph.output(output_values)new_graph.eliminate_dead_code()new_graph.lint()returnnew_graphdef_is_primal(node):returnnode.op=="placeholder"and"tangents"notinnode.targetdef_is_tangent(node):returnnode.op=="placeholder"and"tangents"innode.targetdef_extract_fwd_bwd_outputs(joint_module:fx.GraphModule):num_fwd_outputs=joint_module._out_spec.children_specs[0].num_leavesoutputs=pytree.tree_flatten([node.argsfornodeinjoint_module.graph.nodesifnode.op=='output'])[0]fwd_outputs=outputs[:num_fwd_outputs]bwd_outputs=outputs[num_fwd_outputs:]returnfwd_outputs,bwd_outputsdef_extract_fwd_bwd_modules(joint_module:fx.GraphModule,saved_values):fwd_outputs,bwd_outputs=_extract_fwd_bwd_outputs(joint_module)primal_inputs=list(filter(_is_primal,joint_module.graph.nodes))tangent_inputs=list(filter(_is_tangent,joint_module.graph.nodes))# Construct the forward modulefwd_graph=_extract_graph_with_inputs_outputs(joint_module.graph,primal_inputs,fwd_outputs+saved_values)bwd_graph=_extract_graph_with_inputs_outputs(joint_module.graph,saved_values+tangent_inputs,bwd_outputs)# This is to filter out saved values that don't actually end up being used by the backwards passfornodeinbwd_graph.nodes:ifnode.op=='placeholder'andnotnode.users:forsaved_valueinsaved_values:ifsaved_value.name==node.name:saved_values.remove(saved_value)break# Now, we re-generate the fwd/bwd graphs.# NB: This might increase compilation time, but I doubt it mattersfwd_graph=_extract_graph_with_inputs_outputs(joint_module.graph,primal_inputs,fwd_outputs+saved_values)bwd_graph=_extract_graph_with_inputs_outputs(joint_module.graph,saved_values+tangent_inputs,bwd_outputs)fwd_module=fx.GraphModule(joint_module,fwd_graph)bwd_module=fx.GraphModule(joint_module,bwd_graph)returnfwd_module,bwd_module
[docs]defdefault_partition(joint_module:fx.GraphModule,_joint_inputs)->Tuple[fx.GraphModule,fx.GraphModule]:""" Partitions the :attr:`joint_module` in a manner that closely resembles the behavior observed in the original ``.forward()`` and ``.backward()`` of the callable, i.e., the resulting forward graph contains those operators that are executed in the original ``.forward()`` callable passed to :func:`aot_function`. The default partitioner collects the operators that are between the forward inputs and the forward outputs. This helps in finding the tensors which have to be stashed for the backward pass. These stashed tensors become the output of the generated forward graph. The remaining operators are then placed in the backward graph. .. warning:: This API is experimental and likely to change. Args: joint_module(fx.GraphModule): The joint forward and backward graph. This is the result of AOT Autograd tracing. Returns: Returns the generated forward and backward Fx graph modules. """primal_inputs=list(filter(_is_primal,joint_module.graph.nodes))fwd_outputs,bwd_outputs=_extract_fwd_bwd_outputs(joint_module)forward_only_graph=_extract_graph_with_inputs_outputs(joint_module.graph,primal_inputs,fwd_outputs)forward_node_names={node.namefornodeinforward_only_graph.nodesifnode.op!='output'}saved_values=[]fornodeinjoint_module.graph.nodes:ifnode.namenotinforward_node_names:continue# Since we can't save tuple of tensor values, we need to flatten out what we're savingif'tensor_meta'notinnode.metaandnode.op=='call_function':users=node.usersassertall(user.target==operator.getitemforuserinusers)foruserinusers:saved_values.append(user)else:saved_values.append(node)saved_values=list(set(saved_values))return_extract_fwd_bwd_modules(joint_module,saved_values)
def_prod(x):s=1foriinx:s*=ireturnsdef_size_of(metadata):sizes={torch.float:4,torch.float16:2,torch.bfloat16:2,torch.float32:4,torch.float64:8,torch.int:4,torch.int8:1,torch.int16:2,torch.int32:4,torch.int64:8,torch.uint8:1,torch.bool:1,}numel=_prod(metadata.shape)dtype=metadata.dtypeifdtypenotinsizes:raiseNotImplementedError("Don't know the size of dtype ",dtype)returnnumel*sizes[dtype]# Used for some investigative purposesdef_count_ops(graph):fromcollectionsimportdefaultdictcnt=defaultdict(int)fornodeingraph.nodes:ifnode.op=='call_function':cnt[node.target.__name__]+=1print(sorted(cnt.items(),key=lambdax:x[1],reverse=True))
[docs]defmin_cut_rematerialization_partition(joint_module:fx.GraphModule,_joint_inputs,compiler="nvfuser")->Tuple[fx.GraphModule,fx.GraphModule]:""" Partitions the joint graph such that the backward recomputes the forward. Recomputing helps in trading off memory bandwidth with computation. To create the fwd and bwd graph, we copy the joint graph, manually set the outputs to just original forward or backward outputs. And then we run the resulting graphs through dead code elimintation. .. warning:: This API is experimental and likely to change. Args: joint_module(fx.GraphModule): The joint forward and backward graph. This is the result of AOT Autograd tracing. Returns: Returns the generated forward and backward Fx graph modules. """try:importnetworkxasnxexceptImportError:raiseRuntimeError("Need networkx installed to perform smart recomputation heuristics")joint_module.graph.eliminate_dead_code()joint_module.recompile()fx_g=joint_module.graph# add the CSE passcse_graph=fx_graph_cse(fx_g)joint_module.graph=cse_graphfull_bw_graph=joint_module.graphname_to_node={}fornodeinjoint_module.graph.nodes:name_to_node[node.name]=nodedefclassify_nodes(joint_module):required_bw_nodes=set()fornodeinjoint_module.graph.nodes:ifnode.op=='placeholder'and"tangents"innode.target:required_bw_nodes.add(node)ifnodeinrequired_bw_nodes:foruserinnode.users:required_bw_nodes.add(user)primal_inputs=list(filter(_is_primal,joint_module.graph.nodes))fwd_outputs,_=_extract_fwd_bwd_outputs(joint_module)forward_only_graph=_extract_graph_with_inputs_outputs(joint_module.graph,primal_inputs,fwd_outputs)required_fw_nodes={name_to_node[node.name]fornodeinforward_only_graph.nodesifnode.op!='output'}unclaimed_nodes={nodefornodeinjoint_module.graph.nodesifnodenotinrequired_fw_nodesandnodenotinrequired_bw_nodes}returnrequired_fw_nodes,required_bw_nodes,unclaimed_nodesrequired_fw_nodes,required_bw_nodes,unclaimed_nodes=classify_nodes(joint_module)fornodeinreversed(joint_module.graph.nodes):ifnodenotinrequired_fw_nodes:node.dist_from_bw=0else:node.dist_from_bw=int(1e9)foruserinnode.users:node.dist_from_bw=min(node.dist_from_bw,user.dist_from_bw+1)aten=torch.ops.atenprims=torch.ops.primsrecomputable_ops=[aten.add,aten.sub,aten.div,aten.atan2,aten.mul,aten.max,aten.min,aten.pow,aten.remainder,aten.fmod,aten.__and__,aten.__or__,aten.__xor__,aten.__lshift__,aten.__rshift__,aten.eq,aten.ne,aten.ge,aten.gt,aten.le,aten.lt,aten.abs,aten.bitwise_not,aten.ceil,aten.floor,aten.frac,aten.neg,aten.relu,aten.round,aten.silu,aten.trunc,aten.log,aten.log10,aten.log1p,aten.log2,aten.lgamma,aten.exp,aten.expm1,aten.erf,aten.erfc,aten.cos,aten.acos,aten.cosh,aten.sin,aten.asin,aten.sinh,aten.tan,aten.atan,aten.tanh,aten.atanh,aten.sqrt,aten.rsqrt,aten.reciprocal,aten.sigmoid,aten.softplus,aten.threshold,aten.threshold_backward,aten.clamp,aten.where,aten.lerp,aten.addcmul,aten.gelu,aten.gelu_backward,aten.alias,aten.softmax,aten._softmax,aten._softmax_backward_data,aten.sum,aten.mean,aten._grad_sum_to_size,aten.sum_to_size,aten.amax,aten.to,aten.type_as,operator.getitem,aten.squeeze,aten.unsqueeze]# noqa: E501ifcompiler=="inductor":recomputable_ops+=[prims.div,prims.convert_element_type,aten.sign,aten.clone,aten._to_copy,aten.full_like,prims.var,prims.sum,aten.var,aten.std,prims.broadcast_in_dim,aten.select,aten.permute,aten._unsafe_view,aten.view,aten.expand,aten.slice,aten.reshape,aten.broadcast_tensors,aten.scalar_tensor,aten.ones,aten.new_zeros,aten.lift_fresh_copy,aten.minimum,aten.arange,aten.bitwise_and,aten.triu,aten.var_mean,aten.isinf,aten.any,aten.isnan,aten.full,aten.as_strided,aten.zeros,aten.argmax,aten.maximum,aten.bitwise_or,aten.logical_and,aten.logical_or]# noqa: E501# Natalia said that we should allow recomputing indexing :)recomputable_ops+=[aten.index]recomputable_ops=set(recomputable_ops)random_ops=[aten.native_dropout,aten.rand_like,aten.randn_like]compute_intensive_ops=[aten.mm,aten.convolution,aten.convolution_backward,aten.bmm,aten.addmm,aten.upsample_bilinear2d]# noqa: E501unrecomputable_ops=random_ops+compute_intensive_opsfusible_ops=recomputable_ops|set(random_ops)ifAOT_PARTITIONER_DEBUG:joint_module_ops=set(str(node.target._overloadpacket)fornodeinjoint_module.graph.nodesifnode.op=="call_function"andhasattr(node.target,"_overloadpacket"))ops_ignored=joint_module_ops-set([str(i)foriinrecomputable_ops])print("Ops banned from rematerialization: ",ops_ignored)print()AGGRESSIVE_RECOMPUTATION=Falsedef_maybe_size_of(node):if'tensor_meta'innode.meta:return_size_of(node.meta['tensor_meta'])return0defban_recomputation(node):ifAGGRESSIVE_RECOMPUTATION:return(node.op=='call_function'andget_aten_target(node)inunrecomputable_ops)else:ifnode.op!='call_function':returnFalseifget_aten_target(node)notinrecomputable_ops:returnTrueifnode.target==operator.getitem:returnFalseifcompiler=="inductor"andnode.dist_from_bw>4:returnTrue# If the output of an op is 4x smaller (arbitrary choice),# then we don't allow recomputation.if'tensor_meta'notinnode.meta:returnFalseinput_tensors_size=sum(_maybe_size_of(i)foriinnode.argsifisinstance(i,fx.Node))output_size=_size_of(node.meta['tensor_meta'])return(output_size*4<input_tensors_size)defis_fusible(a,b):returnget_aten_target(a)infusible_opsandget_aten_target(b)infusible_opsdefis_materialized(node):ifnode.op=='placeholder':returnTruereturnnotall(is_fusible(node,user)foruserinnode.users)defget_node_weight(node):mem_sz=_size_of(node.meta['tensor_meta'])# Heuristic to bias towards nodes closer to the backwards pass# Complete guess about current valuemem_sz=int(mem_sz*(1.1**max(min(node.dist_from_bw,100),1)))# mem_sz = int(mem_sz + node.dist_from_bw)ifis_materialized(node):returnmem_szelse:returnmem_sz*2nx_graph=nx.DiGraph()fornodeinfull_bw_graph.nodes:ifnode.op=='output':continueifnodeinrequired_bw_nodes:nx_graph.add_edge(node.name+"_in","sink",capacity=math.inf)continueifnode.op=='placeholder'and"primals"innode.target:nx_graph.add_edge("source",node.name+"_in",capacity=math.inf)# If a node can't be recomputed (too expensive or involves randomness),# we prevent it from being recomputed by adding an inf edge to the source# We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed.ifban_recomputation(node)andnodeinrequired_fw_nodes:nx_graph.add_edge("source",node.name+"_in",capacity=math.inf)if'tensor_meta'notinnode.meta:weight=math.infelse:weight=get_node_weight(node)# Creates the weights on the "node" edgenx_graph.add_edge(node.name+"_in",node.name+"_out",capacity=weight)foruserinnode.users:nx_graph.add_edge(node.name+"_out",user.name+"_in",capacity=math.inf)cut_value,partition=nx.minimum_cut(nx_graph,"source","sink")reachable,non_reachable=partitioncutset=set()foru,nbrsin((n,nx_graph[n])forninreachable):cutset.update((u,v)forvinnbrsifvinnon_reachable)cut_nodes=set()fornode_in,node_outincutset:assertnode_in[:-3]==node_out[:-4]node_name=node_in[:-3]cut_nodes.add(node_name)# To make this stuff deterministicnode_idx={node:idxforidx,nodeinenumerate(joint_module.graph.nodes)}saved_values=sorted((name_to_node[node]fornodeincut_nodes),key=lambdax:node_idx[x])fw_module,bw_module=_extract_fwd_bwd_modules(joint_module,saved_values)ifAOT_PARTITIONER_DEBUG:print("Theoretical Activations Stored: ",sum([_size_of(i.meta['tensor_meta'])foriinsaved_values])/1e9)fw_module_nodes=set([node.namefornodeinfw_module.graph.nodesifnode.op=='call_function'])bw_module_nodes=set([node.namefornodeinbw_module.graph.nodesifnode.op=='call_function'])remat_nodes=fw_module_nodes&bw_module_nodescounts=defaultdict(int)fornodeinfw_module.graph.nodes:ifnode.nameinremat_nodesandhasattr(node.target,'_overloadpacket'):counts[str(node.target._overloadpacket)]+=1print("# nodes rematerialized: ",len(remat_nodes))print("Count of Ops Rematerialized: ",sorted(counts.items(),key=lambdax:x[1],reverse=True))returnfw_module,bw_module
defdraw_graph(traced:torch.fx.GraphModule,fname:str,figname:str="fx_graph",clear_meta=True):ifclear_meta:new_graph=copy.deepcopy(traced.graph)traced=fx.GraphModule(traced,new_graph)fornodeintraced.graph.nodes:node.meta={}base,ext=os.path.splitext(fname)ifnotext:ext=".svg"print(f"Writing FX graph to file: {base}{ext}")g=graph_drawer.FxGraphDrawer(traced,figname)x=g.get_main_dot_graph()getattr(x,"write_"+ext.lstrip("."))(f"{base}{ext}")defdraw_joint_graph(graph,joint_inputs,file_name="full_graph.png"):draw_graph(graph,file_name)returndefault_partition(graph,joint_inputs)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.