Shortcuts

Source code for executorch.sdk.etrecord._etrecord

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import json
import pickle
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from zipfile import BadZipFile, ZipFile

from executorch import exir
from executorch.exir import (
    EdgeProgramManager,
    ExecutorchProgram,
    ExecutorchProgramManager,
    ExirExportedProgram,
    ExportedProgram,
)
from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap

from executorch.exir.serde.export_serialize import SerializedArtifact
from executorch.exir.serde.serialize import deserialize, serialize
from executorch.sdk.bundled_program.core import BundledProgram

from executorch.sdk.bundled_program.schema.bundled_program_schema import Value

ProgramOutput = List[Value]

try:
    # breaking change introduced in python 3.11
    # pyre-ignore
    from enum import StrEnum
except ImportError:
    from enum import Enum

    class StrEnum(str, Enum):
        pass


class ETRecordReservedFileNames(StrEnum):
    ETRECORD_IDENTIFIER = "ETRECORD_V0"
    EDGE_DIALECT_EXPORTED_PROGRAM = "edge_dialect_exported_program"
    ET_DIALECT_GRAPH_MODULE = "et_dialect_graph_module"
    DEBUG_HANDLE_MAP_NAME = "debug_handle_map"
    DELEGATE_MAP_NAME = "delegate_map"
    REFERENCE_OUTPUTS = "reference_outputs"


@dataclass
class ETRecord:
    edge_dialect_program: Optional[ExportedProgram] = None
    graph_map: Optional[Dict[str, ExportedProgram]] = None
    _debug_handle_map: Optional[Dict[int, Union[int, List[int]]]] = None
    _delegate_map: Optional[
        Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]
    ] = None
    _reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None


def _handle_exported_program(
    etrecord_zip: ZipFile, module_name: str, method_name: str, ep: ExportedProgram
) -> None:
    assert isinstance(ep, ExportedProgram)
    serialized_artifact = serialize(ep)
    assert isinstance(serialized_artifact.exported_program, bytes)
    etrecord_zip.writestr(
        f"{module_name}/{method_name}", serialized_artifact.exported_program
    )
    etrecord_zip.writestr(
        f"{module_name}/{method_name}_state_dict", serialized_artifact.state_dict
    )


def _handle_export_module(
    etrecord_zip: ZipFile,
    export_module: Union[
        ExirExportedProgram,
        EdgeProgramManager,
        ExportedProgram,
    ],
    module_name: str,
) -> None:
    if isinstance(export_module, ExirExportedProgram):
        _handle_exported_program(
            etrecord_zip, module_name, "forward", export_module.exported_program
        )
    elif isinstance(export_module, ExportedProgram):
        _handle_exported_program(etrecord_zip, module_name, "forward", export_module)
    elif isinstance(
        export_module,
        (EdgeProgramManager, exir.program._program.EdgeProgramManager),
    ):
        for method in export_module.methods:
            _handle_exported_program(
                etrecord_zip,
                module_name,
                method,
                export_module.exported_program(method),
            )
    else:
        raise RuntimeError(f"Unsupported graph module type. {type(export_module)}")


def _handle_edge_dialect_exported_program(
    etrecord_zip: ZipFile, edge_dialect_exported_program: ExportedProgram
) -> None:
    serialized_artifact = serialize(edge_dialect_exported_program)
    assert isinstance(serialized_artifact.exported_program, bytes)

    etrecord_zip.writestr(
        ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM,
        serialized_artifact.exported_program,
    )
    etrecord_zip.writestr(
        f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}_state_dict",
        serialized_artifact.state_dict,
    )


def _get_reference_outputs(
    bundled_program: BundledProgram,
) -> Dict[str, List[ProgramOutput]]:
    """
    Extracts out the expected outputs from the bundled program, keyed by the method names.
    """
    reference_outputs = {}
    for method_test_suite in bundled_program.method_test_suites:
        reference_outputs[method_test_suite.method_name] = []
        for test_case in method_test_suite.test_cases:
            if not test_case.expected_outputs:
                raise ValueError(
                    f"Missing at least one set of expected outputs for method {method_test_suite.method_name}."
                )
            reference_outputs[method_test_suite.method_name].append(
                test_case.expected_outputs
            )
    return reference_outputs


[docs]def generate_etrecord( etrecord_path: str, edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram], executorch_program: Union[ ExecutorchProgram, ExecutorchProgramManager, BundledProgram, ], export_modules: Optional[ Dict[ str, Union[ ExportedProgram, ExirExportedProgram, EdgeProgramManager, ], ] ] = None, ) -> None: """ Generates an `ETRecord` from the given objects, serializes it and saves it to the given path. The objects that will be serialized to an `ETRecord` are all the graph modules present in the `export_modules` dict, the graph module present in the edge dialect program object, and also the graph module present in the ExecuTorch program object, which is the closest graph module representation of what is eventually run on the device. In addition to all the graph modules, we also serialize the program buffer, which the users can provide to the ExecuTorch runtime to run the model, and the debug handle map for SDK tooling usage. Args: etrecord_path: Path to where the `ETRecord` file will be saved to. edge_dialect_program: `EdgeProgramManager` for this model returned by the call to to_edge() executorch_program: The ExecuTorch program for this model returned by the call to `to_executorch()` or the `BundledProgram` of this model export_modules[Optional]: **Should be ignored by OSS users**. A dictionary of graph modules with the key being the user provided name and the value being the corresponding exported module. The exported graph modules can be either the output of `torch.export()` or `exir.to_edge()`. Returns: None """ etrecord_zip = ZipFile(etrecord_path, "w") # Write the magic file identifier that will be used to verify that this file # is an etrecord when it's used later in the SDK tooling. etrecord_zip.writestr(ETRecordReservedFileNames.ETRECORD_IDENTIFIER, "") if export_modules is not None: for module_name, export_module in export_modules.items(): contains_reserved_name = any( reserved_name in module_name for reserved_name in ETRecordReservedFileNames ) if contains_reserved_name: raise RuntimeError( f"The name {module_name} provided in the export_modules dict is a reserved name in the ETRecord namespace." ) _handle_export_module(etrecord_zip, export_module, module_name) if isinstance( edge_dialect_program, (EdgeProgramManager, exir.program._program.EdgeProgramManager), ): _handle_edge_dialect_exported_program( etrecord_zip, edge_dialect_program.exported_program(), ) elif isinstance(edge_dialect_program, ExirExportedProgram): _handle_edge_dialect_exported_program( etrecord_zip, edge_dialect_program.exported_program, ) else: raise RuntimeError( f"Unsupported type of edge_dialect_program passed in {type(edge_dialect_program)}." ) # When a BundledProgram is passed in, extract the reference outputs and save in a file if isinstance(executorch_program, BundledProgram): reference_outputs = _get_reference_outputs(executorch_program) etrecord_zip.writestr( ETRecordReservedFileNames.REFERENCE_OUTPUTS, # @lint-ignore PYTHONPICKLEISBAD pickle.dumps(reference_outputs), ) executorch_program = executorch_program.executorch_program etrecord_zip.writestr( ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME, json.dumps(executorch_program.debug_handle_map), ) etrecord_zip.writestr( ETRecordReservedFileNames.DELEGATE_MAP_NAME, json.dumps(executorch_program.delegate_map), )
def parse_etrecord(etrecord_path: str) -> ETRecord: """ Parses an `ETRecord` file and returns an `ETRecord` object that contains the deserialized graph modules, program buffer, and a debug handle map. In the graph map in the returned `ETRecord` object if a model with multiple entry points was provided originally by the user during `ETRecord` generation then each entry point will be stored as a separate graph module in the `ETRecord` object with the name being `the original module name + "/" + the name of the entry point`. Args: etrecord_path: Path to the `ETRecord` file. Returns: `ETRecord` object. """ try: etrecord_zip = ZipFile(etrecord_path, "r") except BadZipFile: raise RuntimeError("Invalid etrecord file passed in.") file_list = etrecord_zip.namelist() if ETRecordReservedFileNames.ETRECORD_IDENTIFIER not in file_list: raise RuntimeError( "ETRecord identifier missing from etrecord file passed in. Either an invalid file was passed in or the file is corrupt." ) graph_map: Dict[str, ExportedProgram] = {} debug_handle_map = None delegate_map = None edge_dialect_program = None reference_outputs = None serialized_exported_program_files = set() serialized_state_dict_files = set() for entry in file_list: if entry == ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME: debug_handle_map = json.loads( etrecord_zip.read(ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME) ) elif entry == ETRecordReservedFileNames.DELEGATE_MAP_NAME: delegate_map = json.loads( etrecord_zip.read(ETRecordReservedFileNames.DELEGATE_MAP_NAME) ) elif entry == ETRecordReservedFileNames.ETRECORD_IDENTIFIER: continue elif entry == ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM: serialized_artifact = SerializedArtifact( etrecord_zip.read( ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM ), etrecord_zip.read(f"{entry}_state_dict"), b"", ) edge_dialect_program = deserialize(serialized_artifact) elif entry == ETRecordReservedFileNames.REFERENCE_OUTPUTS: # @lint-ignore PYTHONPICKLEISBAD reference_outputs = pickle.loads( etrecord_zip.read(ETRecordReservedFileNames.REFERENCE_OUTPUTS) ) else: if entry.endswith("state_dict"): serialized_state_dict_files.add(entry) else: serialized_exported_program_files.add(entry) for serialized_file in serialized_exported_program_files: serialized_state_dict_file = f"{serialized_file}_state_dict" assert ( serialized_state_dict_file in serialized_state_dict_files ), f"Could not find corresponding state dict file for {serialized_file}." serialized_artifact = SerializedArtifact( etrecord_zip.read(serialized_file), etrecord_zip.read(serialized_state_dict_file), b"", ) graph_map[serialized_file] = deserialize(serialized_artifact) return ETRecord( edge_dialect_program=edge_dialect_program, graph_map=graph_map, _debug_handle_map=debug_handle_map, _delegate_map=delegate_map, _reference_outputs=reference_outputs, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources