Shortcuts

Source code for executorch.exir.lowered_backend_module

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import copy
import operator
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.utils._pytree as pytree
from executorch.exir._serialize import _serialize_pte_binary
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name
from executorch.exir.emit import emit_program

from executorch.exir.graph_module import _get_submodule

from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
from executorch.exir.passes.spec_prop_pass import make_spec, SpecPropPass
from executorch.exir.schema import Program

from executorch.exir.tracer import Value

from torch._subclasses import FakeTensor
from torch.export.exported_program import (
    ConstantArgument,
    ExportedProgram,
    ExportGraphSignature,
    InputKind,
    InputSpec,
    ModuleCallEntry,
    ModuleCallSignature,
    OutputKind,
    OutputSpec,
    TensorArgument,
)
from torch.fx.passes.utils.fuser_utils import (
    erase_nodes,
    fuse_as_graphmodule,
    insert_subgm,
    legalize_graph,
    NodeList,
    topo_sort,
)


[docs]class LoweredBackendModule(torch.nn.Module): """ A subclass of nn.Module that is generated for modules containing delegated functions. This is can be created by calling `to_backend`. """ _backend_id: str # The backend's name _processed_bytes: bytes # The delegate blobs created from backend.preprocess _compile_specs: List[ CompileSpec ] # A list of backend-specific objects with static metadata to configure the "compilation" process. _original_exported_program: ExportedProgram # The original EXIR module def __init__( self, edge_program: ExportedProgram, backend_id: str, processed_bytes: bytes, compile_specs: List[CompileSpec], ) -> None: super().__init__() self._original_exported_program = edge_program self._backend_id = backend_id self._processed_bytes = processed_bytes self._compile_specs = compile_specs # pyre-ignore def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule": # Copy exported program copied_program = ExportedProgram( root=copy.deepcopy(self._original_exported_program.graph_module), graph=copy.deepcopy(self._original_exported_program.graph), graph_signature=copy.deepcopy( self._original_exported_program.graph_signature ), state_dict=self._original_exported_program.state_dict, range_constraints=copy.deepcopy( self._original_exported_program.range_constraints ), module_call_graph=copy.deepcopy( self._original_exported_program.module_call_graph ), verifier=copy.deepcopy(self._original_exported_program.verifier), constants=self._original_exported_program.constants, ) res = LoweredBackendModule( edge_program=copied_program, backend_id=self._backend_id, processed_bytes=self._processed_bytes, compile_specs=copy.deepcopy(self._compile_specs, memo), ) res.meta = copy.copy(getattr(self, "meta", {})) return res @property def backend_id(self) -> str: """ Returns the backends name. """ return self._backend_id @property def processed_bytes(self) -> bytes: """ Returns the delegate blob created from backend.preprocess """ return self._processed_bytes @property def compile_specs(self) -> List[CompileSpec]: """ Returns a list of backend-specific objects with static metadata to configure the "compilation" process. """ return self._compile_specs @property def original_module(self) -> ExportedProgram: """ Returns the original EXIR module """ return self._original_exported_program # TODO(chenlai): consolidate the seriailization config with serialize_to_flatbuffer api
[docs] def buffer( self, extract_delegate_segments: bool = False, segment_alignment: int = 4096, constant_tensor_alignment: Optional[int] = None, delegate_alignment: Optional[int] = None, ) -> bytes: """ Returns a buffer containing the serialized ExecuTorch binary. """ # TODO(T181463742): avoid calling bytes(..) which incurs large copies. out = bytes( _serialize_pte_binary( program=self.program(), extract_delegate_segments=extract_delegate_segments, segment_alignment=segment_alignment, constant_tensor_alignment=constant_tensor_alignment, delegate_alignment=delegate_alignment, ) ) return out
# TODO(chenlai): re-consider recapture instead of manually constructing the program because # the meta data construction is done manually. def program(self, emit_stacktrace: bool = False) -> Program: # Fix autodpes introuces cyclic dependencies: # program -> verifier -> lowered_backend_module -> program # @manual from executorch.exir.program._program import ( _get_updated_graph_signature, _transform, ) """ Returns the object that represents the ExecuTorch binary before serialization. """ # Creates a new module based on the original module. The original module will # look something like following: # # opcode name target args kwargs # ------------- ------------------- ---------------- ------------------------------------------ -------- # placeholder arg0_1 arg0_1 () {} # placeholder arg1_1 arg1_1 () {} # call_function aten_repeat_default * (arg1_1, [4, 1]) {} # call_function aten_mul_tensor * (aten_repeat_default, aten_repeat_default) {} # call_function aten_add_tensor * (arg1_1, arg1_1) {} # output output output ([aten_mul_tensor, aten_add_tensor],) {} # # if the whole module is lowered, the resulting lowered module look like # # opcode name target args kwargs # ------------- ------------------------ --------------------------- ---------------------------------- -------- # placeholder arg0_1 arg0_1 () {} # placeholder arg1_1 arg1_1 () {} # get_attr lowered_module_0 lowered_module_0 () {} # call_function executorch_call_delegate executorch_call_delegate (lowered_module_0, arg0_1, arg1_1) {} # call_function getitem <built-in function getitem> (executorch_call_delegate, 0) {} # call_function getitem_1 <built-in function getitem> (executorch_call_delegate, 1) {} # output output_1 output ([getitem, getitem_1],) {} # # We'll remove all call_function nodes, insert an call_delegate node, inserting getitems nodes to get the result for call_delegate node # and return the list of getitems as the output lowered_exported_program = copy.deepcopy(self._original_exported_program) # The real input nodes are the ones not buffer or parameter all_input_nodes = [ node for node in lowered_exported_program.graph.nodes if ( node.op == "placeholder" and node.name not in lowered_exported_program.graph_signature.inputs_to_buffers and node.name not in lowered_exported_program.graph_signature.inputs_to_parameters ) ] output_node = [ node for node in lowered_exported_program.graph.nodes if node.op == "output" ] assert len(output_node) == 1, "There should be only one output node" # Step 1. Cleaning up the graph before inserting the call_delegate node # Remove the original output node lowered_exported_program.graph.erase_node(output_node[0]) # Remove all the everything else except the input for node in reversed(lowered_exported_program.graph.nodes): if node.op != "placeholder": lowered_exported_program.graph.erase_node(node) # Find placeholders that are parameters or buffers, remove them from the main graph for node in lowered_exported_program.graph.nodes: if node.op == "placeholder" and ( node.name in lowered_exported_program.graph_signature.inputs_to_buffers or node.name in lowered_exported_program.graph_signature.inputs_to_parameters ): lowered_exported_program.graph.erase_node(node) # Step 2. Start constructing the graph lowered_name = get_lowered_module_name( lowered_exported_program.graph_module, self ) # Insert the lowered module to the graph module as an attibute lowered_node = lowered_exported_program.graph.get_attr(lowered_name) # Insert a call_delegate node to the graph module, with arguments from the arg list delegate_node = lowered_exported_program.graph.call_function( executorch_call_delegate, (lowered_node, *all_input_nodes) ) # Get the output list. Since the output node is a tuple of list, like ([aten_mul_tensor, aten_add_tensor],) # We add some handling logic to get the list `[aten_mul_tensor, aten_add_tensor]` properly original_output_nodes = [ node for node in self._original_exported_program.graph.nodes if node.op == "output" ][0].args[0] delegate_node.meta["spec"] = tuple( [make_spec(node.meta["val"]) for node in original_output_nodes] ) delegate_node.meta["val"] = tuple( [node.meta["val"] for node in original_output_nodes] ) # The getitem nodes that are going to be inserted to the lowered graph module getitem_nodes = [] for i in range(len(original_output_nodes)): getitem_node = lowered_exported_program.graph.call_function( operator.getitem, args=(delegate_node, i), ) getitem_node.meta["val"] = delegate_node.meta["val"][i] getitem_nodes.append(getitem_node) lowered_exported_program.graph.output(getitem_nodes) lowered_exported_program.graph_module.recompile() lowered_exported_program.graph.lint() # Users output will be the get items nodes instead output_specs = [ OutputSpec( kind=OutputKind.USER_OUTPUT, arg=TensorArgument(name=getitem_node.name), target=None, ) for getitem_node in getitem_nodes ] # All data are consumed by the delegates so they should be removed from the state dict. inputs_to_parameters = ( lowered_exported_program.graph_signature.inputs_to_parameters ) inputs_to_buffers = lowered_exported_program.graph_signature.inputs_to_buffers input_specs = [ InputSpec( kind=InputKind.USER_INPUT, arg=TensorArgument(name=node.name), target=None, ) for user_input in lowered_exported_program.graph_signature.user_inputs if user_input not in inputs_to_parameters and user_input not in inputs_to_buffers ] # Double check the ExportedProgram data(especially everything except graph) is good exported_program = ExportedProgram( root=lowered_exported_program.graph_module, graph=lowered_exported_program.graph, graph_signature=_get_updated_graph_signature( ExportGraphSignature( input_specs=input_specs, output_specs=output_specs ), lowered_exported_program.graph_module, ), # TODO: May need to set lowered_exported_program.call_spec = CallSpec(None, None) # somewhere as we should pass it a list of tensors to the lowered module and output a # list of tensors. Putting call_spec=lowered_exported_program.call_spec is correct here as the # inputs/outputs to the toplevel program will be in the format of the eager module. state_dict={}, # None because all data are consumed by delegate range_constraints=lowered_exported_program.range_constraints, module_call_graph=lowered_exported_program.module_call_graph, example_inputs=None, verifier=lowered_exported_program.verifier, ) exported_program = _transform( exported_program, SpecPropPass(), MemoryPlanningPass("greedy") ) emitted_program = emit_program( exported_program, emit_stacktrace=emit_stacktrace ).program return emitted_program # Used to patch each delegated function with a call_delegate call # @staticmethod def forward( self, *args: Value, **kwargs: Tuple[Value, ...], ) -> Value: return executorch_call_delegate(self, *args)
# TODO(zhxchen17) Try ExportPass def _fixup_output_node(gm: torch.fx.GraphModule) -> None: for node in reversed(gm.graph.nodes): if node.op == "output": with gm.graph.inserting_before(node): assert len(node.args) == 1 outputs = node.args[0] if isinstance(outputs, torch.fx.Node): val = outputs.meta.get("val") if isinstance(val, list): # If a list is returned, in some cases it is represented as a # singular node, like `split_copy_tensor` but EXIR will return a # opened-up list like `[getitem1, getitem2]` outputs = [ torch.fx.Proxy(outputs)[i].node for i in range(len(val)) ] returns, out_spec = pytree.tree_flatten(outputs) node.args = (returns,) return def arrange_graph_placeholders( gm: torch.fx.GraphModule, owning_program: ExportedProgram ) -> torch.fx.GraphModule: """ Modifies the graph of the given graphmodule with one that contains the same nodes as the original, but with placeholders in order of (Params + Buffers) (User Inputs) This is used by the delegate api which disturbs the placeholder ordering when creating a submodule from partitioned nodes Args: gm: The graph module that we want arranged owning_program: ExportedProgram that the submodule (gm) belongs to Returns: The graph module in-placed arranged """ new_graph = torch.fx.Graph() node_map = {} # mapping of nodes from old graph to new graph graph_sign = owning_program.graph_signature # Add all placeholders into the graph first: param_nodes = [] buffer_nodes = [] input_nodes = [] for node in gm.graph.nodes: if node.op != "placeholder": continue if node.name in graph_sign.inputs_to_parameters: param_nodes.append(node) elif node.name in graph_sign.inputs_to_buffers: buffer_nodes.append(node) else: input_nodes.append(node) for param_node in param_nodes: new_node = new_graph.node_copy(param_node, lambda x: node_map[x]) node_map[param_node] = new_node for buffer_node in buffer_nodes: new_node = new_graph.node_copy(buffer_node, lambda x: node_map[x]) node_map[buffer_node] = new_node for input_node in input_nodes: new_node = new_graph.node_copy(input_node, lambda x: node_map[x]) node_map[input_node] = new_node # Now add all the other nodes in order for node in gm.graph.nodes: if node.op == "placeholder": continue new_node = new_graph.node_copy(node, lambda x: node_map[x]) node_map[node] = new_node # lint to ensure correctness new_graph.lint() new_graph._codegen = gm.graph._codegen gm.graph = new_graph return gm # TODO Don't regenerate new signature manually. def _get_new_signature( # noqa: C901 original_program: ExportedProgram, gm: torch.fx.GraphModule, tag: Optional[str] = None, ) -> Tuple[ ExportGraphSignature, Dict[str, Union[torch.Tensor, torch.nn.Parameter]], Dict[str, Union[torch.Tensor, torch.ScriptObject]], ]: """ Args: tag: If tag is None, this means that we are constructing the graph signature for the toplevel graph, after delegation. We need to do this because sometimes delegates will swallow some parameters/buffers, so we need to update the graph signature/state dict to reflect these changes. Otherwise, if tag is not None, this means we are constructing the graph signature for the delegated modules. In this case, we need to look through the input nodes and see which ones were originally parameters/buffers, and lower them down to the delegate. """ old_signature = original_program.graph_signature input_specs = [] output_specs = [] new_signature = ExportGraphSignature( input_specs=input_specs, output_specs=output_specs ) new_state_dict = {} new_constants = {} input_tensor_node_to_sig = { input_spec.arg.name: input_spec for input_spec in old_signature.input_specs if isinstance(input_spec.arg, TensorArgument) } for node in gm.graph.nodes: is_tagged = tag is None or node.meta.get("delegation_tag", None) == tag if node.op == "placeholder": if node.name not in input_tensor_node_to_sig: assert tag is not None input_specs.append( InputSpec( kind=InputKind.USER_INPUT, arg=TensorArgument(name=node.name), target=None, ) ) continue orig_input_spec = input_tensor_node_to_sig[node.name] if not isinstance(orig_input_spec.arg, TensorArgument): input_specs.append(orig_input_spec) elif is_tagged: input_specs.append(orig_input_spec) if orig_input_spec.kind == InputKind.PARAMETER: new_state_dict[orig_input_spec.target] = ( original_program.state_dict[orig_input_spec.target] ) elif ( orig_input_spec.kind == InputKind.BUFFER and orig_input_spec.persistent ): new_state_dict[orig_input_spec.target] = ( original_program.state_dict[orig_input_spec.target] ) elif orig_input_spec.kind == InputKind.BUFFER: assert not orig_input_spec.persistent new_constants[orig_input_spec.target] = original_program.constants[ orig_input_spec.target ] elif orig_input_spec.kind in ( InputKind.CONSTANT_TENSOR, InputKind.CUSTOM_OBJ, ): new_constants[orig_input_spec.target] = original_program.constants[ orig_input_spec.target ] else: input_specs.append( InputSpec( kind=InputKind.USER_INPUT, arg=TensorArgument(name=node.name), target=None, ) ) if node.op == "output": output_nodes = pytree.tree_leaves((node.args, node.kwargs)) if tag is not None: # We are constructing output_specs for the delegate outputs. # These don't have any buffer mutations. for output_node in output_nodes: if not isinstance(output_node, torch.fx.Node): output_specs.append( OutputSpec( kind=OutputKind.USER_OUTPUT, arg=ConstantArgument(output_node), target=None, ) ) else: output_specs.append( OutputSpec( kind=OutputKind.USER_OUTPUT, arg=TensorArgument(name=output_node.name), target=None, ) ) else: # We are reconstruting the toplevel module which contains # delegates. Delegation should not change the number of outputs # in the toplevel module, and it does not touch the mutated buffers assert len(old_signature.output_specs) == len(output_nodes) for prev_output_spec, output_node in zip( old_signature.output_specs, output_nodes ): if not isinstance(output_node, torch.fx.Node): assert isinstance(prev_output_spec.arg, ConstantArgument) output_specs.append( OutputSpec( kind=OutputKind.USER_OUTPUT, arg=ConstantArgument(output_node), target=None, ) ) else: new_output_spec = copy.deepcopy(prev_output_spec) new_output_spec.arg.name = output_node.name output_specs.append(new_output_spec) return new_signature, new_state_dict, new_constants def create_exported_program_from_submodule( submodule: torch.fx.GraphModule, owning_program: ExportedProgram, tag: str, ) -> ExportedProgram: """ Creates an ExportedProgram from the given submodule using the parameters and buffers from the top-level owning program Args: submodule: submodule to create and exported program from owning_program: exported program containing the parameters and buffers used within the submodule Returns: The ExportedProgram created from submodule """ # Arrange the submodule's placeholders in order submodule = arrange_graph_placeholders(submodule, owning_program) # Get updated graph signature subgraph_signature, subgraph_state_dict, subgraph_constants = _get_new_signature( owning_program, submodule, tag ) in_spec = pytree.tree_flatten((tuple(subgraph_signature.user_inputs), {}))[1] out_spec = pytree.tree_flatten(subgraph_signature.user_outputs)[1] return ExportedProgram( root=submodule, graph=submodule.graph, graph_signature=subgraph_signature, state_dict=subgraph_state_dict, range_constraints=copy.deepcopy(owning_program.range_constraints), module_call_graph=[ ModuleCallEntry( "", ModuleCallSignature( inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec ), ) ], verifier=owning_program.verifier, constants=subgraph_constants, ) def create_submodule_from_nodes( gm: torch.fx.GraphModule, node_list: NodeList, tag: str, skip_legalize_graph: bool = False, ) -> Tuple[torch.fx.GraphModule, torch.fx.Node]: """ Modifies the given graph module in-place to separate out the given nodes into a submodule. The given node_list should form a fully connected subgraph. Args: gm: The graph module that we want to partition node_list: A list of nodes that belong in the partition Returns: The submodule that has been partitioned, the call_module node in the toplevel graph module calling the submodule """ sorted_nodes = topo_sort(node_list) submodule_name = "fused_" + tag sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule( gm, sorted_nodes, submodule_name ) _fixup_output_node(sub_gm) gm = insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) submodule_node = None for node in gm.graph.nodes: if node.op == "call_module": if node.target == submodule_name: submodule_node = node else: raise RuntimeError( f"The submodule created with nodes {node_list} did not form \ one fully contained subgraph. Check that these nodes form a \ fully contained graph. Partitioned graph: {gm.graph}." ) if len(orig_outputs) == 1 and isinstance(orig_outputs[0].meta["val"], FakeTensor): # If the original output is a single tensor, it has been # pytree.tree_flatten-ed to be a singleton list, so we want to replace # all uses with a getitem call to the 0th index of the result with gm.graph.inserting_after(submodule_node): proxy_out = torch.fx.Proxy(submodule_node)[0].node # type: ignore[index] submodule_node.replace_all_uses_with(proxy_out) proxy_out.meta["val"] = submodule_node.meta["val"] # Reset the args since it was overwritten in the previous line proxy_out.args = (submodule_node, 0) else: # fuse_as_graphmodule will automatically propagate the metadata of the # partition's last node to the getitem nodes that appear after the # call_module node. However, in the case of delegation we do not want # these getitem nodes to contain irrelevant previous metadata # (ex. source_fn, # nn_module_stack) for user_node in submodule_node.users: user_node.meta.pop("nn_module_stack", None) user_node.meta.pop("source_fn_stack", None) erase_nodes(gm, sorted_nodes) # Topological sort original gm with newly created sub_gm # TODO : T153794167 Get rid of support for skipping legalize graph in create_submodule_from_nodes # once we transition to using fuse_by_partitions. if not skip_legalize_graph: legalize_graph(gm) # Get the call_module node submodule_node = None for node in gm.graph.nodes: if node.op == "call_module" and node.target == submodule_name: submodule_node = node elif node.op == "call_module": raise RuntimeError( f"The submodule created with nodes {node_list} did not form \ one fully contained subgraph. Check that these nodes form a \ fully contained graph. Partitioned graph: {gm.graph}." ) assert ( submodule_node is not None ), f"No submodule was created with the nodes {node_list} in the graph {gm.graph}" return sub_gm, submodule_node def get_lowered_submodules( graph_module: torch.fx.GraphModule, ) -> List[Tuple[str, LoweredBackendModule, torch.fx.Node]]: """ Returns a list of lowered modules that are in the given graph (does not look into submodules). Specifically, the returned value is a list containing a tuple of (name of the lowered module that's stored in the graph module, the lowered module itself, and the fx node that called this lowered module). """ lowered_submodules = [] for node in graph_module.graph.nodes: if node.op == "call_function" and node.target == executorch_call_delegate: name, module, node = _get_submodule(graph_module, node, 0) assert isinstance(module, LoweredBackendModule) lowered_submodules.append((name, module, node)) return lowered_submodules def get_lowered_backend_modules( graph_module: torch.fx.GraphModule, ) -> List[LoweredBackendModule]: """ Returns a list of exported programs which were lowered by backen delegates """ lowered_programs = [] for node in graph_module.graph.nodes: if node.op == "call_function" and node.target == executorch_call_delegate: lowered_backend_module = getattr(graph_module, node.args[0].name) lowered_programs.append(lowered_backend_module) return lowered_programs

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources