Source code for executorch.sdk.bundled_program.config
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from dataclasses import dataclass
from typing import get_args, List, Optional, Sequence, Union
import torch
from torch.utils._pytree import tree_flatten
from typing_extensions import TypeAlias
"""
The data types currently supported for element to be bundled. It should be
consistent with the types in bundled_program.schema.Value.
"""
ConfigValue: TypeAlias = Union[
torch.Tensor,
int,
bool,
float,
]
"""
The data type of the input for method single execution.
"""
MethodInputType: TypeAlias = Sequence[ConfigValue]
"""
The data type of the output for method single execution.
"""
MethodOutputType: TypeAlias = Sequence[torch.Tensor]
"""
All supported types for input/expected output of MethodTestCase.
Namedtuple is also supported and listed implicity since it is a subclass of tuple.
"""
# pyre-ignore
DataContainer: TypeAlias = Union[list, tuple, dict]
class MethodTestCase:
"""Test case with inputs and expected outputs
The expected_outputs are optional and only required if the user wants to verify model outputs after execution.
"""
def __init__(
self,
inputs: MethodInputType,
expected_outputs: Optional[MethodOutputType] = None,
) -> None:
"""Single test case for verifying specific method
Args:
input: All inputs required by eager_model with specific inference method for one-time execution.
It is worth mentioning that, although both bundled program and ET runtime apis support setting input
other than `torch.tensor` type, only the input in `torch.tensor` type will be actually updated in
the method, and the rest of the inputs will just do a sanity check if they match the default value in method.
expected_output: Expected output of given input for verification. It can be None if user only wants to use the test case for profiling.
Returns:
self
"""
# TODO(gasoonjia): Update type check logic.
# pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sannity check.
self.inputs: List[ConfigValue] = self._flatten_and_sanity_check(inputs)
self.expected_outputs: List[ConfigValue] = []
if expected_outputs is not None:
# pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sannity check.
self.expected_outputs = self._flatten_and_sanity_check(expected_outputs)
def _flatten_and_sanity_check(
self, unflatten_data: DataContainer
) -> List[ConfigValue]:
"""Flat the given data and check its legality
Args:
unflatten_data: Data needs to be flatten.
Returns:
flatten_data: Flatten data with legal type.
"""
flatten_data, _ = tree_flatten(unflatten_data)
for data in flatten_data:
assert isinstance(
data,
get_args(ConfigValue),
), "The type of input {} with type {} is not supported.\n".format(
data, type(data)
)
assert not isinstance(
data,
type(None),
), "The input {} should not be in null type.\n".format(data)
return flatten_data
[docs]@dataclass
class MethodTestSuite:
"""All test info related to verify method
Attributes:
method_name: Name of the method to be verified.
test_cases: All test cases for verifying the method.
"""
method_name: str
test_cases: Sequence[MethodTestCase]