# mypy: allow-untyped-defs"""Functions to verify exported ONNX model is functionally equivalent to original PyTorch model.ONNX Runtime is required, and is used as the ONNX backend for export verification."""from__future__importannotationsimportcontextlibimportcopyimportdataclassesimportdatetimeimportdifflibimportenumimportfunctoolsimportioimportitertoolsimportosimporttempfileimportwarningsfromtypingimportAny,Callable,Collection,Mapping,Sequence,Tuple,Unionimportnumpyasnpimportnumpy.typingasnptimporttorchimporttorch._C._onnxas_C_onnxfromtorchimport_Cfromtorch.onnximport_constants,_experimental,utilsfromtorch.onnx._globalsimportGLOBALSfromtorch.onnx._internalimportonnx_proto_utilsfromtorch.typesimportNumber_ORT_PROVIDERS=("CPUExecutionProvider",)_NumericType=Union[Number,torch.Tensor,np.ndarray]_ModelType=Union[torch.nn.Module,torch.jit.ScriptModule]_InputArgsType=Union[torch.Tensor,Tuple[Any,...]]_InputKwargsType=Mapping[str,Any]_OutputsType=Union[Sequence[_NumericType],Sequence]classOnnxBackend(enum.Enum):"""Enum class for ONNX backend used for export verification."""REFERENCE="ONNXReferenceEvaluator"ONNX_RUNTIME_CPU="CPUExecutionProvider"ONNX_RUNTIME_CUDA="CUDAExecutionProvider"
[docs]@dataclasses.dataclassclassVerificationOptions:"""Options for ONNX export verification. Attributes: flatten: If True, unpack nested list/tuple/dict inputs into a flattened list of Tensors for ONNX. Set this to False if nested structures are to be preserved for ONNX, which is usually the case with exporting ScriptModules. Default True. ignore_none: Whether to ignore None type in torch output, which is usually the case with tracing. Set this to False, if torch output should keep None type, which is usually the case with exporting ScriptModules. Default to True. check_shape: Whether to check the shapes between PyTorch and ONNX Runtime outputs are exactly the same. Set this to False to allow output shape broadcasting. Default to True. check_dtype: Whether to check the dtypes between PyTorch and ONNX Runtime outputs are consistent. Default to True. backend: ONNX backend for verification. Default to OnnxBackend.ONNX_RUNTIME_CPU. rtol: relative tolerance in comparison between ONNX and PyTorch outputs. atol: absolute tolerance in comparison between ONNX and PyTorch outputs. remained_onnx_input_idx: If provided, only the specified inputs will be passed to the ONNX model. Supply a list when there are unused inputs in the model. Since unused inputs will be removed in the exported ONNX model, supplying all inputs will cause an error on unexpected inputs. This parameter tells the verifier which inputs to pass into the ONNX model. acceptable_error_percentage: acceptable percentage of element mismatches in comparison. It should be a float of value between 0.0 and 1.0. """flatten:bool=Trueignore_none:bool=Truecheck_shape:bool=Truecheck_dtype:bool=Truebackend:OnnxBackend=OnnxBackend.ONNX_RUNTIME_CPUrtol:float=1e-3atol:float=1e-7remained_onnx_input_idx:Sequence[int]|None=Noneacceptable_error_percentage:float|None=None
def_flatten_tuples(elem):flattened=[]fortinelem:ifisinstance(t,tuple):flattened.extend(_flatten_tuples(t))else:flattened.append(t)returnflattened# TODO(justinchuby): Add type checking by narrowing down the return type when input is Nonedef_to_numpy(elem)->list|npt.NDArray:ifisinstance(elem,torch.Tensor):ifelem.requires_grad:returnelem.detach().cpu().numpy()else:returnelem.cpu().numpy()elifisinstance(elem,(list,tuple)):return[_to_numpy(inp)forinpinelem]elifisinstance(elem,(bool,int,float)):returnnp.array(elem)elifisinstance(elem,dict):flattened=[]forkinelem:flattened.extend([_to_numpy(k),_to_numpy(elem[k])])returnflattenedreturnelemdef_inline_flatten_list(inputs,res_list)->list:foriininputs:res_list.append(i)ifnotisinstance(i,(list,tuple))else_inline_flatten_list(i,res_list)returnres_listdef_unpack_to_numpy(values,cast_onnx_accepted=True)->list:value_unpacked=[]forvalueinvalues:value_unpacked.extend(utils.unpack_quantized_tensor(value,cast_onnx_accepted=cast_onnx_accepted))return[_to_numpy(v)forvinvalue_unpacked]def_run_onnx(onnx_session,inputs)->_OutputsType:kw_inputs={}ifinputsandisinstance(inputs[-1],dict):kw_inputs=inputs[-1]inputs=inputs[:-1]inputs=_unpack_to_numpy(_flatten_tuples(inputs))ort_inputs={}forinput_name,inputinkw_inputs.items():ort_inputs[input_name]=_to_numpy(input)inputs=_to_numpy(inputs)ifhasattr(onnx_session,"get_inputs"):# onnxruntime.InferenceSessioninput_names=[i.nameforiinonnx_session.get_inputs()]elifhasattr(onnx_session,"input_names"):# onnx.reference.ReferenceEvaluatorinput_names=onnx_session.input_nameselse:raiseValueError(f"Unknown ONNX backend type: {type(onnx_session)}.")fori,inputinenumerate(inputs):ifi==len(input_names)orinput_names[i]inort_inputs:raiseValueError(f"got too many positional inputs. inputs: {inputs}. kw_inputs: {kw_inputs}. "f"input names: {input_names}.")ort_inputs[input_names[i]]=inputonnx_outs=onnx_session.run(None,ort_inputs)returnonnx_outsdef_ort_session(model:str|io.BytesIO,ort_providers:Sequence[str]=_ORT_PROVIDERS):try:importonnxruntime# type: ignore[import]exceptImportErrorase:raiseImportError("onnxruntime is required for export verification.")fromeifort_providersisNone:ort_providers=_ORT_PROVIDERSsession_options=onnxruntime.SessionOptions()# suppress ort warnings.# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.session_options.log_severity_level=3ort_session=onnxruntime.InferenceSession(modelifisinstance(model,str)elsemodel.getvalue(),session_options,providers=ort_providers,)returnort_sessiondef_onnx_reference_evaluator_session(model:str|io.BytesIO):try:importonnxfromonnximportreferenceasonnx_reference# type: ignore[attr-defined]exceptImportErrorasexc:raiseImportError("onnx >= 1.13 is required for reference evaluator.")fromexcproto=(onnx.load(model)# type: ignore[attr-defined]ifisinstance(model,str)elseonnx.load_model_from_string(model.getvalue())# type: ignore[attr-defined])onnx_session=onnx_reference.ReferenceEvaluator(proto)returnonnx_sessiondef_onnx_backend_session(model:str|io.BytesIO,backend:OnnxBackend):ifbackend==OnnxBackend.REFERENCE:onnx_session=_onnx_reference_evaluator_session(model)elifbackendin{OnnxBackend.ONNX_RUNTIME_CPU,OnnxBackend.ONNX_RUNTIME_CUDA}:onnx_session=_ort_session(model,(backend.value,))else:raiseValueError(f"Unsupported backend: {backend}")returnonnx_sessiondef_compare_onnx_pytorch_outputs_in_np(onnx_outs:_OutputsType,pt_outs:_OutputsType,options:VerificationOptions,):assert(len(onnx_outs)==len(pt_outs)),f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})"acceptable_error_percentage=options.acceptable_error_percentageifacceptable_error_percentageand(acceptable_error_percentage>1.0oracceptable_error_percentage<0.0):raiseValueError("If set, acceptable_error_percentage should be between 0.0 and 1.0")forort_out,pt_outinzip(onnx_outs,pt_outs):try:# TODO: Remove `check_shape` option once every shape inconsistent issue is addressed.ifnotoptions.check_shape:# Allow different but broadcastable output shapes.ort_out,pt_out=np.broadcast_arrays(ort_out,pt_out)torch.testing.assert_close(ort_out,pt_out,rtol=options.rtol,atol=options.atol,check_dtype=options.check_dtype,equal_nan=True,)exceptAssertionErrorase:ifacceptable_error_percentage:error_percentage=1-np.sum(np.isclose(ort_out,pt_out,rtol=options.rtol,atol=options.atol))/np.prod(ort_out.shape)iferror_percentage<=acceptable_error_percentage:warnings.warn(f"Suppressed AssertionError:\n{e}.\n"f"Error percentage {error_percentage} "f"within acceptable range {acceptable_error_percentage}.")continueifort_out.dtype==np.uint8orort_out.dtype==np.int8:warnings.warn("ONNX output is quantized")ifpt_out.dtype==np.uint8orpt_out.dtype==np.int8:warnings.warn("PyTorch output is quantized")raisedef_compare_onnx_pytorch_outputs(onnx_outs:_OutputsType,pt_outs:Any,options:VerificationOptions,):""" Compare ONNX and PyTorch outputs. Args: onnx_outs: outputs from ONNX backend. pt_outs: outputs from PyTorch. options: options for verification. Raises: AssertionError: if outputs from ONNX model and PyTorch model are not equal up to specified precision. ValueError: if arguments provided are invalid. """ifoptions.ignore_none:# torch.jit._flatten filters None typept_outs,_=torch.jit._flatten(pt_outs)else:pt_outs=_inline_flatten_list([pt_outs],[])pt_outs_np=_unpack_to_numpy(pt_outs,cast_onnx_accepted=False)onnx_outs=_inline_flatten_list(onnx_outs,[])_compare_onnx_pytorch_outputs_in_np(onnx_outs,pt_outs_np,options)def_prepare_input_for_pytorch(args,kwargs):"""Prepare input for PyTorch model execution. Any future changes/formatting to the input before dispatching to the PyTorch model should be made in this function. Args: args: positional arguments for PyTorch model forward method. kwargs: keyword arguments for PyTorch model forward method. Returns: args: positional arguments for PyTorch model forward method. kwargs: keyword arguments for PyTorch model forward method. """ifisinstance(args,(torch.Tensor,dict)):args=(args,)# In-place operators will update input tensor data as well.# Thus inputs are replicated before every forward call.args=copy.deepcopy(args)ifkwargs:kwargs=copy.deepcopy(kwargs)else:kwargs={}returnargs,kwargsdef_prepare_input_for_export(args,kwargs):"""Prepare input for ONNX model export. Any future changes/formatting to the input before dispatching to the :func:`torch.onnx.export` api should be made in this function. Args: args: positional arguments for PyTorch model forward method. kwargs: keyword arguments for PyTorch model forward method. Returns: onnx_inputs: positional arguments for ONNX model export, as `args` in :func:`torch.onnx.export`. """args,kwargs=_prepare_input_for_pytorch(args,kwargs)ifnotkwargsandlen(args)>0andisinstance(args[-1],dict):onnx_inputs=args+({},)elifkwargs:onnx_inputs=args+(kwargs,)else:onnx_inputs=argsreturnonnx_inputsdef_prepare_input_for_onnx(args,kwargs,remained_onnx_input_idx:Sequence[int]|None,flatten:bool):"""Prepare input for ONNX model execution in ONNX backend. Any future changes/formatting to the input before dispatching to the ONNX backend run should be made in this function. Args: args: positional arguments for PyTorch model forward method. kwargs: keyword arguments for PyTorch model forward method. remained_onnx_input_idx: indices of inputs to be used for ONNX model execution. flatten: whether to flatten the input before dispatching to the ONNX model execution. Returns: onnx_inputs: positional arguments for ONNX model execution in ONNX backend. """onnx_inputs=_prepare_input_for_export(args,kwargs)ifflatten:onnx_inputs,_=torch.jit._flatten(onnx_inputs)elifonnx_inputsandonnx_inputs[-1]=={}:# Handle empty kwargs (normally removed by flatten).onnx_inputs=onnx_inputs[:-1]ifremained_onnx_input_idxisnotNone:return[onnx_inputs[i]foriinremained_onnx_input_idx]else:returnonnx_inputsdef_try_clone_model(model):"""Used for preserving original model in case forward mutates model states."""try:returncopy.deepcopy(model)exceptException:warnings.warn("Failed to clone model. Model state might be mutated during verification.")returnmodeldef_compare_onnx_pytorch_model(pt_model:_ModelType,onnx_model_f:str|io.BytesIO,input_args:_InputArgsType,input_kwargs:_InputKwargsType|None,additional_test_inputs:Sequence[_InputArgsType]|None,options:VerificationOptions,):"""Compare outputs from ONNX model runs with outputs from PyTorch model runs. Args: pt_model: PyTorch model. onnx_model_f: ONNX model file path or file-like object. input_args: positional arguments for PyTorch model forward method. input_kwargs: keyword arguments for PyTorch model forward method. additional_test_inputs: additional positional arguments for PyTorch model forward method. options: options for verification. Raises: AssertionError: if outputs from ONNX model and PyTorch model are not equal up to specified precision. """onnx_session=_onnx_backend_session(onnx_model_f,options.backend)defcompare_onnx_pytorch_model_with_input(input_args,input_kwargs):pt_args,pt_kwargs=_prepare_input_for_pytorch(input_args,input_kwargs)# TODO: remove this and treat mutating model separately. See #77679pt_model_copy=_try_clone_model(pt_model)pt_outs=pt_model_copy(*pt_args,**pt_kwargs)onnx_inputs=_prepare_input_for_onnx(input_args,input_kwargs,options.remained_onnx_input_idx,options.flatten)onnx_outs=_run_onnx(onnx_session,onnx_inputs)_compare_onnx_pytorch_outputs(onnx_outs=onnx_outs,pt_outs=pt_outs,options=options,)compare_onnx_pytorch_model_with_input(input_args,input_kwargs)ifadditional_test_inputs:fortest_input_argsinadditional_test_inputs:compare_onnx_pytorch_model_with_input(test_input_args,{})class_GraphDiff:"""A class to represent the difference between two graphs."""def__init__(self,graph_a:_C.Graph,graph_b:_C.Graph):"""Construct a _GraphDiff object. Args: graph_a (_C.Graph): First graph to compare. graph_b (_C.Graph): Second graph to compare. """self.graph_a=graph_aself.graph_b=graph_bdef__str__(self):"""See function :func:`diff_report`."""returnself.diff_report()def_indent(self,lines:str)->str:return"\n".join(["\t"+lineforlineinlines.splitlines()])defdiff_report(self)->str:"""Return a string representation of the graph difference. The report shows the first pair of nodes that diverges. It also shows the source location of the pair of nodes. Returns: graph_diff_report (str): A string representation of the graph difference. """graph_a=self.graph_agraph_b=self.graph_bgraph_a_str=str(graph_a)graph_b_str=str(graph_b)ifgraph_a_str==graph_b_str:return""graph_diff=difflib.ndiff(graph_a_str.splitlines(True),graph_b_str.splitlines(True))graph_diff_report=["Graph diff:",self._indent("".join(graph_diff))]fornode_a,node_binitertools.zip_longest(graph_a.nodes(),graph_b.nodes()):ifstr(node_a)!=str(node_b):graph_diff_report.append("First diverging operator:")node_diff=difflib.ndiff(str(node_a).splitlines(True),str(node_b).splitlines(True))source_printout=["node diff:",self._indent("".join(node_diff))]stack_a=node_a.sourceRange()ifnode_aelseNoneifstack_a:source_printout.extend(["Former source location:",self._indent(str(stack_a))])stack_b=node_b.sourceRange()ifnode_belseNoneifstack_b:source_printout.extend(["Latter source location:",self._indent(str(stack_b))])graph_diff_report.extend(source_printout)breakreturn"\n".join(graph_diff_report)def_check_graph_diff(model:torch.nn.Module|torch.jit.ScriptModule,test_input_groups:Sequence[tuple[tuple[Any,...],Mapping[str,Any]]],export_options:_experimental.ExportOptions,model_to_graph_func:Callable[[torch.nn.Module,tuple[Any,...],Mapping[str,Any],_experimental.ExportOptions,],_C.Graph,],)->str:"""Check if graph produced by `model_to_graph_func` is the same across `test_input_groups`. Args: model: See :func:`check_export_model_diff`. test_input_groups: See :func:`check_export_model_diff`. export_options: See :func:`check_export_model_diff`. model_to_graph_func: A function to convert a PyTorch model to a JIT IR graph. Returns: graph_diff_report (str): A string representation of the graph difference. """iflen(test_input_groups)<2:raiseValueError("Need at least two groups of test inputs to compare.")ref_jit_graph=Noneforargs,kwargsintest_input_groups:jit_graph=model_to_graph_func(model,args,kwargs,export_options)ifref_jit_graphisNone:ref_jit_graph=jit_graphcontinuegraph_diff_report=_GraphDiff(ref_jit_graph,jit_graph).diff_report()ifgraph_diff_report:returngraph_diff_reportreturn""def_traced_graph_from_model(model:torch.nn.Module|torch.jit.ScriptModule,args:tuple[Any,...],kwargs:Mapping[str,Any],export_options:_experimental.ExportOptions,)->_C.Graph:"""As part of the ONNX export steps, create a traced JIT graph from a PyTorch model. Args: model: See :func:`check_export_model_diff`. args: See :func:`check_export_model_diff`. kwargs: See :func:`check_export_model_diff`. export_options: See :func:`check_export_model_diff`. Returns: jit_graph (_C.Graph): A traced JIT graph. """training=export_options.trainingverbose=export_options.verbosewithutils.exporter_context(model,training,verbose):export_inputs=_prepare_input_for_export(args,kwargs)model=utils._pre_trace_quant_model(model,export_inputs)jit_graph,_,_,_=utils._create_jit_graph(model,export_inputs)returnjit_graphdef_onnx_graph_from_model(model:torch.nn.Module|torch.jit.ScriptModule,args:tuple[Any,...],kwargs:Mapping[str,Any],export_options:_experimental.ExportOptions,)->_C.Graph:"""As part of the ONNX export steps, export an ONNX JIT graph from a PyTorch model. Args: model: See :func:`check_export_model_diff`. args: See :func:`check_export_model_diff`. kwargs: See :func:`check_export_model_diff`. export_options: See :func:`check_export_model_diff`. Returns: onnx_graph (_C.Graph): An ONNX JIT graph. """# TODO: refactor utils.py to remove duplicated code of context setup. See #78834opset_version=export_options.opset_versionoperator_export_type=export_options.operator_export_typeexport_modules_as_functions=export_options.export_modules_as_functionstraining=export_options.trainingverbose=export_options.verbosedynamic_axes=export_options.dynamic_axesinput_names=export_options.input_namesoutput_names=export_options.output_namesifopset_versionisNone:opset_version=_constants.ONNX_DEFAULT_OPSETutils._setup_trace_module_map(model,export_modules_as_functions)ifnotoperator_export_type:operator_export_type=_C_onnx.OperatorExportTypes.ONNXGLOBALS.export_onnx_opset_version=opset_versionGLOBALS.operator_export_type=operator_export_typewithutils.exporter_context(model,training,verbose):do_constant_folding=utils._decide_constant_folding(export_options.do_constant_folding,operator_export_type,training)ifdynamic_axesisNone:dynamic_axes={}utils._validate_dynamic_axes(dynamic_axes,model,input_names,output_names)export_inputs=_prepare_input_for_export(args,kwargs)export_inputs=utils._decide_input_format(model,export_inputs)onnx_graph,_,_=utils._model_to_graph(model,export_inputs,verbose,input_names,output_names,operator_export_type,do_constant_folding,training=training,dynamic_axes=dynamic_axes,)returnonnx_graphdef_onnx_graph_from_aten_graph(graph:torch.Graph,export_options:_experimental.ExportOptions,params_dict:dict[str,Any]|None=None,)->tuple[torch.Graph,dict[str,Any]]:ifparams_dictisNone:params_dict={}operator_export_type=export_options.operator_export_typedynamic_axes=export_options.dynamic_axesor{}input_names=export_options.input_namestraining=export_options.trainingdo_constant_folding=export_options.do_constant_foldingopset_version=export_options.opset_versionor_constants.ONNX_DEFAULT_OPSETGLOBALS.export_onnx_opset_version=opset_versionGLOBALS.operator_export_type=operator_export_typedo_constant_folding=utils._decide_constant_folding(do_constant_folding,operator_export_type,training)# TODO: Below is doing aten graph to onnx. It should be abstracted as a# function in torch/onnx/utils.py.graph=graph.copy()graph=utils._optimize_graph(graph,operator_export_type,params_dict=params_dict,dynamic_axes=dynamic_axes,input_names=input_names,)iftrainingisNoneortraining==_C_onnx.TrainingMode.EVAL:params_dict=torch._C._jit_pass_onnx_eval_peephole(graph,params_dict)if(do_constant_foldingandopset_version>=_constants.ONNX_CONSTANT_FOLDING_MIN_OPSET):params_dict=_C._jit_pass_onnx_constant_fold(graph,params_dict,opset_version)_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)ifGLOBALS.onnx_shape_inference:_C._jit_pass_onnx_graph_shape_type_inference(graph,params_dict,opset_version)params_dict=_C._jit_pass_onnx_eliminate_unused_items(graph,params_dict)# For ONNX opset < 9, constants only have three data types: float16, float, double.# In this pass transform constants of other data types to float/double + cast operator.ifopset_version<9:_C._jit_pass_onnx_cast_all_constant_to_floating(graph)params_dict=_C._jit_pass_filter_non_tensor_arguments(params_dict)_C._jit_decay_packed_param_input_types(graph)_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)ifexport_options.verbose:print("ONNX graph: ",graph)returngraph,params_dictdef_onnx_proto_from_onnx_graph(onnx_graph:torch.Graph,export_options:_experimental.ExportOptions,params_dict:dict[str,Any],)->tuple[bytes,Mapping[str,bytes]]:opset_version=export_options.opset_versionor_constants.ONNX_DEFAULT_OPSETdynamic_axes=export_options.dynamic_axesor{}operator_export_type=export_options.operator_export_typeval_keep_init_as_ip=utils._decide_keep_init_as_input(export_options.keep_initializers_as_inputs,operator_export_type,opset_version,)val_add_node_names=utils._decide_add_node_names(True,operator_export_type)custom_opsets=export_options.custom_opsetsor{}proto,export_map,_,_=onnx_graph._export_onnx(# type: ignore[attr-defined]params_dict,opset_version,dynamic_axes,False,operator_export_type,notexport_options.verbose,val_keep_init_as_ip,custom_opsets,val_add_node_names,"",{},)returnproto,export_mapdefcheck_export_model_diff(model:torch.nn.Module|torch.jit.ScriptModule,test_input_groups:Sequence[tuple[tuple[Any,...],Mapping[str,Any]]],export_options:_experimental.ExportOptions|None=None,)->str:"""Verify exported model discrepancy between different groups of inputs. A graph is exported for each group of inputs. The exported graphs are then compared to each other, and discrepancies of first pair of nodes are reported. This function first checks the jit graph. If no discrepancies were found, it then checks the onnx graph. Unless otherwise specified, the jit/ONNX graph is expected to be the same, regardless of the inputs used for exporting. A discrepancy implies the graph exported is not accurate when run on other groups of inputs, which will typically results in runtime errors or mismatching output. Args: model (torch.nn.Module or torch.jit.ScriptModule): The model to be exported. test_input_groups (Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]]): A sequence of input groups to be used to export the model. Each input group is a pair of (args, kwargs). export_options (_experimental.ExportOptions, optional): An _experimental.ExportOptions object that controls the export behavior. Returns: str: A string containing the diff of the exported models. """export_options=(_experimental.ExportOptions()ifexport_optionsisNoneelseexport_options)jit_diff_report=_check_graph_diff(model,test_input_groups,export_options,_traced_graph_from_model)ifjit_diff_report:returnjit_diff_reportreturn_check_graph_diff(model,test_input_groups,export_options,_onnx_graph_from_model)defverify(model:_ModelType,input_args:_InputArgsType,input_kwargs:_InputKwargsType|None=None,do_constant_folding:bool=True,dynamic_axes:Mapping[str,Mapping[int,str]|Mapping[str,Sequence[int]]]|None=None,input_names:Sequence[str]|None=None,output_names:Sequence[str]|None=None,training:_C_onnx.TrainingMode=_C_onnx.TrainingMode.EVAL,opset_version:int|None=None,keep_initializers_as_inputs:bool=True,verbose:bool=False,fixed_batch_size:bool=False,use_external_data:bool=False,additional_test_inputs:Sequence[_InputArgsType]|None=None,options:VerificationOptions|None=None,):"""Verify model export to ONNX against original PyTorch model. Args: model (torch.nn.Module or torch.jit.ScriptModule): See :func:`torch.onnx.export`. input_args (tuple): See :func:`torch.onnx.export`. input_kwargs (dict): See :func:`torch.onnx.export`. do_constant_folding (bool, optional): See :func:`torch.onnx.export`. dynamic_axes (dict, optional): See :func:`torch.onnx.export`. input_names (list, optional): See :func:`torch.onnx.export`. output_names (list, optional): See :func:`torch.onnx.export`. training (torch.onnx.TrainingMode): See :func:`torch.onnx.export`. opset_version (int, optional): See :func:`torch.onnx.export`. keep_initializers_as_inputs (bool, optional): See :func:`torch.onnx.export`. verbose (bool, optional): See :func:`torch.onnx.export`. fixed_batch_size (bool, optional): Legacy argument, used only by rnn test cases. use_external_data (bool, optional): Explicitly specify whether to export the model with external data. additional_test_inputs (list, optional): List of tuples. Each tuple is a group of input arguments to test. Currently only *args are supported. options (_VerificationOptions, optional): A _VerificationOptions object that controls the verification behavior. Raises: AssertionError: if outputs from ONNX model and PyTorch model are not equal up to specified precision. ValueError: if arguments provided are invalid. """ifoptionsisNone:options=VerificationOptions()iftraining==torch.onnx.TrainingMode.TRAINING:model.train()eliftraining==torch.onnx.TrainingMode.EVAL:model.eval()withtorch.no_grad(),contextlib.ExitStack()asstack:model_f:str|io.BytesIO=io.BytesIO()ifuse_external_data:tmpdir_path=stack.enter_context(tempfile.TemporaryDirectory())model_f=os.path.join(tmpdir_path,"model.onnx")inputs_for_export=_prepare_input_for_export(input_args,input_kwargs)# TODO(#77679): remove this and treat mutating model separately.model_copy=_try_clone_model(model)utils._export(model,inputs_for_export,model_f,opset_version=opset_version,do_constant_folding=do_constant_folding,keep_initializers_as_inputs=keep_initializers_as_inputs,dynamic_axes=dynamic_axes,input_names=input_names,output_names=output_names,fixed_batch_size=fixed_batch_size,training=training,verbose=verbose,)_compare_onnx_pytorch_model(pt_model=model_copy,onnx_model_f=model_f,input_args=input_args,input_kwargs=input_kwargs,additional_test_inputs=additional_test_inputs,options=options,)defverify_aten_graph(graph:torch.Graph,input_args:tuple[Any,...],export_options:_experimental.ExportOptions,params_dict:dict[str,Any]|None=None,verification_options:VerificationOptions|None=None,)->tuple[AssertionError|None,torch.Graph,_OutputsType,_OutputsType]:ifverification_optionsisNone:verification_options=VerificationOptions()ifparams_dictisNone:params_dict={}original_jit_graph=graphgraph=graph.copy()# Execute aten graph and get reference torch jit outputs.graph_inputs=list(graph.inputs())jit_inputs=tuple([argforargininput_argsifargisnotNone])weights=[params_dict[v.debugName()]forvingraph_inputs[len(jit_inputs):]]assertall(wisnotNoneforwinweights)# TODO: Only copy the argument if mutation is detected in Graph.jit_inputs=copy.deepcopy(jit_inputs)jit_input_and_parameters=jit_inputs+tuple(weights)jit_outs=torch._C._jit_interpret_graph(graph,jit_input_and_parameters)# type: ignore[attr-defined]ifnotisinstance(jit_outs,(list,tuple)):jit_outs=[jit_outs]# Convert aten graph to onnx graph.graph,onnx_params_dict=_onnx_graph_from_aten_graph(graph,export_options,params_dict)proto,export_map=_onnx_proto_from_onnx_graph(graph,export_options,onnx_params_dict)model_f:str|io.BytesIO=io.BytesIO()onnx_proto_utils._export_file(proto,model_f,export_map)# NOTE: Verification is unstable. Try catch to emit information for debugging.try:# NOTE: Input might be dce'ed, so we need to remove those from the input args.new_input_names={v.debugName()forvingraph.inputs()}new_input_args=[]forv,arginzip(original_jit_graph.inputs(),input_args):ifv.debugName()innew_input_names:new_input_args.append(arg)input_args=tuple(new_input_args)onnx_inputs=_prepare_input_for_onnx(input_args,{},verification_options.remained_onnx_input_idx,verification_options.flatten,)onnx_session=_onnx_backend_session(model_f,verification_options.backend)onnx_outs=_run_onnx(onnx_session,onnx_inputs)delonnx_session# To free device memorytry:_compare_onnx_pytorch_outputs(onnx_outs=onnx_outs,pt_outs=jit_outs,options=verification_options,)exceptAssertionErrorase:returne,graph,jit_outs,onnx_outsreturnNone,graph,jit_outs,onnx_outsexceptExceptionase:print("Unexpected error during verification.")print("jit graph: ",original_jit_graph)print("onnx graph: ",graph)raiseeclassGraphInfoPrettyPrinter:graph_info:GraphInfo|Noneupper_printer:GraphInfoPrettyPrinter|Nonelower_printer:GraphInfoPrettyPrinter|Nonegraph_str_lambdas:Mapping[int,str]connector_str_lambdas:Mapping[int,str]children_str_lambdas:Mapping[int,str]def__init__(self,graph_info:GraphInfo|None):self.graph_info=graph_infoif(graph_infoisnotNoneandgraph_info.upper_graph_infoisnotNoneandgraph_info.lower_graph_infoisnotNone):self.upper_printer=GraphInfoPrettyPrinter(graph_info.upper_graph_info)self.lower_printer=GraphInfoPrettyPrinter(graph_info.lower_graph_info)else:self.upper_printer=Noneself.lower_printer=Nonedef_total_rows(self)->int:ifself.graph_infoisNone:return1ifself.upper_printerandself.lower_printer:return(self.upper_printer._total_rows()+self.lower_printer._total_rows()+1)return2# Two lines: node count + id.def_node_count_segment_str(self)->str:ifself.graph_infoisNone:return"..."node_count=self.graph_info.essential_node_count()has_mismatch=self.graph_info.has_mismatch()error_node_kind=(f"({self.graph_info.essential_node_kinds().pop()})"ifnode_count==1andhas_mismatchelse"")returnf"{node_count}{'X'ifhas_mismatchelsechr(0x2713)}{error_node_kind}"def_graph_id_segment_str(self)->str:ifself.graph_infoisNone:return""returnf"id: {self.graph_info.id}"def_max_segment_columns(self)->int:returnmax(map(len,(self._node_count_segment_str(),self._graph_id_segment_str())))def_graph_segment_str_at_line(self,line:int)->str:"""Get the string representation of the graph segment at the given line."""ifline==0:result_str=self._node_count_segment_str()result_str+=" "*(self._max_segment_columns()-len(result_str))returnresult_strifline==1:result_str=self._graph_id_segment_str()result_str+=" "*(self._max_segment_columns()-len(result_str))returnresult_strif0<=line<self._total_rows():return" "*self._max_segment_columns()return""def_connector_segment_str_at_line(self,line:int)->str:"""Get the connector segment string at the given line."""ifself.upper_printerisNoneandself.lower_printerisNone:return""upper_total_rows=self.upper_printer._total_rows()ifself.upper_printerelse1lower_total_rows=self.lower_printer._total_rows()ifself.lower_printerelse1ifline==0:return" __"elifline<upper_total_rows+1:return" | "elifline==upper_total_rows+1:return" |__"elifline<upper_total_rows+lower_total_rows+1:return" "return""def_children_str_at_line(self,line:int)->str:"""Get the string representation of the children at the given line. Recursively calls `_str_at_line` on children nodes. """ifself.upper_printerisNoneandself.lower_printerisNone:return""upper_total_rows=self.upper_printer._total_rows()ifself.upper_printerelse1lower_total_rows=self.lower_printer._total_rows()ifself.lower_printerelse1if0<=line<upper_total_rows:return(self.upper_printer._str_at_line(line)ifself.upper_printerelse"...")elifupper_total_rows<line<upper_total_rows+lower_total_rows+1:return(self.lower_printer._str_at_line(line-upper_total_rows-1)ifself.lower_printerelse"...")return""def_str_at_line(self,line:int)->str:"""Get the string representation of the graph at the given line."""return(self._graph_segment_str_at_line(line)+self._connector_segment_str_at_line(line)+self._children_str_at_line(line))defpretty_print(self):ifself.graph_infoisNone:print(None)return# Print tree.print(" Tree: ".center(80,"="))total_rows=self._total_rows()forlineinrange(total_rows):print(self._str_at_line(line).rstrip())ifself.graph_info.has_mismatch():# Summarize leaf subgraphs with mismatch.print(" Mismatch leaf subgraphs: ".center(80,"="))print([graph_info.idforgraph_infoinself.graph_info.all_mismatch_leaf_graph_info()])# Summarize node kinds with mismatch.mismatch_node_kinds:dict[str,int]={}forgraph_infoinself.graph_info.all_mismatch_leaf_graph_info():node_kinds=graph_info.essential_node_kinds()iflen(node_kinds)==1:node_kind=node_kinds.pop()mismatch_node_kinds[node_kind]=(mismatch_node_kinds.get(node_kind,0)+1)print(" Mismatch node kinds: ".center(80,"="))print(mismatch_node_kinds)else:print(" No mismatch found. ".center(80,"="))classOnnxTestCaseRepro:def__init__(self,repro_dir):self.repro_dir=repro_dirself.proto,self.inputs,self.outputs=onnx_proto_utils.load_test_case(repro_dir)@classmethoddefcreate_test_case_repro(cls,proto:bytes,inputs,outputs,dir:str,name:str|None=None):"""Create a repro under "{dir}/test_{name}" for an ONNX test case. The test case contains the model and the inputs/outputs data. The directory structure is as follows: dir \u251c\u2500\u2500 test_<name> \u2502 \u251c\u2500\u2500 model.onnx \u2502 \u2514\u2500\u2500 test_data_set_0 \u2502 \u251c\u2500\u2500 input_0.pb \u2502 \u251c\u2500\u2500 input_1.pb \u2502 \u251c\u2500\u2500 output_0.pb \u2502 \u2514\u2500\u2500 output_1.pb Args: proto: ONNX model proto. inputs: Inputs to the model. outputs: Outputs of the model. dir: Directory to save the repro. name: Name of the test case. If not specified, a name based on current time will be generated. Returns: Path to the repro. """ifnameisNone:name=datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")returnonnx_proto_utils.export_as_test_case(proto,_to_numpy(inputs),_to_numpy(outputs),name,dir,)defvalidate(self,options:VerificationOptions):"""Run the ONNX test case with options.backend, and compare with the expected outputs. Args: options: Options for validation. Raise: AssertionError: if outputs from options.backend and expected outputs are not equal up to specified precision. """onnx_session=_onnx_backend_session(io.BytesIO(self.proto),options.backend)run_outputs=onnx_session.run(None,self.inputs)ifhasattr(onnx_session,"get_outputs"):output_names=[o.nameforoinonnx_session.get_outputs()]elifhasattr(onnx_session,"output_names"):output_names=onnx_session.output_nameselse:raiseValueError(f"Unknown onnx session type: {type(onnx_session)}")expected_outs=[self.outputs[name]fornameinoutput_names]_compare_onnx_pytorch_outputs_in_np(run_outputs,expected_outs,options)
[docs]@dataclasses.dataclassclassGraphInfo:"""GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph."""graph:torch.Graphinput_args:tuple[Any,...]params_dict:dict[str,Any]export_options:_experimental.ExportOptions=dataclasses.field(default_factory=_experimental.ExportOptions)mismatch_error:AssertionError|None=dataclasses.field(default=None,init=False)pt_outs:Sequence[_NumericType]|None=dataclasses.field(default=None,init=False)upper_graph_info:GraphInfo|None=dataclasses.field(default=None,init=False)lower_graph_info:GraphInfo|None=dataclasses.field(default=None,init=False)id:str=dataclasses.field(default="")_onnx_graph:torch.Graph|None=dataclasses.field(init=False,default=None)_EXCLUDED_NODE_KINDS:frozenset[str]=frozenset({"prim::Constant","prim::ListConstruct","aten::ScalarImplicit"})
[docs]defclear(self):"""Clear states and results of previous verification."""self.mismatch_error=Noneself.pt_outs=Noneself._onnx_graph=Noneself.upper_graph_info=Noneself.lower_graph_info=None
[docs]defpretty_print_tree(self):"""Pretty print `GraphInfo` tree. Each node represents a subgraph, showing the number of nodes in the subgraph and a check mark if the subgraph has output mismatch between torch and ONNX. The id of the subgraph is shown under the node. The `GraphInfo` object for any subgraph can be retrieved by calling `graph_info.find_partition(id)`. Example:: ==================================== Tree: ===================================== 5 X __2 X __1 \u2713 id: | id: 0 | id: 00 | | | |__1 X (aten::relu) | id: 01 | |__3 X __1 \u2713 id: 1 | id: 10 | |__2 X __1 X (aten::relu) id: 11 | id: 110 | |__1 \u2713 id: 111 =========================== Mismatch leaf subgraphs: =========================== ['01', '110'] ============================= Mismatch node kinds: ============================= {'aten::relu': 2} """GraphInfoPrettyPrinter(self).pretty_print()
[docs]defpretty_print_mismatch(self,graph:bool=False):"""Pretty print details of the mismatch between torch and ONNX. Args: graph: If True, print the ATen JIT graph and ONNX graph. """print(f" Mismatch info for graph partition {self.id}: ".center(80,"="))ifgraph:print(" ATen JIT graph ".center(80,"="))# TODO: A more compact graph printer.# * Drop stride, grad, device information.# * Show source location on a separate line.print(self.graph)ifself._onnx_graphisnotNone:print(" ONNX graph ".center(80,"="))print(self._onnx_graph)ifself.has_mismatch():print(" Mismatch error ".center(80,"="))print(self.mismatch_error)else:print(" No mismatch ".center(80,"="))
[docs]defhas_mismatch(self)->bool:"""Return True if the subgraph has output mismatch between torch and ONNX."""returnself.mismatch_errorisnotNone
[docs]defessential_node_count(self)->int:"""Return the number of nodes in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`."""returnsum(1forninself.graph.nodes()ifn.kind()notinself._EXCLUDED_NODE_KINDS)
[docs]defessential_node_kinds(self)->set[str]:"""Return the set of node kinds in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`."""return{n.kind()forninself.graph.nodes()ifn.kind()notinself._EXCLUDED_NODE_KINDS}
[docs]defall_mismatch_leaf_graph_info(self)->list[GraphInfo]:"""Return a list of all leaf `GraphInfo` objects that have mismatch."""ifnotself.has_mismatch():return[]no_mismatch_children=(self.upper_graph_infoisNoneornotself.upper_graph_info.has_mismatch())and(self.lower_graph_infoisNoneornotself.lower_graph_info.has_mismatch())ifno_mismatch_children:return[self]results=[]ifself.upper_graph_infoisnotNone:results+=self.upper_graph_info.all_mismatch_leaf_graph_info()ifself.lower_graph_infoisnotNone:results+=self.lower_graph_info.all_mismatch_leaf_graph_info()returnresults
[docs]deffind_partition(self,id:str)->GraphInfo|None:"""Find the `GraphInfo` object with the given id."""ifid==self.id:returnselfcurrent_length=len(self.id)iflen(id)>current_length:ifid[current_length]=="0"andself.upper_graph_infoisnotNone:returnself.upper_graph_info.find_partition(id)elifid[current_length]=="1"andself.lower_graph_infoisnotNone:returnself.lower_graph_info.find_partition(id)returnNone
[docs]defexport_repro(self,repro_dir:str|None=None,name:str|None=None)->str:"""Export the subgraph to ONNX along with the input/output data for repro. The repro directory will contain the following files:: dir \u251c\u2500\u2500 test_<name> \u2502 \u251c\u2500\u2500 model.onnx \u2502 \u2514\u2500\u2500 test_data_set_0 \u2502 \u251c\u2500\u2500 input_0.pb \u2502 \u251c\u2500\u2500 input_1.pb \u2502 \u251c\u2500\u2500 output_0.pb \u2502 \u2514\u2500\u2500 output_1.pb Args: repro_dir: The directory to export the repro files to. Defaults to current working directory if None. name: An optional name for the test case folder: "test_{name}". Returns: The path to the exported repro directory. """ifrepro_dirisNone:repro_dir=os.getcwd()repro_dir=os.path.join(repro_dir,"onnx_debug")onnx_graph,onnx_params_dict=_onnx_graph_from_aten_graph(self.graph,self.export_options,self.params_dict)proto,_=_onnx_proto_from_onnx_graph(onnx_graph,self.export_options,onnx_params_dict)returnOnnxTestCaseRepro.create_test_case_repro(proto,self.input_args,self.pt_outs,repro_dir,name)
def_graph_partition_pivot(self)->int:"""Find the pivot index to partition the graph. The pivot is the node that splits the graph into two parts. Each part should have the similar amount of nodes, excluding non essential ops, defined in `_EXCLUDED_NODE_KINDS`, such as `prim::Constant`. If the graph has an odd number of nodes, the upper part will have one more node. If the graph does not have any node that can be partitioned, return -1. Returns: The index of the pivot node. """included_node_indices=[ifori,ninenumerate(self.graph.nodes())ifn.kind()notinself._EXCLUDED_NODE_KINDS]half_idx=len(included_node_indices)//2-1ifhalf_idx>=0andlen(included_node_indices)>half_idx:returnincluded_node_indices[half_idx]+1return-1def_partition_upper_graph(self)->torch.Graph:pivot=self._graph_partition_pivot()ifpivot==-1:returntorch.Graph()graph=self.graph.copy()# Copy to not mutate parent graph.original_outputs=list(graph.outputs())def_process_bridge_value_for_upper(new_outputs:list[torch.Value],bridge_value:torch.Value)->torch.Value:# Add bridge values as upper graph outputs.new_outputs.append(bridge_value)returnbridge_valuenew_outputs:list[torch.Value]=[]process_bridge_value_for_upper=functools.partial(_process_bridge_value_for_upper,new_outputs)_,dropped_nodes,complete_upper_nodes_set,_=self._partition_nodes(graph,pivot,process_bridge_value_for_upper)for_inenumerate(original_outputs):graph.eraseOutput(0)foroutputinnew_outputs:graph.registerOutput(output)fornodeinreversed(dropped_nodes):node.destroy()fori,inputinreversed(list(enumerate(list(graph.inputs())))):if(not_has_uses_by_nodes(input,complete_upper_nodes_set)andinputnotinnew_outputs):try:graph.eraseInput(i)exceptRuntimeErrorase:print(input,graph)raiseereturngraphdef_partition_lower_graph(self)->torch.Graph:pivot=self._graph_partition_pivot()ifpivot==-1:returntorch.Graph()graph=self.graph.copy()# Copy to not mutate parent graph.original_outputs=list(graph.outputs())original_inputs=list(graph.inputs())def_process_bridge_value_for_lower(graph:torch.Graph,bridge_value:torch.Value)->torch.Value:# Add bridge values as lower graph inputs.new_input=graph.addInput()bridge_value.replaceAllUsesWith(new_input)new_input.copyMetadata(bridge_value)returnnew_inputprocess_bridge_value_for_lower=functools.partial(_process_bridge_value_for_lower,graph)upper_nodes,lower_nodes,_,complete_lower_nodes_set=self._partition_nodes(graph,pivot,process_bridge_value_for_lower)new_outputs=[outputforoutputinoriginal_outputsif_produced_by(output,lower_nodes)]for_inenumerate(original_outputs):graph.eraseOutput(0)foroutputinnew_outputs:graph.registerOutput(output)forinputinoriginal_inputs:if_has_uses_by_nodes(input,complete_lower_nodes_set):new_input=graph.addInput()input.replaceAllUsesWith(new_input)new_input.copyMetadata(input)fornodeinreversed(upper_nodes):ifnodenotincomplete_lower_nodes_set:try:node.destroy()exceptRuntimeErrorase:print(node,graph)raiseefor_inoriginal_inputs:graph.eraseInput(0)returngraphdef_partition_node(self,node:torch.Node,complete_upper_nodes_set:set[torch.Node],complete_lower_nodes_set:set[torch.Node],original_graph_outputs:set[torch.Value],covered_bridge_values:set[torch.Value],process_bridge_value:Callable[[torch.Value],torch.Value],):ifnodeincomplete_lower_nodes_set:returnif(_node_has_uses_by(node,complete_lower_nodes_set)andnode.kind()inself._EXCLUDED_NODE_KINDS):complete_lower_nodes_set.update(_all_nodes([node]))forinputinnode.inputs():ifinputincovered_bridge_values:continueself._partition_node(input.node(),complete_upper_nodes_set,complete_lower_nodes_set,original_graph_outputs,covered_bridge_values,process_bridge_value,)else:foroutputinnode.outputs():ifoutputincovered_bridge_values:continueif(_has_uses_by_nodes(output,complete_lower_nodes_set)oroutputinoriginal_graph_outputs):covered_bridge_values.add(process_bridge_value(output))def_partition_nodes(self,graph:torch.Graph,pivot:int,process_bridge_value:Callable[[torch.Value],torch.Value],)->tuple[list[torch.Node],list[torch.Node],set[torch.Node],set[torch.Node]]:nodes=list(graph.nodes())upper_nodes=nodes[:pivot]lower_nodes=nodes[pivot:]# `upper_nodes` and `complete_upper_nodes_set` differs in that the latter# recursively contains nodes in subblock of `upper_nodes`.# The same applies for `lower_nodes` and `complete_lower_nodes_set`.# With addition that `complete_lower_nodes_set` will include nodes that# are determined to be copied from `upper_nodes` to `lower_nodes`.complete_upper_nodes_set=_all_nodes(upper_nodes)complete_lower_nodes_set=_all_nodes(lower_nodes)original_graph_outputs=set(graph.outputs())# Bridge values are values produced from upper graph, and consumed# by lower graph. These values need to be become upper graph outputs# and lower graph inputs, to bridge the interaction.# Start with all graph inputs marked as covered. If any graph input is# needed by lower graph, just keep it in lower graph inputs later.covered_bridge_values=set(graph.inputs())fornodeinupper_nodes:self._partition_node(node,complete_upper_nodes_set,complete_lower_nodes_set,original_graph_outputs,covered_bridge_values,process_bridge_value,)return(upper_nodes,lower_nodes,complete_upper_nodes_set,complete_lower_nodes_set,)def_bridge_kwargs(self):pt_outs=self.pt_outsgraph_outputs=list(self.graph.outputs())assertpt_outsisnotNoneassertlen(graph_outputs)==len(pt_outs),f"{len(graph_outputs)} vs {len(pt_outs)}\nGraph: {self.graph}"return{v.debugName():oforv,oinzip(graph_outputs,pt_outs)}def_args_and_params_for_partition_graph(self,graph:torch.Graph,bridge_kwargs:Mapping[str,_NumericType|Sequence[_NumericType]],full_kwargs:Mapping[str,torch.Tensor],full_params:Mapping[str,torch.Tensor],):input_names=[input.debugName()forinputingraph.inputs()]args=tuple(bridge_kwargs[k]forkininput_namesifkinbridge_kwargs)args+=tuple(full_kwargs[k]forkininput_namesifkinfull_kwargs)params={k:full_params[k]forkininput_namesifkinfull_params}assertlen(args)+len(params)==len(input_names),f"{len(args)} + {len(params)} vs {len(input_names)}: {input_names}"returnargs,params
[docs]defverify_export(self,options:VerificationOptions)->tuple[AssertionError|None,torch.Graph,_OutputsType,_OutputsType]:""" Verify the export from TorchScript IR graph to ONNX. Export the TorchScript IR graph to ONNX, with the inputs, parameters and export options recorded in this object. Then verify the exported ONNX graph against the original TorchScript IR graph under the provided verification options. Args: options: The verification options. Returns: error: The AssertionError raised during the verification. Returns None if no error is raised. onnx_graph: The exported ONNX graph in TorchScript IR format. onnx_outs: The outputs from running exported ONNX model under the onnx backend in `options`. pt_outs: The outputs from running the TorchScript IR graph. """returnverify_aten_graph(self.graph,input_args=self.input_args,params_dict=self.params_dict,export_options=self.export_options,verification_options=options,)
[docs]deffind_mismatch(self,options:VerificationOptions|None=None,):""" Find all mismatches between the TorchScript IR graph and the exported onnx model. Binary searches the model graph to find the minimal subgraph that exhibits the mismatch. A `GraphInfo` object is created for each subgraph, recording the test inputs and export options, as well as the validation results. Args: options: The verification options. """self.clear()ifoptionsisNone:options=VerificationOptions()ifself.export_options.verbose:print(self.graph)iflen(list(self.graph.outputs()))==0:returnassertlen(self.input_args)+len(self.params_dict)==len(list(self.graph.inputs())),(f"Number of graph inputs({len(list(self.graph.inputs()))}) does not match "f"the provided tensor arguments({len(self.input_args)} + {len(self.params_dict)}).")self.mismatch_error,self._onnx_graph,self.pt_outs,_=self.verify_export(options)ifself.mismatch_errorisNone:# No mismatch found in graph.returnifself.essential_node_count()<=1:# Reached leaf node, no more partitioning.returnfull_kwargs={k.debugName():vfork,vinzip(self.graph.inputs(),self.input_args)}full_params=self.params_dictupper_graph=self._partition_upper_graph()upper_args,upper_params=self._args_and_params_for_partition_graph(upper_graph,{},full_kwargs,full_params)self.upper_graph_info=GraphInfo(upper_graph,upper_args,upper_params,self.export_options,id=self.id+"0",)self.upper_graph_info.find_mismatch(options)bridge_kwargs=self.upper_graph_info._bridge_kwargs()lower_graph=self._partition_lower_graph()lower_args,lower_params=self._args_and_params_for_partition_graph(lower_graph,bridge_kwargs,full_kwargs,full_params)self.lower_graph_info=GraphInfo(lower_graph,lower_args,lower_params,self.export_options,id=self.id+"1",)self.lower_graph_info.find_mismatch(options)
[docs]deffind_mismatch(model:torch.nn.Module|torch.jit.ScriptModule,input_args:tuple[Any,...],do_constant_folding:bool=True,training:_C_onnx.TrainingMode=_C_onnx.TrainingMode.EVAL,opset_version:int|None=None,keep_initializers_as_inputs:bool=True,verbose:bool=False,options:VerificationOptions|None=None,)->GraphInfo:r"""Find all mismatches between the original model and the exported model. Experimental. The API is subject to change. This tool helps debug the mismatch between the original PyTorch model and exported ONNX model. It binary searches the model graph to find the minimal subgraph that exhibits the mismatch. Args: model: The model to be exported. input_args: The input arguments to the model. do_constant_folding: Same as `do_constant_folding` in :func:`torch.onnx.export`. training: Same as `training` in :func:`torch.onnx.export`. opset_version: Same as `opset_version` in :func:`torch.onnx.export`. keep_initializers_as_inputs: Same as `keep_initializers_as_inputs` in :func:`torch.onnx.export`. verbose: Same as `verbose` in :func:`torch.onnx.export`. options: The options for the mismatch verification. Returns: A GraphInfo object that contains the mismatch information. Example:: >>> import torch >>> import torch.onnx.verification >>> torch.manual_seed(0) >>> opset_version = 15 >>> # Define a custom symbolic function for aten::relu. >>> # The custom symbolic function is incorrect, which will result in mismatches. >>> def incorrect_relu_symbolic_function(g, self): ... return self >>> torch.onnx.register_custom_op_symbolic( ... "aten::relu", ... incorrect_relu_symbolic_function, ... opset_version=opset_version, ... ) >>> class Model(torch.nn.Module): ... def __init__(self) -> None: ... super().__init__() ... self.layers = torch.nn.Sequential( ... torch.nn.Linear(3, 4), ... torch.nn.ReLU(), ... torch.nn.Linear(4, 5), ... torch.nn.ReLU(), ... torch.nn.Linear(5, 6), ... ) ... def forward(self, x): ... return self.layers(x) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) >>> graph_info = torch.onnx.verification.find_mismatch( ... Model(), ... (torch.randn(2, 3),), ... opset_version=opset_version, ... ) ===================== Mismatch info for graph partition : ====================== ================================ Mismatch error ================================ Tensor-likes are not close! Mismatched elements: 12 / 12 (100.0%) Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed) Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed) ==================================== Tree: ===================================== 5 X __2 X __1 \u2713 id: | id: 0 | id: 00 | | | |__1 X (aten::relu) | id: 01 | |__3 X __1 \u2713 id: 1 | id: 10 | |__2 X __1 X (aten::relu) id: 11 | id: 110 | |__1 \u2713 id: 111 =========================== Mismatch leaf subgraphs: =========================== ['01', '110'] ============================= Mismatch node kinds: ============================= {'aten::relu': 2} """ifoptionsisNone:options=VerificationOptions()ifopset_versionisNone:opset_version=_constants.ONNX_DEFAULT_OPSET"""From aten graph, do binary search on graph partition to find operator export discrepancy."""# TODO: Copied from utils.py `export` until `_optimize_graph`.iftraining==torch.onnx.TrainingMode.TRAINING:model.train()eliftraining==torch.onnx.TrainingMode.EVAL:model.eval()withtorch.no_grad():inputs_for_export=_prepare_input_for_export(input_args,{})args=utils._decide_input_format(model,inputs_for_export)model=utils._pre_trace_quant_model(model,args)graph,params,_torch_out,_module=utils._create_jit_graph(model,args)params_dict=utils._get_named_param_dict(graph,params)utils._apply_friendly_debug_names(graph,params_dict)graph_info=GraphInfo(graph,input_args,params_dict,_experimental.ExportOptions(do_constant_folding=do_constant_folding,training=training,opset_version=opset_version,keep_initializers_as_inputs=keep_initializers_as_inputs,verbose=verbose,),)graph_info.find_mismatch(options)graph_info.pretty_print_mismatch()graph_info.pretty_print_tree()returngraph_info
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.