Shortcuts

Source code for executorch.exir.program._program

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
import io
import logging
from typing import Any, Dict, List, Optional, Sequence, Set, Union

import torch
import torch._export

from executorch.exir._serialize import _serialize_pte_binary
from executorch.exir._serialize._cord import Cord
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.partitioner import Partitioner
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
from executorch.exir.emit import emit_program, EmitterOutput
from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap
from executorch.exir.error import ExportError
from executorch.exir.pass_manager import PassType
from executorch.exir.passes import (
    base_post_op_replace_passes,
    base_pre_op_replace_passes,
    EdgeToBackendOpsPass,
    MemoryFormatOpsPass,
    OpReplacePass,
)
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
    insert_write_back_for_buffers_pass,
)
from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators
from executorch.exir.passes.spec_prop_pass import SpecPropPass
from executorch.exir.print_program import pretty_print, print_program
from executorch.exir.schema import Program
from executorch.exir.tracer import _default_decomposition_table
from executorch.exir.verification.verifier import (
    EXIRATenDialectVerifier,
    EXIREdgeDialectVerifier,
    get_aten_verifier,
)
from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
from torch.export import ExportedProgram
from torch.export._remove_auto_functionalized_pass import (
    unsafe_remove_auto_functionalized_pass,
)
from torch.export.exported_program import (
    _get_updated_range_constraints,
    ConstantArgument,
    ExportGraphSignature,
    InputKind,
    InputSpec,
    OutputSpec,
    TensorArgument,
)
from torch.fx import _pytree as fx_pytree
from torch.fx._compatibility import compatibility
from torch.fx.passes.infra.pass_manager import PassManager
from torch.utils import _pytree as pytree

Val = Any


def _get_updated_graph_signature(
    old_signature: ExportGraphSignature,
    new_gm: torch.fx.GraphModule,
) -> ExportGraphSignature:
    """
    Update the graph signature's user_input/user_outputs.
    """
    new_input_specs = []
    i = 0
    for node in new_gm.graph.nodes:
        if node.op != "placeholder":
            continue

        assert i < len(
            old_signature.input_specs
        ), "Number of inputs changed after transformation"
        old_input_spec = old_signature.input_specs[i]
        arg = (
            old_input_spec.arg
            if isinstance(old_input_spec.arg, ConstantArgument)
            # pyre-fixme[20]: Argument `class_fqn` expected.
            else type(old_input_spec.arg)(node.name)
        )
        new_input_specs.append(
            InputSpec(
                old_input_spec.kind,
                arg,
                old_input_spec.target,
                persistent=old_input_spec.persistent,
            )
        )
        i += 1

    output_node = list(new_gm.graph.nodes)[-1]
    assert output_node.op == "output"

    new_output_specs = []
    for i, node in enumerate(output_node.args[0]):
        assert i < len(
            old_signature.output_specs
        ), "Number of outputs changed after transformation"
        old_output_spec = old_signature.output_specs[i]
        arg = (
            old_output_spec.arg
            if isinstance(old_output_spec.arg, ConstantArgument)
            # pyre-fixme[20]: Argument `class_fqn` expected.
            else type(old_output_spec.arg)(node.name)
        )
        new_output_specs.append(
            OutputSpec(old_output_spec.kind, arg, old_output_spec.target)
        )

    new_signature = ExportGraphSignature(
        input_specs=new_input_specs, output_specs=new_output_specs
    )
    return new_signature


def _transform(self, *passes: PassType) -> "ExportedProgram":
    pm = PassManager(list(passes))
    res = pm(self.graph_module)
    transformed_gm = res.graph_module if res is not None else self.graph_module
    assert transformed_gm is not None

    if transformed_gm is self.graph_module and not res.modified:
        return self

    transformed_ep = ExportedProgram(
        root=transformed_gm,
        graph=transformed_gm.graph,
        graph_signature=_get_updated_graph_signature(
            self.graph_signature, transformed_gm
        ),
        state_dict=self.state_dict,
        range_constraints=_get_updated_range_constraints(transformed_gm),
        module_call_graph=copy.deepcopy(self._module_call_graph),
        example_inputs=self.example_inputs,
        verifier=self.verifier,
        constants=self.constants,
    )
    transformed_ep.graph_module.meta.update(self.graph_module.meta)
    transformed_ep.graph_module.meta.update(res.graph_module.meta)
    return transformed_ep


