Shortcuts

Source code for executorch.exir.lowered_backend_module

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import copy
import operator
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import torch
import torch.utils._pytree as pytree
from executorch.exir._serialize import _serialize_pte_binary
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name
from executorch.exir.emit import emit_program

from executorch.exir.graph_module import _get_submodule

from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
from executorch.exir.passes.spec_prop_pass import make_spec, SpecPropPass
from executorch.exir.schema import Program

from executorch.exir.tracer import Value
from torch._library.fake_class_registry import FakeScriptObject

from torch._subclasses import FakeTensor
from torch.export.exported_program import (
    ConstantArgument,
    ExportedProgram,
    ExportGraphSignature,
    InputKind,
    InputSpec,
    ModuleCallEntry,
    ModuleCallSignature,
    OutputKind,
    OutputSpec,
    TensorArgument,
)
from torch.fx.passes.utils.fuser_utils import (
    erase_nodes,
    fuse_as_graphmodule,
    insert_subgm,
    legalize_graph,
    NodeList,
    topo_sort,
)


[docs]class LoweredBackendModule(torch.nn.Module): """ A subclass of nn.Module that is generated for modules containing delegated functions. This is can be created by calling `to_backend`. """ _backend_id: str # The backend's name _processed_bytes: bytes # The delegate blobs created from backend.preprocess _compile_specs: List[ CompileSpec ] # A list of backend-specific objects with static metadata to configure the "compilation" process. _original_exported_program: ExportedProgram # The original EXIR module def __init__( self, edge_program: ExportedProgram, backend_id: str, processed_bytes: bytes, compile_specs: List[CompileSpec], ) -> None: super().__init__() self._original_exported_program = edge_program self._backend_id = backend_id self._processed_bytes = processed_bytes self._compile_specs = compile_specs # pyre-ignore def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule": # Copy exported program copied_program = ExportedProgram( root=copy.deepcopy(self._original_exported_program.graph_module), graph=copy.deepcopy(self._original_exported_program.graph), graph_signature=copy.deepcopy( self._original_exported_program.graph_signature ), state_dict=self._original_exported_program.state_dict, range_constraints=copy.deepcopy( self._original_exported_program.range_constraints ), module_call_graph=copy.deepcopy( self._original_exported_program.module_call_graph ), constants=self._original_exported_program.constants, verifiers=[copy.deepcopy(self._original_exported_program.verifier)], ) res = LoweredBackendModule( edge_program=copied_program, backend_id=self._backend_id, processed_bytes=self._processed_bytes, compile_specs=copy.deepcopy(self._compile_specs, memo), ) # pyre-fixme[16]: `LoweredBackendModule` has no attribute `meta`. res.meta = copy.copy(getattr(self, "meta", {})) return res @property def backend_id(self) -> str: """ Returns the backends name. """ return self._backend_id @property def processed_bytes(self) -> bytes: """ Returns the delegate blob created from backend.preprocess """ return self._processed_bytes @property def compile_specs(self) -> List[CompileSpec]: """ Returns a list of backend-specific objects with static metadata to configure the "compilation" process. """ return self._compile_specs @property def original_module(self) -> ExportedProgram: """ Returns the original EXIR module """ return self._original_exported_program # TODO(chenlai): consolidate the seriailization config with serialize_to_flatbuffer api
[docs] def buffer( self, extract_delegate_segments: bool = False, segment_alignment: int = 128, constant_tensor_alignment: Optional[int] = None, delegate_alignment: Optional[int] = None, memory_planning: MemoryPlanningPass = None, # pyre-fixme[9] ) -> bytes: """ Returns a buffer containing the serialized ExecuTorch binary. """ # TODO(T181463742): avoid calling bytes(..) which incurs large copies. out = bytes( _serialize_pte_binary( program=self.program(memory_planning=memory_planning), extract_delegate_segments=extract_delegate_segments, segment_alignment=segment_alignment, constant_tensor_alignment=constant_tensor_alignment, delegate_alignment=delegate_alignment, ) ) return out
# TODO(chenlai): re-consider recapture instead of manually constructing the program because # the meta data construction is done manually. def program( self, emit_stacktrace: bool = False, memory_planning: MemoryPlanningPass = None, # pyre-fixme[9] ) -> Program: # Fix autodpes introuces cyclic dependencies: # program -> verifier -> lowered_backend_module -> program # @manual from executorch.exir.program._program import ( _get_updated_graph_signature, _transform, ) """ Returns the object that represents the ExecuTorch binary before serialization. """ # Creates a new module based on the original module. The original module will # look something like following: # # opcode name target args kwargs # ------------- ------------------- ---------------- ------------------------------------------ -------- # placeholder arg0_1 arg0_1 () {} # placeholder arg1_1 arg1_1 () {} # call_function aten_repeat_default * (arg1_1, [4, 1]) {} # call_function aten_mul_tensor * (aten_repeat_default, aten_repeat_default) {} # call_function aten_add_tensor * (arg1_1, arg1_1) {} # output output output ([aten_mul_tensor, aten_add_tensor],) {} # # if the whole module is lowered, the resulting lowered module look like # # opcode name target args kwargs # ------------- ------------------------ --------------------------- ---------------------------------- -------- # placeholder arg0_1 arg0_1 () {} # placeholder arg1_1 arg1_1 () {} # get_attr lowered_module_0 lowered_module_0 () {} # call_function executorch_call_delegate executorch_call_delegate (lowered_module_0, arg0_1, arg1_1) {} # call_function getitem <built-in function getitem> (executorch_call_delegate, 0) {} # call_function getitem_1 <built-in function getitem> (executorch_call_delegate, 1) {} # output output_1 output ([getitem, getitem_1],) {} # # We'll remove all call_function nodes, insert an call_delegate node, inserting getitems nodes to get the result for call_delegate node # and return the list of getitems as the output lowered_exported_program = copy.deepcopy(self._original_exported_program) # The real input nodes are the ones not buffer or parameter all_input_nodes = [ node for node in lowered_exported_program.graph.nodes if ( node.op == "placeholder" and node.name not in lowered_exported_program.graph_signature.inputs_to_buffers and node.name not in lowered_exported_program.graph_signature.inputs_to_parameters ) ] output_node = [ node for node in lowered_exported_program.graph.nodes if node.op == "output" ] assert len(output_node) == 1, "There should be only one output node" # Step 1. Cleaning up the graph before inserting the call_delegate node # Remove the original output node lowered_exported_program.graph.erase_node(output_node[0]) # Remove all the everything else except the input for node in reversed(lowered_exported_program.graph.nodes): if node.op != "placeholder": lowered_exported_program.graph.erase_node(node) # Find placeholders that are parameters or buffers, remove them from the main graph for node in lowered_exported_program.graph.nodes: if node.op == "placeholder" and ( node.name in lowered_exported_program.graph_signature.inputs_to_buffers or node.name in lowered_exported_program.graph_signature.inputs_to_parameters ): lowered_exported_program.graph.erase_node(node) # Step 2. Start constructing the graph lowered_name = get_lowered_module_name( lowered_exported_program.graph_module, self ) # Insert the lowered module to the graph module as an attibute lowered_node = lowered_exported_program.graph.get_attr(lowered_name) # Insert a call_delegate node to the graph module, with arguments from the arg list delegate_node = lowered_exported_program.graph.call_function( executorch_call_delegate, (lowered_node, *all_input_nodes) ) # Get the output list. Since the output node is a tuple of list, like ([aten_mul_tensor, aten_add_tensor],) # We add some handling logic to get the list `[aten_mul_tensor, aten_add_tensor]` properly original_output_nodes = [ node for node in self._original_exported_program.graph.nodes if node.op == "output" ][0].args[0] delegate_node.meta["spec"] = tuple( [make_spec(node.meta["val"]) for node in original_output_nodes] ) delegate_node.meta["val"] = tuple( [node.meta["val"] for node in original_output_nodes] ) # The getitem nodes that are going to be inserted to the lowered graph module getitem_nodes = [] for i in range(len(original_output_nodes)): getitem_node = lowered_exported_program.graph.call_function( operator.getitem, args=(delegate_node, i), ) getitem_node.meta["val"] = delegate_node.meta["val"][i] getitem_nodes.append(getitem_node) lowered_exported_program.graph.output(getitem_nodes) lowered_exported_program.graph_module.recompile() lowered_exported_program.graph.lint() # Users output will be the get items nodes instead output_specs = [ OutputSpec( kind=OutputKind.USER_OUTPUT, arg=TensorArgument(name=getitem_node.name), target=None, ) for getitem_node in getitem_nodes ] # All data are consumed by the delegates so they should be removed from the state dict. inputs_to_parameters = ( lowered_exported_program.graph_signature.inputs_to_parameters ) inputs_to_buffers = lowered_exported_program.graph_signature.inputs_to_buffers input_specs = [ InputSpec( kind=InputKind.USER_INPUT, arg=TensorArgument(name=node.name), target=None, ) for user_input in lowered_exported_program.graph_signature.user_inputs if user_input not in inputs_to_parameters and user_input not in inputs_to_buffers ] # Double check the ExportedProgram data(especially everything except graph) is good exported_program = ExportedProgram( root=lowered_exported_program.graph_module, graph=lowered_exported_program.graph, graph_signature=_get_updated_graph_signature( ExportGraphSignature( input_specs=input_specs, output_specs=output_specs ), lowered_exported_program.graph_module, ), # TODO: May need to set lowered_exported_program.call_spec = CallSpec(None, None) # somewhere as we should pass it a list of tensors to the lowered module and output a # list of tensors. Putting call_spec=lowered_exported_program.call_spec is correct here as the # inputs/outputs to the toplevel program will be in the format of the eager module. state_dict={}, # None because all data are consumed by delegate range_constraints=lowered_exported_program.range_constraints, module_call_graph=lowered_exported_program.module_call_graph, example_inputs=None, verifiers=[lowered_exported_program.verifier], ) if memory_planning is None: memory_planning = MemoryPlanningPass() exported_program = _transform(exported_program, SpecPropPass(), memory_planning) emitted_program = emit_program( exported_program, emit_stacktrace=emit_stacktrace ).program return emitted_program # Used to patch each delegated function with a call_delegate call # @staticmethod def forward( self, *args: Value, **kwargs: Tuple[Value, ...], ) -> Value: return executorch_call_delegate(self, *args)
# TODO(zhxchen17) Try ExportPass def _fixup_output_node(gm: torch.fx.GraphModule) -> None: for node in reversed(gm.graph.nodes): if node.op == "output": with gm.graph.inserting_before(node): assert len(node.args) == 1 outputs = node.args[0] if isinstance(outputs, torch.fx.Node): val = outputs.meta.get("val") if isinstance(val, list): # If a list is returned, in some cases it is represented as a # singular node, like `split_copy_tensor` but EXIR will return a # opened-up list like `[getitem1, getitem2]` outputs = [ torch.fx.Proxy(outputs)[i].node for i in range(len(val)) ] returns, out_spec = pytree.tree_flatten(outputs) node.args = (returns,) return def arrange_graph_placeholders( gm: torch.fx.GraphModule, owning_program: ExportedProgram ) -> torch.fx.GraphModule: """ Modifies the graph of the given graphmodule with one that contains the same nodes as the original, but with placeholders in order of (Params + Buffers) (User Inputs) This is used by the delegate api which disturbs the placeholder ordering when creating a submodule from partitioned nodes Args: gm: The graph module that we want arranged owning_program: ExportedProgram that the submodule (gm) belongs to Returns: The graph module in-placed arranged """ new_graph = torch.fx.Graph() node_map = {} # mapping of nodes from old graph to new graph graph_sign = owning_program.graph_signature # Add all placeholders into the graph first: param_nodes = [] buffer_nodes = [] input_nodes = [] for node in gm.graph.nodes: if node.op != "placeholder": continue if node.name in graph_sign.inputs_to_parameters: param_nodes.append(node) elif node.name in graph_sign.inputs_to_buffers: buffer_nodes.append(node) else: input_nodes.append(node) for param_node in param_nodes: new_node = new_graph.node_copy(param_node, lambda x: node_map[x]) node_map[param_node] = new_node for buffer_node in buffer_nodes: new_node = new_graph.node_copy(buffer_node, lambda x: node_map[x]) node_map[buffer_node] = new_node for input_node in input_nodes: new_node = new_graph.node_copy(input_node, lambda x: node_map[x]) node_map[input_node] = new_node # Now add all the other nodes in order for node in gm.graph.nodes: if node.op == "placeholder": continue new_node = new_graph.node_copy(node, lambda x: node_map[x]) node_map[node] = new_node # lint to ensure correctness new_graph.lint() new_graph._codegen = gm.graph._codegen gm.graph = new_graph return gm # TODO Don't regenerate new signature manually. def _get_new_signature( # noqa: C901 original_program: ExportedProgram, gm: torch.fx.GraphModule, call_module_node: torch.fx.Node, tag: str, is_submodule: bool = False, ) -> Tuple[ ExportGraphSignature, Dict[str, Union[torch.Tensor, torch.nn.Parameter]], Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]], Dict[str, InputSpec], Dict[str, OutputSpec], ]: """ Args: original_program: The original program that we are paritioning gm: The partitioned graph module. call_module_node: The node in the original program that is calling the partitioned graph module. tag: The tag being used for this partitioned submodule. This is used to tell if a particular parameter/buffer/constant node is being tagged, aka consumed by the delegate. is_submodule: True if we are currently partitioning inside of a submodule (like cond's submodule). If we are inside of a submodule, we do not care about consuming params/buffers. Returns: new_signature (ExportGraphSignature): The new signature for the partitioned graph module. new_state_dict (Dict[str, Union[torch.Tensor, torch.nn.Parameter]]): The new state dict containing the consumed params/buffers. new_constants (Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]]): The new constants table containing the consumed constants . input_specs_to_delete (Dict[str, InputSpec]): The input specs that have been consumed by the delegate (param/buffer input nodes) and should be removed from the toplevel ExportedProgram. output_specs_to_delete (Dict[str, InputSpec]): The output specs that have been consumed by the delegate (buffer mutation nodes) and should be removed from the toplevel ExportedProgram. """ old_signature = original_program.graph_signature input_specs = [] output_specs = [] input_specs_to_delete = {} output_specs_to_delete = {} new_state_dict = {} new_constants = {} # If we are within a submodule, we do not need to care about consuming # parameter/buffers input_node_to_sig: Dict[str, InputSpec] = ( {input_spec.arg.name: input_spec for input_spec in old_signature.input_specs} if not is_submodule else {} ) toplevel_output_node_to_sig: Dict[str, List[OutputSpec]] = defaultdict(list) if not is_submodule: for output_spec in old_signature.output_specs: toplevel_output_node_to_sig[output_spec.arg.name].append(output_spec) for node in gm.graph.nodes: if node.op == "placeholder": if node.name not in input_node_to_sig: input_specs.append( InputSpec( kind=InputKind.USER_INPUT, arg=TensorArgument(name=node.name), target=None, ) ) continue orig_input_spec = input_node_to_sig[node.name] if not isinstance(orig_input_spec.arg, TensorArgument): input_specs.append(orig_input_spec) elif node.meta.get("delegation_tag", None) == tag: input_specs.append(orig_input_spec) if orig_input_spec.kind == InputKind.USER_INPUT: continue # The following input specs are all attributes that should be # consumed by the delegate, so we want to remove it from the # toplevel module input/output input_specs_to_delete[node.name] = orig_input_spec input_target = orig_input_spec.target if input_target in original_program.state_dict: assert orig_input_spec.kind in ( InputKind.PARAMETER, InputKind.BUFFER, ) new_state_dict[input_target] = original_program.state_dict[ input_target ] elif input_target in original_program.constants: assert orig_input_spec.kind in ( InputKind.CONSTANT_TENSOR, InputKind.CUSTOM_OBJ, InputKind.BUFFER, ) new_constants[input_target] = original_program.constants[ input_target ] else: raise RuntimeError(f"Invalid input spec {orig_input_spec} received") else: input_specs.append( InputSpec( kind=InputKind.USER_INPUT, arg=TensorArgument(name=node.name), target=None, ) ) if node.op == "output": buffer_mutation_idxs: Dict[int, List[OutputSpec]] = defaultdict(list) for user in call_module_node.users.keys(): if user.name in toplevel_output_node_to_sig: assert ( user.op == "call_function" and user.target == operator.getitem ), f"Invalid user {user}, node.op is {user.op} and node.target is {user.target}" getitem_idx = user.args[1] assert isinstance( getitem_idx, int ), f"Invalid getitem type: {type(getitem_idx)}" buffer_mutation_idxs[getitem_idx].extend( toplevel_output_node_to_sig[user.name] ) for i, output_node in enumerate(node.args[0]): if i in buffer_mutation_idxs: assert isinstance(output_node, torch.fx.Node) orig_output_specs = buffer_mutation_idxs[i] if any( orig_output_spec.kind == OutputKind.BUFFER_MUTATION and orig_output_spec.target in new_state_dict for orig_output_spec in orig_output_specs ): # If the delegate wants to consume the buffer, then the # delegate should also consume the buffer mutation # (output spec would be a BUFFER_MUTATION). Otherwise # the delegate will just return the result of the # mutation as a USER_OUTPUT. orig_output_spec = [ orig_output_spec for orig_output_spec in orig_output_specs if orig_output_spec.kind == OutputKind.BUFFER_MUTATION and orig_output_spec.target in new_state_dict ][0] assert len(orig_output_specs) == 1, ( f"Constant {orig_output_spec.target} was tagged to be " "consumed by the buffer, and was found to also contain " "a buffer mutation. However this buffer mutation node " "was found to also be used as other types of outputs " "which is currently not supported. Please file an " "issue on Github. \n\n" f"The toplevel program: {original_program}\n" ) output_specs.append( OutputSpec( kind=OutputKind.BUFFER_MUTATION, arg=TensorArgument(name=output_node.name), target=orig_output_spec.target, ) ) output_specs_to_delete[orig_output_spec.arg.name] = ( orig_output_spec ) else: output_specs.append( OutputSpec( kind=OutputKind.USER_OUTPUT, arg=TensorArgument(name=output_node.name), target=None, ) ) elif not isinstance(output_node, torch.fx.Node): output_specs.append( OutputSpec( kind=OutputKind.USER_OUTPUT, arg=ConstantArgument(name="", value=output_node), target=None, ) ) else: output_specs.append( OutputSpec( kind=OutputKind.USER_OUTPUT, arg=TensorArgument(name=output_node.name), target=None, ) ) new_signature = ExportGraphSignature( input_specs=input_specs, output_specs=output_specs ) return ( new_signature, new_state_dict, new_constants, input_specs_to_delete, output_specs_to_delete, ) def create_exported_program_from_submodule( submodule: torch.fx.GraphModule, owning_program: ExportedProgram, tag: str, call_module_node: torch.fx.Node, is_submodule: bool, ) -> Tuple[ExportedProgram, Dict[str, InputSpec], Dict[str, OutputSpec]]: """ Creates an ExportedProgram from the given submodule using the parameters and buffers from the top-level owning program Args: submodule: submodule to create and exported program from owning_program: exported program containing the parameters and buffers used within the submodule Returns: The ExportedProgram created from submodule input_specs_to_delete (Dict[str, InputSpec]): The input specs that have been consumed by the delegate (param/buffer input nodes) and should be removed from the toplevel ExportedProgram. output_specs_to_delete (Dict[str, InputSpec]): The output specs that have been consumed by the delegate (buffer mutation nodes) and should be removed from the toplevel ExportedProgram. """ # Arrange the submodule's placeholders in order submodule = arrange_graph_placeholders(submodule, owning_program) # TODO: we probably need to arrange the outputs wrt buffer mutations. # Get updated graph signature ( subgraph_signature, subgraph_state_dict, subgraph_constants, toplevel_input_specs_to_delete, toplevel_output_specs_to_delete, ) = _get_new_signature( owning_program, submodule, call_module_node, tag, is_submodule ) in_spec = pytree.tree_flatten((tuple(subgraph_signature.user_inputs), {}))[1] out_spec = pytree.tree_flatten(subgraph_signature.user_outputs)[1] return ( ExportedProgram( root=submodule, graph=submodule.graph, graph_signature=subgraph_signature, state_dict=subgraph_state_dict, range_constraints=copy.deepcopy(owning_program.range_constraints), module_call_graph=[ ModuleCallEntry( "", ModuleCallSignature( inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec ), ) ], constants=subgraph_constants, verifiers=[owning_program.verifier], ), toplevel_input_specs_to_delete, toplevel_output_specs_to_delete, ) def create_submodule_from_nodes( gm: torch.fx.GraphModule, node_list: NodeList, tag: str, skip_legalize_graph: bool = False, ) -> Tuple[torch.fx.GraphModule, torch.fx.Node]: """ Modifies the given graph module in-place to separate out the given nodes into a submodule. The given node_list should form a fully connected subgraph. Args: gm: The graph module that we want to partition node_list: A list of nodes that belong in the partition Returns: The submodule that has been partitioned, the call_module node in the toplevel graph module calling the submodule """ sorted_nodes = topo_sort(node_list) submodule_name = "fused_" + tag sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule( gm, sorted_nodes, submodule_name ) _fixup_output_node(sub_gm) gm = insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) submodule_node = None for node in gm.graph.nodes: if node.op == "call_module": if node.target == submodule_name: submodule_node = node else: raise RuntimeError( f"The submodule created with nodes {node_list} did not form \ one fully contained subgraph. Check that these nodes form a \ fully contained graph. Partitioned graph: {gm.graph}." ) if len(orig_outputs) == 1 and isinstance(orig_outputs[0].meta["val"], FakeTensor): # If the original output is a single tensor, it has been # pytree.tree_flatten-ed to be a singleton list, so we want to replace # all uses with a getitem call to the 0th index of the result with gm.graph.inserting_after(submodule_node): proxy_out = torch.fx.Proxy(submodule_node)[0].node # type: ignore[index] submodule_node.replace_all_uses_with(proxy_out) proxy_out.meta["val"] = submodule_node.meta["val"] # Reset the args since it was overwritten in the previous line proxy_out.args = (submodule_node, 0) else: # fuse_as_graphmodule will automatically propagate the metadata of the # partition's last node to the getitem nodes that appear after the # call_module node. However, in the case of delegation we do not want # these getitem nodes to contain irrelevant previous metadata # (ex. source_fn, # nn_module_stack) for user_node in submodule_node.users: user_node.meta.pop("nn_module_stack", None) user_node.meta.pop("source_fn_stack", None) erase_nodes(gm, sorted_nodes) # Topological sort original gm with newly created sub_gm # TODO : T153794167 Get rid of support for skipping legalize graph in create_submodule_from_nodes # once we transition to using fuse_by_partitions. if not skip_legalize_graph: legalize_graph(gm) # Get the call_module node submodule_node = None for node in gm.graph.nodes: if node.op == "call_module" and node.target == submodule_name: submodule_node = node elif node.op == "call_module": raise RuntimeError( f"The submodule created with nodes {node_list} did not form \ one fully contained subgraph. Check that these nodes form a \ fully contained graph. Partitioned graph: {gm.graph}." ) assert ( submodule_node is not None ), f"No submodule was created with the nodes {node_list} in the graph {gm.graph}" return sub_gm, submodule_node def get_lowered_submodules( graph_module: torch.fx.GraphModule, ) -> List[Tuple[str, LoweredBackendModule, torch.fx.Node]]: """ Returns a list of lowered modules that are in the given graph (does not look into submodules). Specifically, the returned value is a list containing a tuple of (name of the lowered module that's stored in the graph module, the lowered module itself, and the fx node that called this lowered module). """ lowered_submodules = [] for node in graph_module.graph.nodes: if node.op == "call_function" and node.target == executorch_call_delegate: name, module, node = _get_submodule(graph_module, node, 0) assert isinstance(module, LoweredBackendModule) lowered_submodules.append((name, module, node)) return lowered_submodules def get_lowered_backend_modules( graph_module: torch.fx.GraphModule, ) -> List[LoweredBackendModule]: """ Returns a list of exported programs which were lowered by backen delegates """ lowered_programs = [] for node in graph_module.graph.nodes: if node.op == "call_function" and node.target == executorch_call_delegate: lowered_backend_module = getattr(graph_module, node.args[0].name) lowered_programs.append(lowered_backend_module) return lowered_programs def _unsafe_adjust_original_program( # noqa: C901 original_program: ExportedProgram, call_delegate_node: torch.fx.Node, input_specs_to_delete: Dict[str, InputSpec], output_specs_to_delete: Dict[str, OutputSpec], ) -> None: """ Directly modify the original exported program's signature and state dict based on the consumed params/buffers in the delegate. """ original_program._graph_signature.input_specs = [ input_spec for input_spec in original_program.graph_signature.input_specs if input_spec.arg.name not in input_specs_to_delete ] currently_used_targets: Set[str] = { input_spec.target for input_spec in original_program._graph_signature.input_specs if input_spec.target is not None } original_program._graph_signature.output_specs = [ output_spec for output_spec in original_program.graph_signature.output_specs if output_spec.arg.name not in output_specs_to_delete ] # Delete all parameters/buffers consumed by the created exported program # from the graph signature, state dict, constants table for node in original_program.graph.nodes: if node.op == "placeholder": if node.name in input_specs_to_delete: assert len(node.users) == 0 original_program.graph.erase_node(node) else: break for input_spec in input_specs_to_delete.values(): input_target = input_spec.target assert input_target is not None if input_target in currently_used_targets: continue if input_spec.kind == InputKind.PARAMETER: del original_program._state_dict[input_target] elif input_spec.kind == InputKind.BUFFER: if input_spec.persistent: del original_program._state_dict[input_target] else: del original_program._constants[input_spec.target] elif input_spec.kind == InputKind.CONSTANT_TENSOR: del original_program._constants[input_spec.target] else: raise RuntimeError(f"Invalid input spec {input_spec} received") # Delete buffer mutations from the output which were consumed by the delegate toplevel_output_node = None for node in reversed(original_program.graph.nodes): if node.op == "output": toplevel_output_node = node break assert toplevel_output_node is not None assert ( len(toplevel_output_node.args) == 1 ), f"Invalid output node: {toplevel_output_node} with args {toplevel_output_node.args}" new_output_args = [ arg for arg in toplevel_output_node.args[0] if not isinstance(arg, torch.fx.Node) or arg.name not in output_specs_to_delete ] toplevel_output_node.args = (tuple(new_output_args),) # Delete the buffer mutation getitem nodes getitem_idxs: List[int] = [] user_nodes = list(call_delegate_node.users.keys()) for user in user_nodes: if user.name in output_specs_to_delete: assert ( user.op == "call_function" and user.target == operator.getitem ), f"Invalid user {user}, node.op is {node.op} and node.target is {node.target}" user_idx = user.args[1] assert isinstance(user_idx, int), f"Invalid getitem type: {type(user_idx)}" getitem_idxs.append(user_idx) original_program.graph.erase_node(user) getitem_idxs.sort(reverse=True) # Adjust all the getitem indices after the deleted getitems user_nodes = list(call_delegate_node.users.keys()) for user in user_nodes: assert user.op == "call_function" and user.target == operator.getitem user_idx = user.args[1] assert isinstance(user_idx, int) for i, idx in enumerate(getitem_idxs): if user_idx > idx: user.args = (user.args[0], user_idx - (len(getitem_idxs) - i)) break original_program._validate()

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources