Shortcuts

# Source code for torch.fx.subgraph_rewriter

from .graph_module import GraphModule
from .graph import Graph
from .node import Node
from .symbolic_trace import symbolic_trace

import copy
from typing import Callable, Dict, List, NamedTuple, Optional, Set
import torch

class Match(NamedTuple):
# Node from which the match was found
anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node]

class SubgraphMatcher:
def __init__(self, pattern: Graph) -> None:
self.pattern = pattern
if len(pattern.nodes) == 0:
raise ValueError("SubgraphMatcher cannot be initialized with an "
"empty pattern")
# self.pattern_anchor is the output Node in pattern
self.pattern_anchor = next(iter(reversed(pattern.nodes)))
# Ensure that there is only a single output value in the pattern
# since we don't support multiple outputs
assert len(self.pattern_anchor.all_input_nodes) == 1, \
"Pattern matching on multiple outputs is not supported"
# Maps nodes in the pattern subgraph to nodes in the larger graph
self.nodes_map: Dict[Node, Node] = {}

def matches_subgraph_from_anchor(self, anchor: Node) -> bool:
"""
Checks if the whole pattern can be matched starting from
anchor in the larger graph.

Pattern matching is done by recursively comparing the pattern
node's use-def relationships against the graph node's.
"""
self.nodes_map = {}
return self._match_nodes(self.pattern_anchor, anchor)

# Compare the pattern node pn against the graph node gn
def _match_nodes(self, pn: Node, gn: Node) -> bool:

# Check if we've already matched these nodes in the current
# traversal
if pn in self.nodes_map:
return self.nodes_map[pn] == gn

def attributes_are_equal(pn: Node, gn: Node) -> bool:
# Use placeholder and output nodes as wildcards. The
# only exception is that an output node can't match
# a placeholder
if (pn.op == "placeholder"
or (pn.op == "output" and gn.op != "placeholder")):
return True
return pn.op == gn.op and pn.target == gn.target

# Terminate early if the node attributes are not equal
if not attributes_are_equal(pn, gn):
return False

# Optimistically mark pn as a match for gn
self.nodes_map[pn] = gn

# Traverse the use-def relationships to ensure that pn is a true
# match for gn
if pn.op == "placeholder":
return True
if (pn.op != "output"
and len(pn.all_input_nodes) != len(gn.all_input_nodes)):
return False
if pn.op == "output":
match_found = any(self._match_nodes(pn.all_input_nodes, gn_)
for gn_ in gn.all_input_nodes)
else:
match_found = (len(pn.all_input_nodes) == len(gn.all_input_nodes)
and all(self._match_nodes(pn_, gn_) for pn_, gn_
in zip(pn.all_input_nodes, gn.all_input_nodes)))
if not match_found:
self.nodes_map.pop(pn)
return False

return True

def _replace_submodules(gm: GraphModule, replacement: torch.nn.Module) -> None:
gm.delete_all_unused_submodules()

if isinstance(replacement, GraphModule):
replacement.graph.lint()

def try_get_submodule(mod: torch.nn.Module, target: str) -> Optional[torch.nn.Module]:
try:
mod_match = mod.get_submodule(target)
return mod_match
except AttributeError:
return None

for node in gm.graph.nodes:
if node.op == "call_module" or node.op == "get_attr":

gm_submod = try_get_submodule(gm, node.target)

replacement_submod = try_get_submodule(replacement, node.target)

# CASE 1: This target already exists as a submodule in our
# result GraphModule. Whether or not it exists in
# replacement, the existing submodule takes precedence.
if gm_submod is not None:
continue

# CASE 2: The target exists as a submodule in replacement
# only, so we need to copy it over.
elif replacement_submod is not None:
new_submod = copy.deepcopy(getattr(replacement, node.target))

# CASE 3: The target doesn't exist as a submodule in gm
# or replacement
else:
raise RuntimeError("Attempted to create a \"", node.op,
"\" node during subgraph rewriting "
f"with target {node.target}, but "
"the referenced submodule does not "
"exist in either the original "
"GraphModule gm or the replacement"
" GraphModule replacement")