def _copy_module(new_prog, new_gm):
    new_prog.meta.update(new_gm.meta)
    new_prog.graph = new_gm.graph
    submodules = [name for name, _ in new_prog.named_children()]
    for name in submodules:
        delattr(new_prog, name)
    for name, mod in new_gm.named_children():
        setattr(new_prog, name, mod)
    for node in new_gm.graph.nodes:
        if node.op == "get_attr":
            t = getattr(new_gm, node.target, None)
            if isinstance(t, torch.Tensor):
                setattr(new_prog, node.target, t)


def lift_constant_tensor_pass(ep):
    """
    Takes an ExportedProgram and returns the ExportedProgram modified in-place,
    with the constant tensors as buffers.
    """
    if len([node for node in ep.graph.nodes if node.op == "placeholder"]) == 0:
        return ep

    graph_signature = ep.graph_signature
    buffers = graph_signature.buffers

    fake_mode = list(ep.graph.nodes)[0].meta["val"].fake_mode
    first_user_input = None
    lifted_constants = []
    for node in ep.graph.nodes:
        if node.op == "placeholder" and node.name in graph_signature.user_inputs:
            first_user_input = node
            break

    for node in ep.graph.nodes:
        if node.op == "get_attr":
            constant_tensor = getattr(ep.graph_module, node.target)
            if not isinstance(constant_tensor, torch.Tensor):
                continue

            constant_tensor_fqn = f"_lifted_tensor_constant{len(buffers)}"

            with ep.graph.inserting_before(first_user_input):
                # Insert the constant node before the first user input
                const_placeholder_node = ep.graph.placeholder(constant_tensor_fqn)
                for k, v in node.meta.items():
                    const_placeholder_node.meta[k] = v
                if fake_mode is not None:
                    const_placeholder_node.meta["val"] = fake_mode.from_tensor(
                        constant_tensor, static_shapes=True
                    )
                else:
                    const_placeholder_node.meta["val"] = constant_tensor
                const_placeholder_node.meta["val"].constant = constant_tensor
                node.replace_all_uses_with(const_placeholder_node)
                ep.graph.erase_node(node)

                # Add the constant as a buffer to the graph signature
                lifted_constants.append(
                    InputSpec(
                        kind=InputKind.BUFFER,
                        arg=TensorArgument(name=const_placeholder_node.name),
                        target=constant_tensor_fqn,
                        persistent=True,
                    )
                )
                buffers.append(constant_tensor_fqn)
                ep.state_dict[constant_tensor_fqn] = constant_tensor

    new_input_specs = []
    for s in graph_signature.input_specs:
        if s.kind == InputKind.USER_INPUT and len(lifted_constants) > 0:
            new_input_specs.extend(lifted_constants)
            lifted_constants.clear()
        new_input_specs.append(s)
    ep.graph_signature.input_specs = new_input_specs
    ep.graph_module.recompile()
    return ep


# Stub to ease migration from `transform` to private `_transform`
def transform_exported_program(ep, *passes: PassType) -> ExportedProgram:
    if hasattr(ep, "_transform"):
        return ep._transform(*passes)
    else:
        return ep.transform(*passes)


