Source code for executorch.exir.lowered_backend_module
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import copy
import operator
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import torch
import torch.utils._pytree as pytree
from executorch.exir._serialize import _serialize_pte_binary
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name
from executorch.exir.emit import emit_program
from executorch.exir.graph_module import _get_submodule
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
from executorch.exir.passes.spec_prop_pass import make_spec, SpecPropPass
from executorch.exir.schema import Program
from executorch.exir.tracer import Value
from torch._library.fake_class_registry import FakeScriptObject
from torch._subclasses import FakeTensor
from torch.export.exported_program import (
ConstantArgument,
ExportedProgram,
ExportGraphSignature,
InputKind,
InputSpec,
ModuleCallEntry,
ModuleCallSignature,
OutputKind,
OutputSpec,
TensorArgument,
)
from torch.fx.passes.utils.fuser_utils import (
erase_nodes,
fuse_as_graphmodule,
insert_subgm,
legalize_graph,
NodeList,
topo_sort,
)
[docs]class LoweredBackendModule(torch.nn.Module):
"""
A subclass of nn.Module that is generated for modules containing
delegated functions. This is can be created by calling `to_backend`.
"""
_backend_id: str # The backend's name
_processed_bytes: bytes # The delegate blobs created from backend.preprocess
_compile_specs: List[
CompileSpec
] # A list of backend-specific objects with static metadata to configure the "compilation" process.
_original_exported_program: ExportedProgram # The original EXIR module
def __init__(
self,
edge_program: ExportedProgram,
backend_id: str,
processed_bytes: bytes,
compile_specs: List[CompileSpec],
) -> None:
super().__init__()
self._original_exported_program = edge_program
self._backend_id = backend_id
self._processed_bytes = processed_bytes
self._compile_specs = compile_specs
# pyre-ignore
def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule":
# Copy exported program
copied_program = ExportedProgram(
root=copy.deepcopy(self._original_exported_program.graph_module),
graph=copy.deepcopy(self._original_exported_program.graph),
graph_signature=copy.deepcopy(
self._original_exported_program.graph_signature
),
state_dict=self._original_exported_program.state_dict,
range_constraints=copy.deepcopy(
self._original_exported_program.range_constraints
),
module_call_graph=copy.deepcopy(
self._original_exported_program.module_call_graph
),
constants=self._original_exported_program.constants,
verifiers=[copy.deepcopy(self._original_exported_program.verifier)],
)
res = LoweredBackendModule(
edge_program=copied_program,
backend_id=self._backend_id,
processed_bytes=self._processed_bytes,
compile_specs=copy.deepcopy(self._compile_specs, memo),
)
# pyre-fixme[16]: `LoweredBackendModule` has no attribute `meta`.
res.meta = copy.copy(getattr(self, "meta", {}))
return res
@property
def backend_id(self) -> str:
"""
Returns the backends name.
"""
return self._backend_id
@property
def processed_bytes(self) -> bytes:
"""
Returns the delegate blob created from backend.preprocess
"""
return self._processed_bytes
@property
def compile_specs(self) -> List[CompileSpec]:
"""
Returns a list of backend-specific objects with static metadata to configure the "compilation" process.
"""
return self._compile_specs
@property
def original_module(self) -> ExportedProgram:
"""
Returns the original EXIR module
"""
return self._original_exported_program
# TODO(chenlai): consolidate the seriailization config with serialize_to_flatbuffer api
[docs] def buffer(
self,
extract_delegate_segments: bool = False,
segment_alignment: int = 128,
constant_tensor_alignment: Optional[int] = None,
delegate_alignment: Optional[int] = None,
memory_planning: MemoryPlanningPass = None, # pyre-fixme[9]
) -> bytes:
"""
Returns a buffer containing the serialized ExecuTorch binary.
"""
# TODO(T181463742): avoid calling bytes(..) which incurs large copies.
out = bytes(
_serialize_pte_binary(
program=self.program(memory_planning=memory_planning),
extract_delegate_segments=extract_delegate_segments,
segment_alignment=segment_alignment,
constant_tensor_alignment=constant_tensor_alignment,
delegate_alignment=delegate_alignment,
)
)
return out
# TODO(chenlai): re-consider recapture instead of manually constructing the program because
# the meta data construction is done manually.
def program(
self,
emit_stacktrace: bool = False,
memory_planning: MemoryPlanningPass = None, # pyre-fixme[9]
) -> Program:
# Fix autodpes introuces cyclic dependencies:
# program -> verifier -> lowered_backend_module -> program
# @manual
from executorch.exir.program._program import (
_get_updated_graph_signature,
_transform,
)
"""
Returns the object that represents the ExecuTorch binary before serialization.
"""
# Creates a new module based on the original module. The original module will
# look something like following:
#
# opcode name target args kwargs
# ------------- ------------------- ---------------- ------------------------------------------ --------
# placeholder arg0_1 arg0_1 () {}
# placeholder arg1_1 arg1_1 () {}
# call_function aten_repeat_default * (arg1_1, [4, 1]) {}
# call_function aten_mul_tensor * (aten_repeat_default, aten_repeat_default) {}
# call_function aten_add_tensor * (arg1_1, arg1_1) {}
# output output output ([aten_mul_tensor, aten_add_tensor],) {}
#
# if the whole module is lowered, the resulting lowered module look like
#
# opcode name target args kwargs
# ------------- ------------------------ --------------------------- ---------------------------------- --------
# placeholder arg0_1 arg0_1 () {}
# placeholder arg1_1 arg1_1 () {}
# get_attr lowered_module_0 lowered_module_0 () {}
# call_function executorch_call_delegate executorch_call_delegate (lowered_module_0, arg0_1, arg1_1) {}
# call_function getitem <built-in function getitem> (executorch_call_delegate, 0) {}
# call_function getitem_1 <built-in function getitem> (executorch_call_delegate, 1) {}
# output output_1 output ([getitem, getitem_1],) {}
#
# We'll remove all call_function nodes, insert an call_delegate node, inserting getitems nodes to get the result for call_delegate node
# and return the list of getitems as the output
lowered_exported_program = copy.deepcopy(self._original_exported_program)
# The real input nodes are the ones not buffer or parameter
all_input_nodes = [
node
for node in lowered_exported_program.graph.nodes
if (
node.op == "placeholder"
and node.name
not in lowered_exported_program.graph_signature.inputs_to_buffers
and node.name
not in lowered_exported_program.graph_signature.inputs_to_parameters
)
]
output_node = [
node for node in lowered_exported_program.graph.nodes if node.op == "output"
]
assert len(output_node) == 1, "There should be only one output node"
# Step 1. Cleaning up the graph before inserting the call_delegate node
# Remove the original output node
lowered_exported_program.graph.erase_node(output_node[0])
# Remove all the everything else except the input
for node in reversed(lowered_exported_program.graph.nodes):
if node.op != "placeholder":
lowered_exported_program.graph.erase_node(node)
# Find placeholders that are parameters or buffers, remove them from the main graph
for node in lowered_exported_program.graph.nodes:
if node.op == "placeholder" and (
node.name in lowered_exported_program.graph_signature.inputs_to_buffers
or node.name
in lowered_exported_program.graph_signature.inputs_to_parameters
):
lowered_exported_program.graph.erase_node(node)
# Step 2. Start constructing the graph
lowered_name = get_lowered_module_name(
lowered_exported_program.graph_module, self
)
# Insert the lowered module to the graph module as an attibute
lowered_node = lowered_exported_program.graph.get_attr(lowered_name)
# Insert a call_delegate node to the graph module, with arguments from the arg list
delegate_node = lowered_exported_program.graph.call_function(
executorch_call_delegate, (lowered_node, *all_input_nodes)
)
# Get the output list. Since the output node is a tuple of list, like ([aten_mul_tensor, aten_add_tensor],)
# We add some handling logic to get the list `[aten_mul_tensor, aten_add_tensor]` properly
original_output_nodes = [
node
for node in self._original_exported_program.graph.nodes
if node.op == "output"
][0].args[0]
delegate_node.meta["spec"] = tuple(
[make_spec(node.meta["val"]) for node in original_output_nodes]
)
delegate_node.meta["val"] = tuple(
[node.meta["val"] for node in original_output_nodes]
)
# The getitem nodes that are going to be inserted to the lowered graph module
getitem_nodes = []
for i in range(len(original_output_nodes)):
getitem_node = lowered_exported_program.graph.call_function(
operator.getitem,
args=(delegate_node, i),
)
getitem_node.meta["val"] = delegate_node.meta["val"][i]
getitem_nodes.append(getitem_node)
lowered_exported_program.graph.output(getitem_nodes)
lowered_exported_program.graph_module.recompile()
lowered_exported_program.graph.lint()
# Users output will be the get items nodes instead
output_specs = [
OutputSpec(
kind=OutputKind.USER_OUTPUT,
arg=TensorArgument(name=getitem_node.name),
target=None,
)
for getitem_node in getitem_nodes
]
# All data are consumed by the delegates so they should be removed from the state dict.
inputs_to_parameters = (
lowered_exported_program.graph_signature.inputs_to_parameters
)
inputs_to_buffers = lowered_exported_program.graph_signature.inputs_to_buffers
input_specs = [
InputSpec(
kind=InputKind.USER_INPUT,
arg=TensorArgument(name=node.name),
target=None,
)
for user_input in lowered_exported_program.graph_signature.user_inputs
if user_input not in inputs_to_parameters
and user_input not in inputs_to_buffers
]
# Double check the ExportedProgram data(especially everything except graph) is good
exported_program = ExportedProgram(
root=lowered_exported_program.graph_module,
graph=lowered_exported_program.graph,
graph_signature=_get_updated_graph_signature(
ExportGraphSignature(
input_specs=input_specs, output_specs=output_specs
),
lowered_exported_program.graph_module,
),
# TODO: May need to set lowered_exported_program.call_spec = CallSpec(None, None)
# somewhere as we should pass it a list of tensors to the lowered module and output a
# list of tensors. Putting call_spec=lowered_exported_program.call_spec is correct here as the
# inputs/outputs to the toplevel program will be in the format of the eager module.
state_dict={}, # None because all data are consumed by delegate
range_constraints=lowered_exported_program.range_constraints,
module_call_graph=lowered_exported_program.module_call_graph,
example_inputs=None,
verifiers=[lowered_exported_program.verifier],
)
if memory_planning is None:
memory_planning = MemoryPlanningPass()
exported_program = _transform(exported_program, SpecPropPass(), memory_planning)
emitted_program = emit_program(
exported_program, emit_stacktrace=emit_stacktrace
).program
return emitted_program
# Used to patch each delegated function with a call_delegate call
# @staticmethod
def forward(
self,
*args: Value,
**kwargs: Tuple[Value, ...],
) -> Value:
return executorch_call_delegate(self, *args)
# TODO(zhxchen17) Try ExportPass
def _fixup_output_node(gm: torch.fx.GraphModule) -> None:
for node in reversed(gm.graph.nodes):
if node.op == "output":
with gm.graph.inserting_before(node):
assert len(node.args) == 1
outputs = node.args[0]
if isinstance(outputs, torch.fx.Node):
val = outputs.meta.get("val")
if isinstance(val, list):
# If a list is returned, in some cases it is represented as a
# singular node, like `split_copy_tensor` but EXIR will return a
# opened-up list like `[getitem1, getitem2]`
outputs = [
torch.fx.Proxy(outputs)[i].node for i in range(len(val))
]
returns, out_spec = pytree.tree_flatten(outputs)
node.args = (returns,)
return
def arrange_graph_placeholders(
gm: torch.fx.GraphModule, owning_program: ExportedProgram
) -> torch.fx.GraphModule:
"""
Modifies the graph of the given graphmodule with one that contains the same nodes as the original,
but with placeholders in order of (Params + Buffers) (User Inputs)
This is used by the delegate api which disturbs the placeholder ordering when creating a submodule
from partitioned nodes
Args:
gm: The graph module that we want arranged
owning_program: ExportedProgram that the submodule (gm) belongs to
Returns:
The graph module in-placed arranged
"""
new_graph = torch.fx.Graph()
node_map = {} # mapping of nodes from old graph to new graph
graph_sign = owning_program.graph_signature
# Add all placeholders into the graph first:
param_nodes = []
buffer_nodes = []
input_nodes = []
for node in gm.graph.nodes:
if node.op != "placeholder":
continue
if node.name in graph_sign.inputs_to_parameters:
param_nodes.append(node)
elif node.name in graph_sign.inputs_to_buffers:
buffer_nodes.append(node)
else:
input_nodes.append(node)
for param_node in param_nodes:
new_node = new_graph.node_copy(param_node, lambda x: node_map[x])
node_map[param_node] = new_node
for buffer_node in buffer_nodes:
new_node = new_graph.node_copy(buffer_node, lambda x: node_map[x])
node_map[buffer_node] = new_node
for input_node in input_nodes:
new_node = new_graph.node_copy(input_node, lambda x: node_map[x])
node_map[input_node] = new_node
# Now add all the other nodes in order
for node in gm.graph.nodes:
if node.op == "placeholder":
continue
new_node = new_graph.node_copy(node, lambda x: node_map[x])
node_map[node] = new_node
# lint to ensure correctness
new_graph.lint()
new_graph._codegen = gm.graph._codegen
gm.graph = new_graph
return gm
# TODO Don't regenerate new signature manually.
def _get_new_signature( # noqa: C901
original_program: ExportedProgram,
gm: torch.fx.GraphModule,
call_module_node: torch.fx.Node,
tag: str,
is_submodule: bool = False,
) -> Tuple[
ExportGraphSignature,
Dict[str, Union[torch.Tensor, torch.nn.Parameter]],
Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
Dict[str, InputSpec],
Dict[str, OutputSpec],
]:
"""
Args:
original_program: The original program that we are paritioning
gm: The partitioned graph module.
call_module_node: The node in the original program that is calling the
partitioned graph module.
tag: The tag being used for this partitioned submodule. This is used to
tell if a particular parameter/buffer/constant node is being tagged,
aka consumed by the delegate.
is_submodule: True if we are currently partitioning inside of a
submodule (like cond's submodule). If we are inside of a submodule,
we do not care about consuming params/buffers.
Returns:
new_signature (ExportGraphSignature): The new signature for the
partitioned graph module.
new_state_dict (Dict[str, Union[torch.Tensor, torch.nn.Parameter]]): The
new state dict containing the consumed params/buffers.
new_constants (Dict[str, Union[torch.Tensor, FakeScriptObject,
torch.ScriptObject]]): The new constants table containing the
consumed constants .
input_specs_to_delete (Dict[str, InputSpec]): The input specs that have
been consumed by the delegate (param/buffer input nodes) and should
be removed from the toplevel ExportedProgram.
output_specs_to_delete (Dict[str, InputSpec]): The output specs that have
been consumed by the delegate (buffer mutation nodes) and should be
removed from the toplevel ExportedProgram.
"""
old_signature = original_program.graph_signature
input_specs = []
output_specs = []
input_specs_to_delete = {}
output_specs_to_delete = {}
new_state_dict = {}
new_constants = {}
# If we are within a submodule, we do not need to care about consuming
# parameter/buffers
input_node_to_sig: Dict[str, InputSpec] = (
{input_spec.arg.name: input_spec for input_spec in old_signature.input_specs}
if not is_submodule
else {}
)
toplevel_output_node_to_sig: Dict[str, List[OutputSpec]] = defaultdict(list)
if not is_submodule:
for output_spec in old_signature.output_specs:
toplevel_output_node_to_sig[output_spec.arg.name].append(output_spec)
for node in gm.graph.nodes:
if node.op == "placeholder":
if node.name not in input_node_to_sig:
input_specs.append(
InputSpec(
kind=InputKind.USER_INPUT,
arg=TensorArgument(name=node.name),
target=None,
)
)
continue
orig_input_spec = input_node_to_sig[node.name]
if not isinstance(orig_input_spec.arg, TensorArgument):
input_specs.append(orig_input_spec)
elif node.meta.get("delegation_tag", None) == tag:
input_specs.append(orig_input_spec)
if orig_input_spec.kind == InputKind.USER_INPUT:
continue
# The following input specs are all attributes that should be
# consumed by the delegate, so we want to remove it from the
# toplevel module input/output
input_specs_to_delete[node.name] = orig_input_spec
input_target = orig_input_spec.target
if input_target in original_program.state_dict:
assert orig_input_spec.kind in (
InputKind.PARAMETER,
InputKind.BUFFER,
)
new_state_dict[input_target] = original_program.state_dict[
input_target
]
elif input_target in original_program.constants:
assert orig_input_spec.kind in (
InputKind.CONSTANT_TENSOR,
InputKind.CUSTOM_OBJ,
InputKind.BUFFER,
)
new_constants[input_target] = original_program.constants[
input_target
]
else:
raise RuntimeError(f"Invalid input spec {orig_input_spec} received")
else:
input_specs.append(
InputSpec(
kind=InputKind.USER_INPUT,
arg=TensorArgument(name=node.name),
target=None,
)
)
if node.op == "output":
buffer_mutation_idxs: Dict[int, List[OutputSpec]] = defaultdict(list)
for user in call_module_node.users.keys():
if user.name in toplevel_output_node_to_sig:
assert (
user.op == "call_function" and user.target == operator.getitem
), f"Invalid user {user}, node.op is {user.op} and node.target is {user.target}"
getitem_idx = user.args[1]
assert isinstance(
getitem_idx, int
), f"Invalid getitem type: {type(getitem_idx)}"
buffer_mutation_idxs[getitem_idx].extend(
toplevel_output_node_to_sig[user.name]
)
for i, output_node in enumerate(node.args[0]):
if i in buffer_mutation_idxs:
assert isinstance(output_node, torch.fx.Node)
orig_output_specs = buffer_mutation_idxs[i]
if any(
orig_output_spec.kind == OutputKind.BUFFER_MUTATION
and orig_output_spec.target in new_state_dict
for orig_output_spec in orig_output_specs
):
# If the delegate wants to consume the buffer, then the
# delegate should also consume the buffer mutation
# (output spec would be a BUFFER_MUTATION). Otherwise
# the delegate will just return the result of the
# mutation as a USER_OUTPUT.
orig_output_spec = [
orig_output_spec
for orig_output_spec in orig_output_specs
if orig_output_spec.kind == OutputKind.BUFFER_MUTATION
and orig_output_spec.target in new_state_dict
][0]
assert len(orig_output_specs) == 1, (
f"Constant {orig_output_spec.target} was tagged to be "
"consumed by the buffer, and was found to also contain "
"a buffer mutation. However this buffer mutation node "
"was found to also be used as other types of outputs "
"which is currently not supported. Please file an "
"issue on Github. \n\n"
f"The toplevel program: {original_program}\n"
)
output_specs.append(
OutputSpec(
kind=OutputKind.BUFFER_MUTATION,
arg=TensorArgument(name=output_node.name),
target=orig_output_spec.target,
)
)
output_specs_to_delete[orig_output_spec.arg.name] = (
orig_output_spec
)
else:
output_specs.append(
OutputSpec(
kind=OutputKind.USER_OUTPUT,
arg=TensorArgument(name=output_node.name),
target=None,
)
)
elif not isinstance(output_node, torch.fx.Node):
output_specs.append(
OutputSpec(
kind=OutputKind.USER_OUTPUT,
arg=ConstantArgument(name="", value=output_node),
target=None,
)
)
else:
output_specs.append(
OutputSpec(
kind=OutputKind.USER_OUTPUT,
arg=TensorArgument(name=output_node.name),
target=None,
)
)
new_signature = ExportGraphSignature(
input_specs=input_specs, output_specs=output_specs
)
return (
new_signature,
new_state_dict,
new_constants,
input_specs_to_delete,
output_specs_to_delete,
)
def create_exported_program_from_submodule(
submodule: torch.fx.GraphModule,
owning_program: ExportedProgram,
tag: str,
call_module_node: torch.fx.Node,
is_submodule: bool,
) -> Tuple[ExportedProgram, Dict[str, InputSpec], Dict[str, OutputSpec]]:
"""
Creates an ExportedProgram from the given submodule using the parameters and buffers
from the top-level owning program
Args:
submodule: submodule to create and exported program from
owning_program: exported program containing the parameters and buffers used within
the submodule
Returns:
The ExportedProgram created from submodule
input_specs_to_delete (Dict[str, InputSpec]): The input specs that have
been consumed by the delegate (param/buffer input nodes) and should
be removed from the toplevel ExportedProgram.
output_specs_to_delete (Dict[str, InputSpec]): The output specs that have
been consumed by the delegate (buffer mutation nodes) and should be
removed from the toplevel ExportedProgram.
"""
# Arrange the submodule's placeholders in order
submodule = arrange_graph_placeholders(submodule, owning_program)
# TODO: we probably need to arrange the outputs wrt buffer mutations.
# Get updated graph signature
(
subgraph_signature,
subgraph_state_dict,
subgraph_constants,
toplevel_input_specs_to_delete,
toplevel_output_specs_to_delete,
) = _get_new_signature(
owning_program, submodule, call_module_node, tag, is_submodule
)
in_spec = pytree.tree_flatten((tuple(subgraph_signature.user_inputs), {}))[1]
out_spec = pytree.tree_flatten(subgraph_signature.user_outputs)[1]
return (
ExportedProgram(
root=submodule,
graph=submodule.graph,
graph_signature=subgraph_signature,
state_dict=subgraph_state_dict,
range_constraints=copy.deepcopy(owning_program.range_constraints),
module_call_graph=[
ModuleCallEntry(
"",
ModuleCallSignature(
inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec
),
)
],
constants=subgraph_constants,
verifiers=[owning_program.verifier],
),
toplevel_input_specs_to_delete,
toplevel_output_specs_to_delete,
)
def create_submodule_from_nodes(
gm: torch.fx.GraphModule,
node_list: NodeList,
tag: str,
skip_legalize_graph: bool = False,
) -> Tuple[torch.fx.GraphModule, torch.fx.Node]:
"""
Modifies the given graph module in-place to separate out the given nodes
into a submodule. The given node_list should form a fully connected
subgraph.
Args:
gm: The graph module that we want to partition
node_list: A list of nodes that belong in the partition
Returns:
The submodule that has been partitioned, the call_module node in the
toplevel graph module calling the submodule
"""
sorted_nodes = topo_sort(node_list)
submodule_name = "fused_" + tag
sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(
gm, sorted_nodes, submodule_name
)
_fixup_output_node(sub_gm)
gm = insert_subgm(gm, sub_gm, orig_inputs, orig_outputs)
submodule_node = None
for node in gm.graph.nodes:
if node.op == "call_module":
if node.target == submodule_name:
submodule_node = node
else:
raise RuntimeError(
f"The submodule created with nodes {node_list} did not form \
one fully contained subgraph. Check that these nodes form a \
fully contained graph. Partitioned graph: {gm.graph}."
)
if len(orig_outputs) == 1 and isinstance(orig_outputs[0].meta["val"], FakeTensor):
# If the original output is a single tensor, it has been
# pytree.tree_flatten-ed to be a singleton list, so we want to replace
# all uses with a getitem call to the 0th index of the result
with gm.graph.inserting_after(submodule_node):
proxy_out = torch.fx.Proxy(submodule_node)[0].node # type: ignore[index]
submodule_node.replace_all_uses_with(proxy_out)
proxy_out.meta["val"] = submodule_node.meta["val"]
# Reset the args since it was overwritten in the previous line
proxy_out.args = (submodule_node, 0)
else:
# fuse_as_graphmodule will automatically propagate the metadata of the
# partition's last node to the getitem nodes that appear after the
# call_module node. However, in the case of delegation we do not want
# these getitem nodes to contain irrelevant previous metadata
# (ex. source_fn, # nn_module_stack)
for user_node in submodule_node.users:
user_node.meta.pop("nn_module_stack", None)
user_node.meta.pop("source_fn_stack", None)
erase_nodes(gm, sorted_nodes)
# Topological sort original gm with newly created sub_gm
# TODO : T153794167 Get rid of support for skipping legalize graph in create_submodule_from_nodes
# once we transition to using fuse_by_partitions.
if not skip_legalize_graph:
legalize_graph(gm)
# Get the call_module node
submodule_node = None
for node in gm.graph.nodes:
if node.op == "call_module" and node.target == submodule_name:
submodule_node = node
elif node.op == "call_module":
raise RuntimeError(
f"The submodule created with nodes {node_list} did not form \
one fully contained subgraph. Check that these nodes form a \
fully contained graph. Partitioned graph: {gm.graph}."
)
assert (
submodule_node is not None
), f"No submodule was created with the nodes {node_list} in the graph {gm.graph}"
return sub_gm, submodule_node
def get_lowered_submodules(
graph_module: torch.fx.GraphModule,
) -> List[Tuple[str, LoweredBackendModule, torch.fx.Node]]:
"""
Returns a list of lowered modules that are in the given graph (does not look
into submodules). Specifically, the returned value is a list containing a
tuple of (name of the lowered module that's stored in the graph module, the
lowered module itself, and the fx node that called this lowered module).
"""
lowered_submodules = []
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target == executorch_call_delegate:
name, module, node = _get_submodule(graph_module, node, 0)
assert isinstance(module, LoweredBackendModule)
lowered_submodules.append((name, module, node))
return lowered_submodules
def get_lowered_backend_modules(
graph_module: torch.fx.GraphModule,
) -> List[LoweredBackendModule]:
"""
Returns a list of exported programs which were lowered by backen delegates
"""
lowered_programs = []
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target == executorch_call_delegate:
lowered_backend_module = getattr(graph_module, node.args[0].name)
lowered_programs.append(lowered_backend_module)
return lowered_programs
def _unsafe_adjust_original_program( # noqa: C901
original_program: ExportedProgram,
call_delegate_node: torch.fx.Node,
input_specs_to_delete: Dict[str, InputSpec],
output_specs_to_delete: Dict[str, OutputSpec],
) -> None:
"""
Directly modify the original exported program's signature and state dict
based on the consumed params/buffers in the delegate.
"""
original_program._graph_signature.input_specs = [
input_spec
for input_spec in original_program.graph_signature.input_specs
if input_spec.arg.name not in input_specs_to_delete
]
currently_used_targets: Set[str] = {
input_spec.target
for input_spec in original_program._graph_signature.input_specs
if input_spec.target is not None
}
original_program._graph_signature.output_specs = [
output_spec
for output_spec in original_program.graph_signature.output_specs
if output_spec.arg.name not in output_specs_to_delete
]
# Delete all parameters/buffers consumed by the created exported program
# from the graph signature, state dict, constants table
for node in original_program.graph.nodes:
if node.op == "placeholder":
if node.name in input_specs_to_delete:
assert len(node.users) == 0
original_program.graph.erase_node(node)
else:
break
for input_spec in input_specs_to_delete.values():
input_target = input_spec.target
assert input_target is not None
if input_target in currently_used_targets:
continue
if input_spec.kind == InputKind.PARAMETER:
del original_program._state_dict[input_target]
elif input_spec.kind == InputKind.BUFFER:
if input_spec.persistent:
del original_program._state_dict[input_target]
else:
del original_program._constants[input_spec.target]
elif input_spec.kind == InputKind.CONSTANT_TENSOR:
del original_program._constants[input_spec.target]
else:
raise RuntimeError(f"Invalid input spec {input_spec} received")
# Delete buffer mutations from the output which were consumed by the delegate
toplevel_output_node = None
for node in reversed(original_program.graph.nodes):
if node.op == "output":
toplevel_output_node = node
break
assert toplevel_output_node is not None
assert (
len(toplevel_output_node.args) == 1
), f"Invalid output node: {toplevel_output_node} with args {toplevel_output_node.args}"
new_output_args = [
arg
for arg in toplevel_output_node.args[0]
if not isinstance(arg, torch.fx.Node) or arg.name not in output_specs_to_delete
]
toplevel_output_node.args = (tuple(new_output_args),)
# Delete the buffer mutation getitem nodes
getitem_idxs: List[int] = []
user_nodes = list(call_delegate_node.users.keys())
for user in user_nodes:
if user.name in output_specs_to_delete:
assert (
user.op == "call_function" and user.target == operator.getitem
), f"Invalid user {user}, node.op is {node.op} and node.target is {node.target}"
user_idx = user.args[1]
assert isinstance(user_idx, int), f"Invalid getitem type: {type(user_idx)}"
getitem_idxs.append(user_idx)
original_program.graph.erase_node(user)
getitem_idxs.sort(reverse=True)
# Adjust all the getitem indices after the deleted getitems
user_nodes = list(call_delegate_node.users.keys())
for user in user_nodes:
assert user.op == "call_function" and user.target == operator.getitem
user_idx = user.args[1]
assert isinstance(user_idx, int)
for i, idx in enumerate(getitem_idxs):
if user_idx > idx:
user.args = (user.args[0], user_idx - (len(getitem_idxs) - i))
break
original_program._validate()