from.graph_moduleimportGraphModulefrom.graphimportGraphfrom.nodeimportNodefrom._symbolic_traceimportsymbolic_tracefrom._compatibilityimportcompatibilityimportcopyfromtypingimportCallable,Dict,List,NamedTuple,Optional,Setimporttorch@compatibility(is_backward_compatible=True)classMatch(NamedTuple):# Node from which the match was foundanchor:Node# Maps nodes in the pattern subgraph to nodes in the larger graphnodes_map:Dict[Node,Node]class_SubgraphMatcher:def__init__(self,pattern:Graph)->None:self.pattern=patterniflen(pattern.nodes)==0:raiseValueError("_SubgraphMatcher cannot be initialized with an ""empty pattern")# `self.pattern_anchor` is the output Node in `pattern`self.pattern_anchor=next(iter(reversed(pattern.nodes)))# Ensure that there is only a single output value in the pattern# since we don't support multiple outputsassertlen(self.pattern_anchor.all_input_nodes)==1, \
"Pattern matching on multiple outputs is not supported"# Maps nodes in the pattern subgraph to nodes in the larger graphself.nodes_map:Dict[Node,Node]={}defmatches_subgraph_from_anchor(self,anchor:Node)->bool:""" Checks if the whole pattern can be matched starting from ``anchor`` in the larger graph. Pattern matching is done by recursively comparing the pattern node's use-def relationships against the graph node's. """self.nodes_map={}returnself._match_nodes(self.pattern_anchor,anchor)# Compare the pattern node `pn` against the graph node `gn`def_match_nodes(self,pn:Node,gn:Node)->bool:# Check if we've already matched these nodes in the current# traversalifpninself.nodes_map:returnself.nodes_map[pn]==gndefattributes_are_equal(pn:Node,gn:Node)->bool:# Use placeholder and output nodes as wildcards. The# only exception is that an output node can't match# a placeholderif(pn.op=="placeholder"or(pn.op=="output"andgn.op!="placeholder")):returnTruereturnpn.op==gn.opandpn.target==gn.target# Terminate early if the node attributes are not equalifnotattributes_are_equal(pn,gn):returnFalse# Optimistically mark `pn` as a match for `gn`self.nodes_map[pn]=gn# Traverse the use-def relationships to ensure that `pn` is a true# match for `gn`ifpn.op=="placeholder":returnTrueif(pn.op!="output"andlen(pn.all_input_nodes)!=len(gn.all_input_nodes)):returnFalseifpn.op=="output":match_found=any(self._match_nodes(pn.all_input_nodes[0],gn_)forgn_ingn.all_input_nodes)else:match_found=(len(pn.all_input_nodes)==len(gn.all_input_nodes)andall(self._match_nodes(pn_,gn_)forpn_,gn_inzip(pn.all_input_nodes,gn.all_input_nodes)))ifnotmatch_found:self.nodes_map.pop(pn)returnFalsereturnTruedef_replace_submodules(gm:GraphModule,replacement:torch.nn.Module)->None:gm.delete_all_unused_submodules()ifisinstance(replacement,GraphModule):replacement.graph.lint()deftry_get_submodule(mod:torch.nn.Module,target:str)->Optional[torch.nn.Module]:try:mod_match=mod.get_submodule(target)returnmod_matchexceptAttributeError:returnNonefornodeingm.graph.nodes:ifnode.op=="call_module"ornode.op=="get_attr":gm_submod=try_get_submodule(gm,node.target)replacement_submod=try_get_submodule(replacement,node.target)# CASE 1: This target already exists as a submodule in our# result GraphModule. Whether or not it exists in# `replacement`, the existing submodule takes precedence.ifgm_submodisnotNone:continue# CASE 2: The target exists as a submodule in `replacement`# only, so we need to copy it over.elifreplacement_submodisnotNone:new_submod=copy.deepcopy(getattr(replacement,node.target))gm.add_submodule(node.target,new_submod)# CASE 3: The target doesn't exist as a submodule in `gm`# or `replacement`else:raiseRuntimeError("Attempted to create a \"",node.op,"\" node during subgraph rewriting "f"with target {node.target}, but ""the referenced submodule does not ""exist in either the original ""GraphModule `gm` or the replacement"" GraphModule `replacement`")gm.graph.lint()
[docs]@compatibility(is_backward_compatible=True)defreplace_pattern(gm:GraphModule,pattern:Callable,replacement:Callable)->List[Match]:""" Matches all possible non-overlapping sets of operators and their data dependencies (``pattern``) in the Graph of a GraphModule (``gm``), then replaces each of these matched subgraphs with another subgraph (``replacement``). Args: ``gm``: The GraphModule that wraps the Graph to operate on ``pattern``: The subgraph to match in ``gm`` for replacement ``replacement``: The subgraph to replace ``pattern`` with Returns: List[Match]: A list of ``Match`` objects representing the places in the original graph that ``pattern`` was matched to. The list is empty if there are no matches. ``Match`` is defined as: .. code-block:: python class Match(NamedTuple): # Node from which the match was found anchor: Node # Maps nodes in the pattern subgraph to nodes in the larger graph nodes_map: Dict[Node, Node] Examples: .. code-block:: python import torch from torch.fx import symbolic_trace, subgraph_rewriter class M(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, w1, w2): m1 = torch.cat([w1, w2]).sum() m2 = torch.cat([w1, w2]).sum() return x + torch.max(m1) + torch.max(m2) def pattern(w1, w2): return torch.cat([w1, w2]).sum() def replacement(w1, w2): return torch.stack([w1, w2]) traced_module = symbolic_trace(M()) subgraph_rewriter.replace_pattern(traced_module, pattern, replacement) The above code will first match ``pattern`` in the ``forward`` method of ``traced_module``. Pattern-matching is done based on use-def relationships, not node names. For example, if you had ``p = torch.cat([a, b])`` in ``pattern``, you could match ``m = torch.cat([a, b])`` in the original ``forward`` function, despite the variable names being different (``p`` vs ``m``). The ``return`` statement in ``pattern`` is matched based on its value only; it may or may not match to the ``return`` statement in the larger graph. In other words, the pattern doesn't have to extend to the end of the larger graph. When the pattern is matched, it will be removed from the larger function and replaced by ``replacement``. If there are multiple matches for ``pattern`` in the larger function, each non-overlapping match will be replaced. In the case of a match overlap, the first found match in the set of overlapping matches will be replaced. ("First" here being defined as the first in a topological ordering of the Nodes' use-def relationships. In most cases, the first Node is the parameter that appears directly after ``self``, while the last Node is whatever the function returns.) One important thing to note is that the parameters of the ``pattern`` Callable must be used in the Callable itself, and the parameters of the ``replacement`` Callable must match the pattern. The first rule is why, in the above code block, the ``forward`` function has parameters ``x, w1, w2``, but the ``pattern`` function only has parameters ``w1, w2``. ``pattern`` doesn't use ``x``, so it shouldn't specify ``x`` as a parameter. As an example of the second rule, consider replacing .. code-block:: python def pattern(x, y): return torch.neg(x) + torch.relu(y) with .. code-block:: python def replacement(x, y): return torch.relu(x) In this case, ``replacement`` needs the same number of parameters as ``pattern`` (both ``x`` and ``y``), even though the parameter ``y`` isn't used in ``replacement``. After calling ``subgraph_rewriter.replace_pattern``, the generated Python code looks like this: .. code-block:: python def forward(self, x, w1, w2): stack_1 = torch.stack([w1, w2]) sum_1 = stack_1.sum() stack_2 = torch.stack([w1, w2]) sum_2 = stack_2.sum() max_1 = torch.max(sum_1) add_1 = x + max_1 max_2 = torch.max(sum_2) add_2 = add_1 + max_2 return add_2 """# Get the graphs for `gm`, `pattern`, `replacement`original_graph=gm.graphpattern_graph=symbolic_trace(pattern).graphreplacement_graph=symbolic_trace(replacement).graph# Find all possible pattern matches in original_graph. Note that# pattern matches may overlap with each other.matcher=_SubgraphMatcher(pattern_graph)matches:List[Match]=[]# Consider each node as an "anchor" (deepest matching graph node)foranchorinoriginal_graph.nodes:ifmatcher.matches_subgraph_from_anchor(anchor):defpattern_is_contained(nodes_map:Dict[Node,Node])->bool:# `lookup` represents all the nodes in `original_graph`# that are part of `pattern`lookup:Dict[Node,Node]={v:kfork,vinnodes_map.items()}forninlookup.keys():# Nodes that can "leak"...# Placeholders (by definition)ifn.op=="placeholder":continue# Pattern output (acts as a container)iflookup[n].op=="output":continue# Result contained by pattern output (what we'll# hook in to the new Graph, thus what we'll# potentially use in other areas of the Graph as# an input Node)if(len(lookup[n].users)==1andlist(lookup[n].users.keys())[0].op=="output"):continueforuserinn.users:# If this node has users that were not in# `lookup`, then it must leak out of the# pattern subgraphifusernotinlookup:returnFalsereturnTrue# It's not a match if the pattern leaks out into the rest# of the graphifpattern_is_contained(matcher.nodes_map):fork,vinmatcher.nodes_map.items():# Shallow copy nodes_mapmatches.append(Match(anchor=anchor,nodes_map=copy.copy(matcher.nodes_map)))# The set of all nodes in `original_graph` that we've seen thus far# as part of a pattern matchreplaced_nodes:Set[Node]=set()# Return True if one of the nodes in the current match has already# been used as part of another matchdefoverlaps_with_prev_match(match:Match)->bool:forninmatch.nodes_map.values():ifninreplaced_nodesandn.op!="placeholder":returnTruereturnFalseformatchinmatches:# Skip overlapping matchesifoverlaps_with_prev_match(match):continue# Map replacement graph nodes to their copy in `original_graph`val_map:Dict[Node,Node]={}pattern_placeholders=[nforninpattern_graph.nodesifn.op=="placeholder"]assertlen(pattern_placeholders)replacement_placeholders=[nforninreplacement_graph.nodesifn.op=="placeholder"]assertlen(pattern_placeholders)==len(replacement_placeholders)placeholder_map={r:pforr,pinzip(replacement_placeholders,pattern_placeholders)}# node from `original_graph` that matched with the output node# in `pattern`subgraph_output:Node=match.anchordefmark_node_as_replaced(n:Node)->None:ifnnotinmatch.nodes_map.values():returnforn_inn.all_input_nodes:mark_node_as_replaced(n_)replaced_nodes.add(n)mark_node_as_replaced(subgraph_output)# Intialize `val_map` with mappings from placeholder nodes in# `replacement` to their corresponding node in `original_graph`forreplacement_nodeinreplacement_placeholders:# Get the `original_graph` placeholder node# corresponding to the current `replacement_node`pattern_node=placeholder_map[replacement_node]original_graph_node=match.nodes_map[pattern_node]# Populate `val_map`val_map[replacement_node]=original_graph_node# Copy the replacement graph overwithoriginal_graph.inserting_before(subgraph_output):copied_output=original_graph.graph_copy(replacement_graph,val_map)# Hook the output Node of the replacement subgraph in to the# original Graph at the correct location# CASE 1: We need to hook the replacement subgraph in somewhere# in the middle of the graph. We replace the Node in the# original graph that corresponds to the end of the pattern# subgraphifsubgraph_output.op!="output":# `subgraph_output` may have multiple args. These args could# be from the orignal graph, or they could have come from# the insertion of `replacement_subgraph`. We need to find# the Node that was originally matched as part of# `pattern` (i.e. a Node from the original graph). We can# figure this out by looking in `match.nodes_map`. The map# was created before `replacement_subgraph` was spliced in,# so we know that, if a Node is in `match.nodes_map.values`,# it must have come from the original graphforninsubgraph_output.all_input_nodes:if(n.op!="placeholder"andninmatch.nodes_map.values()):subgraph_output=nbreakassertsubgraph_output.op!="output"# CASE 2: The pattern subgraph match extends to the end of the# original graph, so we need to change the current graph's# output Node to reflect the insertion of the replacement graph.# We'll keep the current output Node, but update its args and# `_input_nodes` as necessaryelse:subgraph_output.args=((copied_output,))ifisinstance(copied_output,Node):subgraph_output._input_nodes={copied_output:None}assertisinstance(copied_output,Node)subgraph_output.replace_all_uses_with(copied_output)# Erase the `pattern` nodesfornodeinreversed(original_graph.nodes):iflen(node.users)==0andnode.op!="output":original_graph.erase_node(node)# Update the passed-in GraphModule to reflect the new state of# `original_graph`gm.recompile()# If `replacement` was an nn.Module, we'll need to make sure that# all the submodules have been copied over correctlyifisinstance(replacement,torch.nn.Module):_replace_submodules(gm,replacement)returnmatches
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.