class HackedUpExportedProgramDONOTUSE(ExportedProgram):
    def __init__(
        self,
        root,
        graph,
        graph_signature,
        call_spec,
        state_dict,
        range_constraints,
        module_call_graph,
        example_inputs,
        verifier,
    ):
        super().__init__(
            root=root,
            graph=graph,
            graph_signature=graph_signature,
            state_dict=state_dict,
            range_constraints=range_constraints,
            module_call_graph=module_call_graph,
            example_inputs=example_inputs,
            verifier=verifier,
        )

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        import torch._export.error as error

        if self.call_spec.in_spec is not None:
            user_args = args
            try:
                args = fx_pytree.tree_flatten_spec(user_args, self.call_spec.in_spec)  # type: ignore[assignment]
            except Exception:
                _, received_spec = pytree.tree_flatten(user_args)
                raise error.InternalError(
                    "Trying to flatten user inputs with exported input tree spec: \n"
                    f"{self.call_spec.in_spec}\n"
                    "but actually got inputs with tree spec of: \n"
                    f"{received_spec}"
                )

        ordered_params = tuple(
            self.state_dict[name] for name in self.graph_signature.parameters
        )
        ordered_buffers = tuple(
            self.state_dict[name] for name in self.graph_signature.buffers
        )

        with torch.no_grad():
            # NOTE: calling convention is first params, then buffers, then args as user supplied them.
            # See: torch/_functorch/aot_autograd.py#L1034
            res = torch.fx.Interpreter(self.graph_module).run(
                *ordered_params, *ordered_buffers, *args, enable_io_processing=False
            )

        if self.call_spec.out_spec is not None:
            mutation = self.graph_signature.buffers_to_mutate
            num_mutated = len(mutation)
            mutated_buffers = res[:num_mutated]

            # Exclude dependency token from final result.
            assertion_dep_token = self.graph_signature.assertion_dep_token
            if assertion_dep_token is not None:
                assertion_dep_token_index = list(assertion_dep_token.keys())[0]
                res = res[:assertion_dep_token_index]

            res = res[num_mutated:]
            try:
                res = pytree.tree_unflatten(res, self.call_spec.out_spec)
            except Exception:
                _, received_spec = pytree.tree_flatten(res)
                raise error.InternalError(
                    "Trying to flatten user outputs with exported output tree spec: \n"
                    f"{self.call_spec.out_spec}\n"
                    "but actually got outputs with tree spec of: \n"
                    f"{received_spec}"
                )
            finally:
                ix = 0
                for buffer in self.graph_signature.buffers_to_mutate.values():
                    self.state_dict[buffer] = mutated_buffers[ix]
                    ix += 1
        return res


@compatibility(is_backward_compatible=False)
class ExirExportedProgram:
    def __init__(
        self,
        exported_program: ExportedProgram,
        after_to_edge_passes: bool,
    ):
        self.exported_program = exported_program

        # Add a flag to denote whehter to_edge is called on this program
        # to detect misusage of directly calling to_executorch without to_edge
        self.after_to_edge_passes = after_to_edge_passes

    def transform(self, *passes: PassType) -> "ExirExportedProgram":
        self.exported_program = _transform(self.exported_program, *passes)
        return self

    def __call__(self, *args: Any) -> Any:
        return self.exported_program.module()(*args)

    # TODO(ycao): Change this to a composable function.
    def to_edge(
        self, config: Optional[EdgeCompileConfig] = None
    ) -> "ExirExportedProgram":
        config = config or EdgeCompileConfig()
        assert isinstance(
            self.exported_program.graph_module, torch.fx.GraphModule
        ), f"type is instead: {type(self.exported_program.graph_module).__name__}"

        return _to_edge(self, config)

    def dump(self) -> None:
        print(self.exported_program.graph_module.graph)

    def to_executorch(
        self,
        config: Optional[ExecutorchBackendConfig] = None,
    ) -> "ExecutorchProgram":
        if not self.after_to_edge_passes:
            raise RuntimeError("Must run to_edge before to_executorch.")
        config = config or ExecutorchBackendConfig()
        new_gm = self.exported_program.graph_module
        for p in edge_to_executorch_passes(config):
            new_gm_res = p(new_gm)
            assert new_gm_res is not None
            new_gm = new_gm_res.graph_module

        # This is tech debt on tech debt. memory planning pass inherits from some pass infra for GMs.
        # This isnt enough info now so i cant use call I have to use some new function 'run'.
        # Existing user passes dont use run so Im just cheating here because they dont need to work on mutable buffers yet.
        # After exir.capture is gone I will clean up the memory planning infra to be consistent.
        # Frankly all of exir has big code quality issues because of the migrations that need to be addressed.
        new_gm_res = config.memory_planning_pass(new_gm)  # pyre-ignore[19]
        assert new_gm_res is not None
        new_gm = new_gm_res.graph_module
        new_prog = ExirExportedProgram(
            copy.deepcopy(self.exported_program), self.after_to_edge_passes
        )
        _copy_module(new_prog.exported_program.graph_module, new_gm)
        executorch_prog = ExecutorchProgram(
            new_prog,
            emit_stacktrace=config.emit_stacktrace,
            extract_delegate_segments=config.extract_delegate_segments,
            extract_constant_segment=config.extract_constant_segment,
            segment_alignment=config.segment_alignment,
            constant_tensor_alignment=config.constant_tensor_alignment,
            delegate_alignment=config.delegate_alignment,
        )
        executorch_prog.graph_module.meta.update(new_gm.meta)
        executorch_prog.graph_module.meta.update(
            self.exported_program.graph_module.meta
        )
        return executorch_prog

    def __deepcopy__(
        self, memo: Optional[Dict[int, Any]] = None
    ) -> "ExirExportedProgram":

        new_eep = ExirExportedProgram(
            copy.deepcopy(self.exported_program, memo),
            self.after_to_edge_passes,
        )
        return new_eep


@compatibility(is_backward_compatible=False)
class ExecutorchProgram:
    def __init__(
        self,
        exir_exported_program: ExirExportedProgram,
        emit_stacktrace: bool,
        extract_delegate_segments: bool,
        extract_constant_segment: bool,
        segment_alignment: int,
        constant_tensor_alignment: Optional[int] = None,
        delegate_alignment: Optional[int] = None,
    ) -> None:
        if not exir_exported_program.after_to_edge_passes:
            raise RuntimeError(
                "Need to call prog.to_edge prior to constructing ExecutorchProgram."
            )
        self.exported_program = exir_exported_program.exported_program
        self._pte_data: Optional[Cord] = None
        self._buffer: Optional[bytes] = None
        self._emitter_output: Optional[EmitterOutput] = None
        self._emit_stacktrace: bool = emit_stacktrace
        self._extract_delegate_segments: bool = extract_delegate_segments
        self._extract_constant_segment: bool = extract_constant_segment
        self._segment_alignment: int = segment_alignment
        self._constant_tensor_alignment: Optional[int] = constant_tensor_alignment
        self._delegate_alignment: Optional[int] = delegate_alignment

    def _get_pte_data(self) -> Cord:
        if self._pte_data is None:
            self._pte_data = _serialize_pte_binary(
                program=self.program,
                extract_delegate_segments=self._extract_delegate_segments,
                extract_constant_segment=self._extract_constant_segment,
                segment_alignment=self._segment_alignment,
                constant_tensor_alignment=self._constant_tensor_alignment,
                delegate_alignment=self._delegate_alignment,
            )
        return self._pte_data

    @property
    def buffer(self) -> bytes:
        """Returns the serialized ExecuTorch binary as a byte string.

        Note that the call to `buffer` may allocate a very large amount of
        contiguous memory, depending on the model size. If writing to a file,
        use `write_to_file` which won't incur additional copies.
        """
        # TODO(T181494963): update pybinding to remove buffer cache, which can consume large
        # amounts of memory longer than necessary.
        if self._buffer is None:
            self._buffer = bytes(self._get_pte_data())
        return self._buffer

    @property
    def program(self) -> Program:
        if self._emitter_output is None:
            self._emitter_output = emit_program(
                self.exported_program, self._emit_stacktrace
            )
        return self._emitter_output.program

    @property
    def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]:
        if self._emitter_output:
            return self._emitter_output.debug_handle_map
        return {}

    @property
    def delegate_map(
        self,
    ) -> Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]:
        if self._emitter_output:
            return self._emitter_output.method_to_delegate_debug_id_map
        return {}

    @property
    def graph_module(self) -> torch.fx.GraphModule:
        return self.exported_program.graph_module

    # TODO (zhxchen17) Change this to property.
    def dump_graph_module(self) -> torch.fx.GraphModule:
        return self.exported_program.graph_module

    def dump_exported_program(self) -> ExportedProgram:
        return self.exported_program

    def write_to_file(self, open_file: io.BufferedIOBase) -> None:
        """
        Writes the serialized ExecuTorch binary to the file at `open_file`. Prefer to use this over
        `buffer`, as it writes to file without copying into a contiguous block of memory first,
        reducing the peak memory usage.
        """
        self._get_pte_data().write_to_file(open_file)


def _get_aten_to_edge_passes(config: EdgeCompileConfig):
    # TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable
    # use_edge_op it can be moved to aten_to_edge_passes before eliminated_dead_code pass. Also ExportPass doesn't play
    # well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta.
    # It might be regenerated in SpecPropPass so it may not be visiable. However debug handle will be lost.

    pre_op_replace_passes = base_pre_op_replace_passes + (
        [] if config._skip_type_promotion else [RemoveMixedTypeOperators()]
    )

    post_op_replace_passes = (
        [] if config._skip_dim_order else [MemoryFormatOpsPass()]
    ) + base_post_op_replace_passes

    return pre_op_replace_passes, post_op_replace_passes


def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
    if config._check_ir_validity:
        try:
            EXIRATenDialectVerifier()(ep.exported_program.graph_module)
        except ExportError:
            logging.info(
                "If you'd like to disable IR validation checking, please set _check_ir_validity in EdgeCompileConfig, "
                "like *.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))."
            )
            raise

    dialect = ep.exported_program.dialect
    if dialect == "ATEN":
        ep = ExirExportedProgram(
            ExportedProgram(
                root=ep.exported_program.graph_module,
                graph=ep.exported_program.graph_module.graph,
                graph_signature=ep.exported_program.graph_signature,
                state_dict=ep.exported_program.state_dict,
                range_constraints=ep.exported_program.range_constraints,
                module_call_graph=ep.exported_program.module_call_graph,
                example_inputs=ep.exported_program.example_inputs,
                verifier=get_aten_verifier(enable=config._check_ir_validity),
                constants=ep.exported_program.constants,
            ),
            False,
        )
    pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config)

    new_ep = copy.deepcopy(ep).transform(*pre_op_replace_passes)
    if dialect == "ATEN":
        new_ep.exported_program = lift_constant_tensor_pass(new_ep.exported_program)

    new_gm = new_ep.exported_program.graph_module
    if config._use_edge_ops:
        new_gm_res = OpReplacePass()(new_gm)
        assert new_gm_res is not None
        new_gm = new_gm_res.graph_module

    for p in post_op_replace_passes:
        new_gm_res = p(new_gm)
        assert new_gm_res is not None
        new_gm = new_gm_res.graph_module

    new_ep.exported_program = ExportedProgram(
        root=new_gm,
        graph=new_gm.graph,
        graph_signature=_get_updated_graph_signature(
            new_ep.exported_program.graph_signature, new_gm
        ),
        state_dict=new_ep.exported_program.state_dict,
        range_constraints=new_ep.exported_program.range_constraints,
        module_call_graph=new_ep.exported_program.module_call_graph,
        example_inputs=new_ep.exported_program.example_inputs,
        verifier=EXIREdgeDialectVerifier(
            check_edge_ops=config._use_edge_ops,
            enable=config._check_ir_validity,
            class_only=True,
        ),
        constants=new_ep.exported_program.constants,
    )
    new_ep.after_to_edge_passes = True
    return new_ep


def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType]:
    # pyre-ignore
    passes: List[PassType] = [
        *config.passes,
        SpecPropPass(),
        # ExecuTorch backend ops are unable to handle unbacked symints. So after
        # this pass, passes cannot be Interpreter-based, because it will fail if
        # there exists an unbacked symint operation.
        EdgeToBackendOpsPass(),
        RemoveGraphAssertsPass(),
        config.sym_shape_eval_pass,
        config.to_out_var_pass,
    ]
    return passes


[docs]def to_edge( programs: Union[ExportedProgram, Dict[str, ExportedProgram]], constant_methods: Optional[Dict[str, Any]] = None, compile_config: Optional[EdgeCompileConfig] = None, ) -> "EdgeProgramManager": """ :func:`to_edge` constructs an EdgeProgramManager from a set of exported programs in ATen dialect. Upon construction those programs are transformed into edge dialect. Args: programs: Can be a single ExportedProgram or a dictionary mapping function names to their corresponding ExportedPrograms. If only a single ExportedProgram is provided it will be assigned the name "forward". constant_methods: An optional dictionary of method name to the constant value returned by that method in eager mode. Often used to store config information on Edge models. compile_config: An optional argument used to provide greater control over the transformation to edge dialect process. Returns: EdgeProgramManager """ assert not isinstance(constant_methods, EdgeCompileConfig) config = compile_config or EdgeCompileConfig() if not isinstance(programs, dict): aten_programs = {"forward": programs} else: aten_programs = programs edge_programs: Dict[str, ExportedProgram] = {} for name, program in aten_programs.items(): # Decompose to Core ATen program = program.run_decompositions(_default_decomposition_table()) if config._check_ir_validity: try: EXIRATenDialectVerifier()(program.graph_module) except ExportError as e: logging.info(f"Input program {name} is not in ATen dialect.") raise e pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config) passes = [] passes.append( ReplaceViewOpsWithViewCopyOpsPass() ) # TODO move inside aten_to_edge passes after all users are migrated off v1 capture passes.extend(pre_op_replace_passes) if config._use_edge_ops: passes.append(OpReplacePass()) gm = program.graph_module for p in passes: gm_res = p(gm) assert gm_res is not None gm = gm_res.graph_module edge_program = ExportedProgram( root=gm, graph=gm.graph, graph_signature=_get_updated_graph_signature(program.graph_signature, gm), state_dict=program.state_dict, range_constraints=program.range_constraints, module_call_graph=program.module_call_graph, example_inputs=program.example_inputs, verifier=EXIREdgeDialectVerifier( check_edge_ops=config._use_edge_ops, enable=config._check_ir_validity, class_only=True, ), constants=program.constants, ) # Lift the tensor constants created in ScalarToTensorPass edge_program = lift_constant_tensor_pass(edge_program) edge_program = _transform(edge_program, *post_op_replace_passes) edge_programs[name] = edge_program return EdgeProgramManager(edge_programs, constant_methods, config)
[docs]class EdgeProgramManager: """ Package of one or more `ExportedPrograms` in Edge dialect. Designed to simplify lowering to ExecuTorch. See: https://pytorch.org/executorch/stable/ir-exir.html Allows easy applications of transforms across a collection of exported programs including the delegation of subgraphs. Manages the second link in the lowering chain of ATen -> Edge -> ExecuTorch. """ def __init__( self, edge_programs: Dict[str, ExportedProgram], constant_methods: Optional[Dict[str, Any]] = None, compile_config: Optional[EdgeCompileConfig] = None, ): """ Should not be called directly by users. User should use :func:'to_edge' instead. Constructs an EdgeProgramManager from an existing set of exported programs in edge dialect. """ config = compile_config or EdgeCompileConfig() for name, program in edge_programs.items(): try: EXIREdgeDialectVerifier( check_edge_ops=config._use_edge_ops, enable=config._check_ir_validity, )(program.graph_module) except ExportError as e: logging.info(f"Input program {name} is not in aten dialect.") raise e self._edge_programs = edge_programs self._config_methods = constant_methods @property def methods(self) -> Set[str]: """ Returns the set of methods in this EdgeProgramManager. """ return set(self._edge_programs.keys()) @property def config_methods(self) -> Set[str]: """ Returns the set of config methods in this EdgeProgramManager. """ return set(self._config_methods.keys()) if self._config_methods else set()
[docs] def exported_program(self, method_name: str = "forward") -> ExportedProgram: """ Returns the ExportedProgram specified by 'method_name'. """ return self._edge_programs[method_name]
[docs] def transform( self, passes: Union[Sequence[PassType], Dict[str, Sequence[PassType]]], check_ir_validity: bool = True, # We should also probably add check_edge_ops here as well ) -> "EdgeProgramManager": """ Transforms the program according to the provided passes. Args: passes: The passes can either be a list of passes, or a dictionary mapping method names to lists of passes. If it is just a list of passes, all methods in the given EdgeProgramManager will be transformed with the provided passes. If it is a dictionary, only method names specified in the dictionary will be transformed with their corresponding passes. Returns: EdgeProgramManager: A copy of the calling EdgeProgramManager with the transformations applied. """ new_programs: Dict[str, ExportedProgram] = {} if isinstance(passes, dict): for name, program in self._edge_programs.items(): if name in passes.keys(): new_programs[name] = _transform(program, *passes[name]) EXIREdgeDialectVerifier(enable=check_ir_validity)( new_programs[name].graph_module ) else: new_programs[name] = copy.deepcopy(program) else: # apply passes to every method for name, program in self._edge_programs.items(): new_programs[name] = _transform(program, *passes) EXIREdgeDialectVerifier(enable=check_ir_validity)( new_programs[name].graph_module ) config = EdgeCompileConfig(_check_ir_validity=check_ir_validity) return EdgeProgramManager( new_programs, copy.deepcopy(self._config_methods), config )
[docs] def to_backend( self, partitioner: Union[Partitioner, Dict[str, Partitioner]] ) -> "EdgeProgramManager": """ Returns a semantically-equivalent program to the one given as input, but with portions of each program in the EdgeProgramManager targeted for delegation as determined by the partitioner. Args: partitioner: The partitioner can either be a Partitioner subclass instance, or a dictionary mapping method names to Partitioner subclass instance. If it is a Partitioner subclass, all programs in the given EdgeProgramManager will be lowered using the given partitioner. If it is a dictionary, only method names specified in the dictionary will be lowered with the given partitioner. The Partitioner subclass instance is in charge with tagging portions of the input program for delegation. A valid partitioner must return PartitionerResult including valid partition_tags: Dict[str, DelegationSpec], where each key is a tag name and the nodes with same tag will be fused a one subgraph and delegated to backend specififed in delegation spec. Returns: EdgeProgramManager: A copy of the calling EdgeProgramManager with the specified subgraphs lowered. """ new_edge_programs: Dict[str, ExportedProgram] = {} if isinstance(partitioner, dict): for name, program in self._edge_programs.items(): if name in partitioner.keys(): new_edge_programs[name] = to_backend(program, partitioner[name]) else: new_edge_programs[name] = copy.deepcopy(program) else: # apply partitioner to every method for name, program in self._edge_programs.items(): new_edge_programs[name] = to_backend(program, partitioner) config = EdgeCompileConfig(_check_ir_validity=False) return EdgeProgramManager( new_edge_programs, copy.deepcopy(self._config_methods), config )
[docs] def to_executorch( self, config: Optional[ExecutorchBackendConfig] = None ) -> "ExecutorchProgramManager": """ Transforms the program to the ExecuTorch backend. Args: config: An optional argument used to provide greater control over the transformation to the ExecuTorch backend. Returns: ExecutorchProgramManager: A manager representing the state of the EdgeProgramManager after it has been transformed to the ExecuTorch backend. """ config = config if config else ExecutorchBackendConfig() execution_programs: Dict[str, ExportedProgram] = {} for name, program in self._edge_programs.items(): program = unsafe_remove_auto_functionalized_pass(program) gm, new_signature = insert_write_back_for_buffers_pass(program) new_gm = program.graph_module for p in edge_to_executorch_passes(config): new_gm_res = p(new_gm) assert new_gm_res is not None new_gm = new_gm_res.graph_module if isinstance(p, SpecPropPass): # Note that this is a hacky way to get around the fact that # placeholder nodes corresponding to the parameters of the graph module # shall not participate in memory planning. It increases runtime memory # footprint. # Proper way would be to have ExportPass work with ExportedProgram # instead of GraphModule. This is because ExportPass should work # on top of the export artifact of torch.export whichi s ExportedProgram. # Working with GraphModule does not provide all the information contained # in the ExportedProgram # TODO(who?) p.update_placeholder_tensor_specs(program, new_gm) # TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work if hasattr(config.memory_planning_pass, "run"): new_gm_res = config.memory_planning_pass.run( # pyre-ignore[16] new_gm, new_signature ) else: new_gm_res = config.memory_planning_pass(new_gm) # pyre-ignore[19] assert new_gm_res is not None new_gm = new_gm_res.graph_module _copy_module(program.graph_module, new_gm) execution_programs[name] = program return ExecutorchProgramManager( execution_programs, self._config_methods, config )
[docs]class ExecutorchProgramManager: """ Package of one or more `ExportedPrograms` in Execution dialect. Designed to simplify lowering to ExecuTorch. See: https://pytorch.org/executorch/stable/ir-exir.html When the ExecutorchProgramManager is constructed the ExportedPrograms in execution dialect are used to form the executorch binary (in a process called emission) and then serialized to a buffer. Manages the final link in the lowering chain of ATen -> Edge -> ExecuTorch. """ def __init__( self, execution_programs: Dict[str, ExportedProgram], config_methods: Optional[Dict[str, Any]] = None, backend_config: Optional[ExecutorchBackendConfig] = None, ): """ End users should not call this constructor directly. Instead, they should use :func:'to_executorch' to construct an ExecutorchProgramManager. Constructs an ExecutorchProgramManager from a set of exported programs in execution dialect. Args: execution_programs: A dictionary of method name to the corresponding ExportedProgram. config_methods: A dictionary of method name to the config value returned by that method in eager mode. backend_config: An optional argument used to provide greater control over the emission and serialization. """ # Set up methods self._execution_programs: Dict[str, ExportedProgram] = execution_programs self._config_methods: Optional[Dict[str, Any]] = config_methods backend_config = backend_config or ExecutorchBackendConfig() # Emit methods self._emitter_output: EmitterOutput = emit_program( self._execution_programs, backend_config.emit_stacktrace, self._config_methods, ) # Serialize emitter output, ready to be written to a file. self._pte_data: Cord = _serialize_pte_binary( program=self._emitter_output.program, extract_delegate_segments=backend_config.extract_delegate_segments, extract_constant_segment=backend_config.extract_constant_segment, segment_alignment=backend_config.segment_alignment, constant_tensor_alignment=backend_config.constant_tensor_alignment, delegate_alignment=backend_config.delegate_alignment, ) self._buffer: Optional[bytes] = None @property def methods(self) -> Set[str]: """ Returns the set of methods in this ExecutorchProgramManager. """ return set(self._execution_programs.keys()) @property def config_methods(self) -> Set[str]: """ Returns the set of config methods in this ExecutorchProgramManager. """ return set(self._config_methods.keys()) if self._config_methods else set()
[docs] def exported_program(self, method_name: str = "forward") -> ExportedProgram: """ Returns the ExportedProgram specified by 'method_name'. """ return self._execution_programs[method_name]
[docs] def dump_executorch_program(self, verbose: bool = False) -> None: """ Prints the ExecuTorch binary in a human readable format. Args: verbose (bool): If False prints the binary in a condensed format. If True prints the binary 1-1 with the specification in the schema. """ if verbose: pretty_print(self._emitter_output.program) else: print_program(self._emitter_output.program)
@property def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]: return self._emitter_output.debug_handle_map @property def delegate_map( self, ) -> Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]: return self._emitter_output.method_to_delegate_debug_id_map @property def executorch_program(self) -> Program: """ Returns the object that represents the ExecuTorch binary before serialization. """ return self._emitter_output.program @property def buffer(self) -> bytes: """Returns the serialized ExecuTorch binary as a byte string. Note that the call to `buffer` may allocate a very large amount of contiguous memory, depending on the model size. If writing to a file, use `write_to_file` which won't incur additional copies. """ # TODO(T181494963): update pybinding to remove buffer cache, which can consume large # amounts of memory longer than necessary. if self._buffer is None: self._buffer = bytes(self._pte_data) return self._buffer def write_to_file(self, open_file: io.BufferedIOBase) -> None: """ Writes the serialized ExecuTorch binary to the file at `open_file`. Prefer to use this over `buffer`, as it writes to file without copying into a contiguous block of memory first, reducing the peak memory usage. """ self._pte_data.write_to_file(open_file)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources