importenumimportoperatorimporttorchimporttorch.nnasnnimporttorch.nn.intrinsic.quantizedasnniqimporttorch.nn.quantizedasnnqtoq=torch.ops.quantizedfromtypingimportTuple,Callable,Dict,Set,List,Optional,Unionfromtorch.fximportGraphModulefromtorch.fx.graphimportNodefromtorch.ao.quantizationimport(ObserverBase,FakeQuantizeBase,)fromtorch.ao.quantization.utilsimportgetattr_from_fqnfromtorch.ao.quantization.quantizeimportis_activation_post_processfrom.ns_typesimportNSNodeTargetType,NSResultsType# TODO(future PR): consider deleting this enum and using the torch types# directly. This might be tricky because it is not a one to one mapping.classNodeInputOrOutputType(enum.Enum):FP32=enum.auto()# torch.floatINT8=enum.auto()# torch.qint8 or torch.quint8FP16=enum.auto()# torch.float16UNKNOWN=enum.auto()# we cannot determine input/output dtype# TODO(future PR): while these functions can support multiple dtypes,# for the purposes of numerical debugging we want to get the actual# dtype used in the model. We will likely need some kind of dtype# propagation to estimate this.FP32_OR_INT8=enum.auto()# either torch.float or torch.quint8 or torch.qint8# TODO(future PRs): dynamic quant, fake quant, etcdefget_node_first_input_and_output_type(node:Node,gm:GraphModule,logger_cls:Callable,node_type_to_io_type_map:Dict[str,Set[NSNodeTargetType]],)->Tuple[NodeInputOrOutputType,NodeInputOrOutputType]:# TODO(future PR): clean this upFUNS_IO_TYPE_FP32=node_type_to_io_type_map["funs_io_type_fp32"]FUNS_IO_TYPE_FP16=node_type_to_io_type_map["funs_io_type_fp16"]FUNS_IO_TYPE_INT8=node_type_to_io_type_map["funs_io_type_int8"]FUNS_IO_TYPE_FP32_OR_INT8=node_type_to_io_type_map["funs_io_type_fp32_or_int8"]MODS_IO_TYPE_FP32=node_type_to_io_type_map["mods_io_type_fp32"]MODS_IO_TYPE_INT8=node_type_to_io_type_map["mods_io_type_int8"]MODS_IO_TYPE_FP32_OR_INT8=node_type_to_io_type_map["mods_io_type_fp32_or_int8"]METHS_IO_TYPE_FP32_OR_INT8=node_type_to_io_type_map["meths_io_type_fp32_or_int8"]ifnode.op=="call_function":ifnode.targetinFUNS_IO_TYPE_FP32:return(NodeInputOrOutputType.FP32,NodeInputOrOutputType.FP32)ifnode.targetinFUNS_IO_TYPE_FP16:return(NodeInputOrOutputType.FP16,NodeInputOrOutputType.FP16)elifnode.targetinFUNS_IO_TYPE_INT8:return(NodeInputOrOutputType.INT8,NodeInputOrOutputType.INT8)elifnode.targetinFUNS_IO_TYPE_FP32_OR_INT8:return(NodeInputOrOutputType.FP32_OR_INT8,NodeInputOrOutputType.FP32_OR_INT8,)else:return(NodeInputOrOutputType.UNKNOWN,NodeInputOrOutputType.UNKNOWN)elifnode.op=="call_module":assertnode.op=="call_module"assertisinstance(node.target,str)mod=getattr_from_fqn(gm,node.target)ifisinstance(mod,(logger_cls,ObserverBase,FakeQuantizeBase)):# type: ignore[arg-type]# A logger or observer's input and output type is the output# type of the preceding node.first_arg=node.args[0]assertisinstance(first_arg,Node)(_prev_node_input_type,prev_node_output_type,)=get_node_first_input_and_output_type(first_arg,gm,logger_cls,node_type_to_io_type_map)return(prev_node_output_type,prev_node_output_type)is_known_fp32_input_module=any(isinstance(mod,target_type)fortarget_typeinMODS_IO_TYPE_FP32# type: ignore[arg-type])is_known_int8_input_module=any(isinstance(mod,target_type)fortarget_typeinMODS_IO_TYPE_INT8# type: ignore[arg-type])is_known_fp32_or_int8_input_module=any(isinstance(mod,target_type)fortarget_typeinMODS_IO_TYPE_FP32_OR_INT8# type: ignore[arg-type])ifis_known_fp32_input_module:return(NodeInputOrOutputType.FP32,NodeInputOrOutputType.FP32)elifis_known_int8_input_module:return(NodeInputOrOutputType.INT8,NodeInputOrOutputType.INT8)elifis_known_fp32_or_int8_input_module:return(NodeInputOrOutputType.FP32_OR_INT8,NodeInputOrOutputType.FP32_OR_INT8,)else:return(NodeInputOrOutputType.UNKNOWN,NodeInputOrOutputType.UNKNOWN)elifnode.op=="call_method":ifnode.target=="dequantize":# Dequantize is a special node because it allows multiple input types.# So, we look up the output type of the previous node and return that# as the input type of this node instance.prev_node=node.args[0]assertisinstance(prev_node,Node)(_prev_node_input_type,prev_node_output_type,)=get_node_first_input_and_output_type(prev_node,gm,logger_cls,node_type_to_io_type_map)return(prev_node_output_type,NodeInputOrOutputType.FP32)elifnode.target=="to":# to is a special node because it allows multiple input types.# So, we look up the output type of the previous node and return that# as the input type of this node instance. We also look up the target# of to and return the correct output type.prev_node=node.args[0]assertisinstance(prev_node,Node)(_prev_node_input_type,prev_node_output_type,)=get_node_first_input_and_output_type(prev_node,gm,logger_cls,node_type_to_io_type_map)cur_node_dtype_target=node.args[1]assert(cur_node_dtype_targetistorch.float16),f"{cur_node_dtype_target} handling needs to be added"return(prev_node_output_type,NodeInputOrOutputType.FP16)elifnode.targetinMETHS_IO_TYPE_FP32_OR_INT8:return(NodeInputOrOutputType.FP32_OR_INT8,NodeInputOrOutputType.FP32_OR_INT8,)return(NodeInputOrOutputType.UNKNOWN,NodeInputOrOutputType.UNKNOWN)else:return(NodeInputOrOutputType.UNKNOWN,NodeInputOrOutputType.UNKNOWN)defget_node_input_qparams(node:Node,gm:GraphModule,node_type_to_io_type_map:Dict[str,Set[NSNodeTargetType]],)->Optional[Tuple[Union[torch.Tensor,float],Union[torch.Tensor,int]]]:""" Returns the qparams (scale, zero_point) of the first input to `node`, if they can be inferred from the graph. """prev_node=node.args[0]ifnotisinstance(prev_node,Node):returnNoneMODS_IO_TYPE_FP32_OR_INT8=node_type_to_io_type_map["mods_io_type_fp32_or_int8"]def_get_scale_zp_from_function_args(node,gm,scale_arg_idx,zp_arg_idx):scale_node,zp_node=node.args[scale_arg_idx],node.args[zp_arg_idx]assertisinstance(scale_node,Node)andisinstance(scale_node.target,str)assertisinstance(zp_node,Node)andisinstance(zp_node.target,str)scale_obj=getattr_from_fqn(gm,scale_node.target)zp_obj=getattr_from_fqn(gm,zp_node.target)return(scale_obj,zp_obj)ifprev_node.op=="call_function":# quantize - read the args directlyifprev_node.target==torch.quantize_per_tensor:return_get_scale_zp_from_function_args(prev_node,gm,1,2)elifprev_node.targetin(toq.add,toq.add_relu,toq.mul,toq.mul_relu):return_get_scale_zp_from_function_args(prev_node,gm,2,3)returnNone# TODO(future PR): handle more functionals# TODO(future PR): handle functional ops which inherit qparams from inputelifprev_node.op=="call_module":# get type of the moduleassertisinstance(prev_node.target,str)module_obj=getattr_from_fqn(gm,prev_node.target)ifisinstance(module_obj,(nnq.Linear,nnq.Conv1d,nnq.Conv2d,nniq.ConvReLU2d,nnq.Conv3d,nnq.BatchNorm2d,nnq.BatchNorm3d,nnq.ConvTranspose1d,nnq.ConvTranspose2d,nnq.ELU,nnq.GroupNorm,nnq.InstanceNorm1d,nnq.InstanceNorm2d,nnq.InstanceNorm3d,nnq.LayerNorm,nnq.Hardswish,nnq.LeakyReLU,nnq.ReLU6,nniq.BNReLU2d,nniq.BNReLU3d,nniq.ConvReLU1d,nniq.ConvReLU2d,nniq.ConvReLU3d,nniq.LinearReLU,),):return(module_obj.scale,module_obj.zero_point)# type: ignore[return-value]is_known_fp32_or_int8_input_module=any(isinstance(module_obj,target_type)fortarget_typeinMODS_IO_TYPE_FP32_OR_INT8# type: ignore[arg-type])ifis_known_fp32_or_int8_input_module:returnget_node_input_qparams(prev_node,gm,node_type_to_io_type_map)returnNonedefreturn_first_non_observer_node(node:Node,gm:GraphModule,)->Node:""" If node is not an observer, returns it. If node is an observer, navigates up the graph and returns the first parent which is not an observer. For example, graph: (node_non_obs), node = node_non_obs : returns node_non_obs graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs """ifnode.op=="call_module":node_obj=getattr_from_fqn(gm,node.target)# type: ignore[arg-type]ifis_activation_post_process(node_obj):assertlen(node.args)==1assertisinstance(node.args[0],Node)node=node.args[0]# code duplication intended, not worth refactoringassertisinstance(node.target,str)node_obj=getattr_from_fqn(gm,node.target)ifis_activation_post_process(node_obj):assertlen(node.args)==1assertisinstance(node.args[0],Node)node=node.args[0]returnnodedefget_number_of_non_param_args(node:Node,gm:GraphModule,)->int:""" Assumes that all non-param args occur first. Returns the number of non-param args expected for a node. For example, for F.linear(x, weight, bias) Returns 1, because x is a non-param arg and weight and bias are params. For lstm_mod(x, hid) Returns 2, because both x and hid are non-param args. """ifnode.op=="call_module":node_obj=getattr_from_fqn(gm,node.target)# type: ignore[arg-type]ifisinstance(node_obj,nn.LSTM):return2# default is 1return1defget_arg_indices_of_inputs_to_log(node:Node)->List[int]:""" Returns the indices of args of the node which we should attach loggers to, if input logging is enabled. For example, * for (x + y), returns [0, 1] * for (1 + y), returns [1] * for (x + 1), returns [0] * for (linear(x, w, b)) returns [0] * by default, returns [0] """iflen(node.args)==0:return[]ifnode.op=="call_function"and(# TODO(future PR): use relationship map instead of hardcodingnode.targetin(torch.add,torch.ops.quantized.add,operator.add)ornode.targetin(torch.mul,torch.ops.quantized.mul,operator.mul)):result=[]foriinrange(2):iftype(node.args[i])==Node:result.append(i)returnresultreturn[0]defget_target_type_str(node:Node,gm:GraphModule)->str:""" Returns a string representation of the type of the function or module pointed to by this node, or '' for other node types. """target_type=""ifnode.opin("call_function","call_method"):target_type=torch.typename(node.target)elifnode.op=="call_module":assertisinstance(node.target,str)target_mod=getattr_from_fqn(gm,node.target)target_type=torch.typename(target_mod)returntarget_typedefrekey_logger_info_on_node_name_of_model(results:NSResultsType,model_name:str,)->NSResultsType:""" Rekeys the layer name of a results dictionary to use node names from `model_name`. For example, transforms {'base_op_1_0': {'node_output': {'model_a': [{'ref_node_name': 'linear1', ...}]}}} into {'linear1': {'node_output': {'model_a': [{'ref_node_name': 'linear1', ...}]}}} Note: we cannot use these node names directly because they are not guaranteed to be consistent across models. This is why we extract the results first and rekey afterwards. """new_results={}forold_layer_name,result_type_to_resultsinresults.items():new_layer_name=Nonefor_result_type,model_name_to_resultsinresult_type_to_results.items():forcur_model_name,list_of_resultsinmodel_name_to_results.items():ifcur_model_name==model_name:assertlen(list_of_results)new_layer_name=list_of_results[0]["ref_node_name"]else:continueifnew_layer_nameisnotNone:new_results[new_layer_name]=result_type_to_resultselse:new_results[old_layer_name]=result_type_to_resultsreturnnew_resultsdefmaybe_add_missing_fqns(results:NSResultsType)->None:""" If `fqn` entries are filled in for one of the models in `results`, copies them over to any models which do not have them filled out. A common use case benefitting from this is comparing a model prepared by quantization to a quantized model. In this case, the model prepared by quantization would have `fqn` entries, and the quantized model would not. """# Check in the first result to find any model with fqn entries defined.model_name_with_fqns=Noneforlayer_name,result_type_to_resultsinresults.items():forresult_type,model_name_to_resultsinresult_type_to_results.items():formodel_name,model_resultsinmodel_name_to_results.items():iflen(model_results)>0:ifmodel_results[0]["fqn"]isnotNone:model_name_with_fqns=model_namebreakbreakbreakifmodel_name_with_fqns:forlayer_name,result_type_to_resultsinresults.items():forresult_type,model_name_to_resultsinresult_type_to_results.items():ref_model_results=model_name_to_results[model_name_with_fqns]formodel_name,model_resultsinmodel_name_to_results.items():ifmodel_name==model_name_with_fqns:continueforiinrange(len(model_results)):fqn=ref_model_results[i]["fqn"]model_results[i]["fqn"]=fqndefmaybe_dequantize_first_two_tensor_args_and_handle_tuples(f):definner(*args,**kwargs):a0,a1,*a_other=argsif(isinstance(a0,tuple)andisinstance(a1,tuple))or(isinstance(a0,list)andisinstance(a1,list)):results=[]forel0,el1inzip(a0,a1):new_args=(el0,el1,*a_other)results.append(inner(*new_args,**kwargs))returnresultselifisinstance(a0,torch.Tensor)andisinstance(a1,torch.Tensor):ifa0.is_quantized:a0=a0.dequantize()ifa1.is_quantized:a1=a1.dequantize()# for the purposes of this util, only handle floatsifa0.dtype!=torch.floatora1.dtype!=torch.float:returnNonenew_args=(a0,a1,*a_other)returnf(*new_args,**kwargs)returninner
[docs]@maybe_dequantize_first_two_tensor_args_and_handle_tuplesdefcompute_sqnr(x:torch.Tensor,y:torch.Tensor)->torch.Tensor:""" Computes the SQNR between `x` and `y`. Args: x: Tensor or tuple of tensors y: Tensor or tuple of tensors Return: float or tuple of floats """Ps=torch.norm(x)Pn=torch.norm(x-y)return20*torch.log10(Ps/Pn)
[docs]@maybe_dequantize_first_two_tensor_args_and_handle_tuplesdefcompute_normalized_l2_error(x:torch.Tensor,y:torch.Tensor)->torch.Tensor:""" Computes the normalized L2 error between `x` and `y`. Args: x: Tensor or tuple of tensors y: Tensor or tuple of tensors Return: float or tuple of floats """returntorch.sqrt(((x-y)**2).sum()/(x**2).sum())
[docs]@maybe_dequantize_first_two_tensor_args_and_handle_tuplesdefcompute_cosine_similarity(x:torch.Tensor,y:torch.Tensor)->torch.Tensor:""" Computes the cosine similarity between `x` and `y`. Args: x: Tensor or tuple of tensors y: Tensor or tuple of tensors Return: float or tuple of floats """# For convolutions, the shape of the quantized weight has one additional# dimension compared to the shape of the fp32 weight. Match the shapes# to enable cosine similarity comparison.x=x.reshape(1,-1)y=y.reshape(1,-1)returntorch.nn.functional.cosine_similarity(x,y)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.