GraphInfo¶
- class torch.onnx.verification.GraphInfo(graph, input_args, params_dict, export_options=<factory>, id='', _EXCLUDED_NODE_KINDS=frozenset({'aten::ScalarImplicit', 'prim::Constant', 'prim::ListConstruct'}))[source][source]¶
GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph.
- all_mismatch_leaf_graph_info()[source][source]¶
Return a list of all leaf GraphInfo objects that have mismatch.
- Return type
- essential_node_count()[source][source]¶
Return the number of nodes in the subgraph excluding those in _EXCLUDED_NODE_KINDS.
- Return type
- essential_node_kinds()[source][source]¶
Return the set of node kinds in the subgraph excluding those in _EXCLUDED_NODE_KINDS.
- export_repro(repro_dir=None, name=None)[source][source]¶
Export the subgraph to ONNX along with the input/output data for repro.
The repro directory will contain the following files:
dir ├── test_<name> │ ├── model.onnx │ └── test_data_set_0 │ ├── input_0.pb │ ├── input_1.pb │ ├── output_0.pb │ └── output_1.pb
- find_mismatch(options=None)[source][source]¶
Find all mismatches between the TorchScript IR graph and the exported onnx model.
Binary searches the model graph to find the minimal subgraph that exhibits the mismatch. A GraphInfo object is created for each subgraph, recording the test inputs and export options, as well as the validation results.
- Parameters
options (VerificationOptions | None) – The verification options.
- find_partition(id)[source][source]¶
Find the GraphInfo object with the given id.
- Return type
GraphInfo | None
- has_mismatch()[source][source]¶
Return True if the subgraph has output mismatch between torch and ONNX.
- Return type
- pretty_print_mismatch(graph=False)[source][source]¶
Pretty print details of the mismatch between torch and ONNX.
- Parameters
graph (bool) – If True, print the ATen JIT graph and ONNX graph.
- pretty_print_tree()[source][source]¶
Pretty print GraphInfo tree.
Each node represents a subgraph, showing the number of nodes in the subgraph and a check mark if the subgraph has output mismatch between torch and ONNX.
The id of the subgraph is shown under the node. The GraphInfo object for any subgraph can be retrieved by calling graph_info.find_partition(id).
Example:
==================================== Tree: ===================================== 5 X __2 X __1 ✓ id: | id: 0 | id: 00 | | | |__1 X (aten::relu) | id: 01 | |__3 X __1 ✓ id: 1 | id: 10 | |__2 X __1 X (aten::relu) id: 11 | id: 110 | |__1 ✓ id: 111 =========================== Mismatch leaf subgraphs: =========================== ['01', '110'] ============================= Mismatch node kinds: ============================= {'aten::relu': 2}
- verify_export(options)[source][source]¶
Verify the export from TorchScript IR graph to ONNX.
Export the TorchScript IR graph to ONNX, with the inputs, parameters and export options recorded in this object. Then verify the exported ONNX graph against the original TorchScript IR graph under the provided verification options.
- Parameters
options (VerificationOptions) – The verification options.
- Returns
The AssertionError raised during the verification. Returns None if no error is raised. onnx_graph: The exported ONNX graph in TorchScript IR format. onnx_outs: The outputs from running exported ONNX model under the onnx backend in options. pt_outs: The outputs from running the TorchScript IR graph.
- Return type
error