gm.graph.lint()

[docs]def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -> List[Match]:
"""
Matches all possible non-overlapping sets of operators and their
data dependencies (pattern) in the Graph of a GraphModule
(gm), then replaces each of these matched subgraphs with another
subgraph (replacement).

Args:
gm: The GraphModule that wraps the Graph to operate on
pattern: The subgraph to match in gm for replacement
replacement: The subgraph to replace pattern with

Returns:
List[Match]: A list of Match objects representing the places
in the original graph that pattern was matched to. The list
is empty if there are no matches. Match is defined as:

.. code-block:: python

class Match(NamedTuple):
# Node from which the match was found
anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node]

Examples:

.. code-block:: python

import torch
from torch.fx import symbolic_trace, subgraph_rewriter

class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, w1, w2):
m1 = torch.cat([w1, w2]).sum()
m2 = torch.cat([w1, w2]).sum()
return x + torch.max(m1) + torch.max(m2)

def pattern(w1, w2):

def replacement(w1, w2):

traced_module = symbolic_trace(M())

subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)

The above code will first match pattern in the forward
method of traced_module. Pattern-matching is done based on
use-def relationships, not node names. For example, if you had
p = torch.cat([a, b]) in pattern, you could match
m = torch.cat([a, b]) in the original forward function,
despite the variable names being different (p vs m).

The return statement in pattern is matched based on its
value only; it may or may not match to the return statement in
the larger graph. In other words, the pattern doesn't have to extend
to the end of the larger graph.

When the pattern is matched, it will be removed from the larger
function and replaced by replacement. If there are multiple
matches for pattern in the larger function, each non-overlapping
match will be replaced. In the case of a match overlap, the first
found match in the set of overlapping matches will be replaced.
("First" here being defined as the first in a topological ordering
of the Nodes' use-def relationships. In most cases, the first Node
is the parameter that appears directly after self, while the
last Node is whatever the function returns.)

One important thing to note is that the parameters of the
pattern Callable must be used in the Callable itself,
and the parameters of the replacement Callable must match
the pattern. The first rule is why, in the above code block, the
forward function has parameters x, w1, w2, but the
pattern function only has parameters w1, w2. pattern
doesn't use x, so it shouldn't specify x as a parameter.
As an example of the second rule, consider replacing

.. code-block:: python

def pattern(x, y):

with

.. code-block:: python

def replacement(x, y):

In this case, replacement needs the same number of parameters
as pattern (both x and y), even though the parameter
y isn't used in replacement.

After calling subgraph_rewriter.replace_pattern, the generated
Python code looks like this:

.. code-block:: python

def forward(self, x, w1, w2):
stack_1 = torch.stack([w1, w2])
sum_1 = stack_1.sum()
stack_2 = torch.stack([w1, w2])
sum_2 = stack_2.sum()
max_1 = torch.max(sum_1)
max_2 = torch.max(sum_2)

"""
# Get the graphs for gm, pattern, replacement
original_graph = gm.graph
pattern_graph = symbolic_trace(pattern).graph
replacement_graph = symbolic_trace(replacement).graph

# Find all possible pattern matches in original_graph. Note that
# pattern matches may overlap with each other.
matcher = SubgraphMatcher(pattern_graph)
matches: List[Match] = []

# Consider each node as an "anchor" (deepest matching graph node)
for anchor in original_graph.nodes:

if matcher.matches_subgraph_from_anchor(anchor):

def pattern_is_contained(nodes_map : Dict[Node, Node]) -> bool:
# lookup represents all the nodes in original_graph
# that are part of pattern
lookup: Dict[Node, Node] = {v : k for k, v
in nodes_map.items()}
for n in lookup.keys():

# Nodes that can "leak"...

# Placeholders (by definition)
if n.op == "placeholder":
continue
# Pattern output (acts as a container)
if lookup[n].op == "output":
continue
# Result contained by pattern output (what we'll
# hook in to the new Graph, thus what we'll
# potentially use in other areas of the Graph as
# an input Node)
if (len(lookup[n].users) == 1
and list(lookup[n].users.keys()).op == "output"):
continue

for user in n.users:
# If this node has users that were not in
# lookup, then it must leak out of the
# pattern subgraph
if user not in lookup:
return False
return True

# It's not a match if the pattern leaks out into the rest
# of the graph
if pattern_is_contained(matcher.nodes_map):
for k, v in matcher.nodes_map.items():
# Shallow copy nodes_map
matches.append(Match(anchor=anchor,
nodes_map=copy.copy(matcher.nodes_map)))

# The set of all nodes in original_graph that we've seen thus far
# as part of a pattern match
replaced_nodes: Set[Node] = set()

# Return True if one of the nodes in the current match has already
# been used as part of another match
def overlaps_with_prev_match(match: Match) -> bool:
for n in match.nodes_map.values():
if n in replaced_nodes and n.op != "placeholder":
return True
return False

for match in matches:

# Skip overlapping matches
if overlaps_with_prev_match(match):
continue

# Map replacement graph nodes to their copy in original_graph
val_map: Dict[Node, Node] = {}

pattern_placeholders = [n for n in pattern_graph.nodes
if n.op == "placeholder"]
assert len(pattern_placeholders)
replacement_placeholders = [n for n in replacement_graph.nodes
if n.op == "placeholder"]
assert len(pattern_placeholders) == len(replacement_placeholders)
placeholder_map = {r : p for r, p
in zip(replacement_placeholders, pattern_placeholders)}

# node from original_graph that matched with the output node
# in pattern
subgraph_output: Node = match.anchor

def mark_node_as_replaced(n: Node) -> None:
if n not in match.nodes_map.values():
return
for n_ in n.all_input_nodes:
mark_node_as_replaced(n_)

mark_node_as_replaced(subgraph_output)

# Intialize val_map with mappings from placeholder nodes in
# replacement to their corresponding node in original_graph
for replacement_node in replacement_placeholders:
# Get the original_graph placeholder node
# corresponding to the current replacement_node
pattern_node = placeholder_map[replacement_node]
original_graph_node = match.nodes_map[pattern_node]
# Populate val_map
val_map[replacement_node] = original_graph_node

# Copy the replacement graph over
with original_graph.inserting_before(subgraph_output):
copied_output = original_graph.graph_copy(replacement_graph,
val_map)

# Hook the output Node of the replacement subgraph in to the
# original Graph at the correct location

# CASE 1: We need to hook the replacement subgraph in somewhere
# in the middle of the graph. We replace the Node in the
# original graph that corresponds to the end of the pattern
# subgraph
if subgraph_output.op != "output":
# subgraph_output may have multiple args. These args could
# be from the orignal graph, or they could have come from
# the insertion of replacement_subgraph. We need to find
# the Node that was originally matched as part of
# pattern (i.e. a Node from the original graph). We can
# figure this out by looking in match.nodes_map. The map
# was created before replacement_subgraph was spliced in,
# so we know that, if a Node is in match.nodes_map.values,
# it must have come from the original graph
for n in subgraph_output.all_input_nodes:
if (n.op != "placeholder"
and n in match.nodes_map.values()):
subgraph_output = n
break
assert subgraph_output.op != "output"
# CASE 2: The pattern subgraph match extends to the end of the
# original graph, so we need to change the current graph's
# output Node to reflect the insertion of the replacement graph.
# We'll keep the current output Node, but update its args and
# _input_nodes as necessary
else:
subgraph_output.args = ((copied_output,))
if isinstance(copied_output, Node):
subgraph_output._input_nodes = {copied_output: None}

assert isinstance(copied_output, Node)
subgraph_output.replace_all_uses_with(copied_output)

# Erase the pattern nodes
for node in reversed(original_graph.nodes):
if len(node.users) == 0 and node.op != "output":
original_graph.erase_node(node)

# Update the passed-in GraphModule to reflect the new state of
# original_graph
gm.recompile()

# If replacement was an nn.Module, we'll need to make sure that
# all the submodules have been copied over correctly
if isinstance(replacement, torch.nn.Module):
_replace_submodules(gm, replacement)

return matches ## Docs

Access comprehensive developer documentation for PyTorch

View Docs

## Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials