Source code for torch.ao.quantization.fx.custom_config
from__future__importannotationsfromdataclassesimportdataclassfromtypingimportAny,Dict,List,Optional,Tuple,Typefromtorch.ao.quantizationimportQConfigMappingfromtorch.ao.quantization.backend_configimportBackendConfigfromtorch.ao.quantization.quant_typeimportQuantType,_quant_type_from_str,quant_type_to_str__all__=["ConvertCustomConfig","FuseCustomConfig","PrepareCustomConfig","StandaloneModuleConfigEntry",]# TODO: replace all usages with these constantsSTANDALONE_MODULE_NAME_DICT_KEY="standalone_module_name"STANDALONE_MODULE_CLASS_DICT_KEY="standalone_module_class"FLOAT_TO_OBSERVED_DICT_KEY="float_to_observed_custom_module_class"OBSERVED_TO_QUANTIZED_DICT_KEY="observed_to_quantized_custom_module_class"NON_TRACEABLE_MODULE_NAME_DICT_KEY="non_traceable_module_name"NON_TRACEABLE_MODULE_CLASS_DICT_KEY="non_traceable_module_class"INPUT_QUANTIZED_INDEXES_DICT_KEY="input_quantized_idxs"OUTPUT_QUANTIZED_INDEXES_DICT_KEY="output_quantized_idxs"PRESERVED_ATTRIBUTES_DICT_KEY="preserved_attributes"
[docs]@dataclassclassStandaloneModuleConfigEntry:# qconfig_mapping for the prepare function called in the submodule,# None means use qconfig from parent qconfig_mappingqconfig_mapping:Optional[QConfigMapping]example_inputs:Tuple[Any,...]prepare_custom_config:Optional[PrepareCustomConfig]backend_config:Optional[BackendConfig]
[docs]defset_standalone_module_name(self,module_name:str,qconfig_mapping:Optional[QConfigMapping],example_inputs:Tuple[Any,...],prepare_custom_config:Optional[PrepareCustomConfig],backend_config:Optional[BackendConfig])->PrepareCustomConfig:""" Set the configuration for running a standalone module identified by ``module_name``. If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead. If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used. If ``backend_config`` is None, the parent ``backend_config`` will be used instead. """self.standalone_module_names[module_name]= \
StandaloneModuleConfigEntry(qconfig_mapping,example_inputs,prepare_custom_config,backend_config)returnself
[docs]defset_standalone_module_class(self,module_class:Type,qconfig_mapping:Optional[QConfigMapping],example_inputs:Tuple[Any,...],prepare_custom_config:Optional[PrepareCustomConfig],backend_config:Optional[BackendConfig])->PrepareCustomConfig:""" Set the configuration for running a standalone module identified by ``module_class``. If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead. If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used. If ``backend_config`` is None, the parent ``backend_config`` will be used instead. """self.standalone_module_classes[module_class]= \
StandaloneModuleConfigEntry(qconfig_mapping,example_inputs,prepare_custom_config,backend_config)returnself
[docs]defset_float_to_observed_mapping(self,float_class:Type,observed_class:Type,quant_type:QuantType=QuantType.STATIC)->PrepareCustomConfig:""" Set the mapping from a custom float module class to a custom observed module class. The observed module class must have a ``from_float`` class method that converts the float module class to the observed module class. This is currently only supported for static quantization. """ifquant_type!=QuantType.STATIC:raiseValueError("set_float_to_observed_mapping is currently only supported for static quantization")ifquant_typenotinself.float_to_observed_mapping:self.float_to_observed_mapping[quant_type]={}self.float_to_observed_mapping[quant_type][float_class]=observed_classreturnself
[docs]defset_non_traceable_module_names(self,module_names:List[str])->PrepareCustomConfig:""" Set the modules that are not symbolically traceable, identified by name. """self.non_traceable_module_names=module_namesreturnself
[docs]defset_non_traceable_module_classes(self,module_classes:List[Type])->PrepareCustomConfig:""" Set the modules that are not symbolically traceable, identified by class. """self.non_traceable_module_classes=module_classesreturnself
[docs]defset_input_quantized_indexes(self,indexes:List[int])->PrepareCustomConfig:""" Set the indexes of the inputs of the graph that should be quantized. Inputs are otherwise assumed to be in fp32 by default instead. """self.input_quantized_indexes=indexesreturnself
[docs]defset_output_quantized_indexes(self,indexes:List[int])->PrepareCustomConfig:""" Set the indexes of the outputs of the graph that should be quantized. Outputs are otherwise assumed to be in fp32 by default instead. """self.output_quantized_indexes=indexesreturnself
[docs]defset_preserved_attributes(self,attributes:List[str])->PrepareCustomConfig:""" Set the names of the attributes that will persist in the graph module even if they are not used in the model's ``forward`` method. """self.preserved_attributes=attributesreturnself
# TODO: remove this
[docs]@classmethoddeffrom_dict(cls,prepare_custom_config_dict:Dict[str,Any])->PrepareCustomConfig:""" Create a ``PrepareCustomConfig`` from a dictionary with the following items: "standalone_module_name": a list of (module_name, qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config) tuples "standalone_module_class" a list of (module_class, qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config) tuples "float_to_observed_custom_module_class": a nested dictionary mapping from quantization mode to an inner mapping from float module classes to observed module classes, e.g. {"static": {FloatCustomModule: ObservedCustomModule}} "non_traceable_module_name": a list of modules names that are not symbolically traceable "non_traceable_module_class": a list of module classes that are not symbolically traceable "input_quantized_idxs": a list of indexes of graph inputs that should be quantized "output_quantized_idxs": a list of indexes of graph outputs that should be quantized "preserved_attributes": a list of attributes that persist even if they are not used in ``forward`` This function is primarily for backward compatibility and may be removed in the future. """def_get_qconfig_mapping(obj:Any,dict_key:str)->Optional[QConfigMapping]:""" Convert the given object into a QConfigMapping if possible, else throw an exception. """ifisinstance(obj,QConfigMapping)orobjisNone:returnobjifisinstance(obj,Dict):returnQConfigMapping.from_dict(obj)raiseValueError("Expected QConfigMapping in prepare_custom_config_dict[\"%s\"], got '%s'"%(dict_key,type(obj)))def_get_prepare_custom_config(obj:Any,dict_key:str)->Optional[PrepareCustomConfig]:""" Convert the given object into a PrepareCustomConfig if possible, else throw an exception. """ifisinstance(obj,PrepareCustomConfig)orobjisNone:returnobjifisinstance(obj,Dict):returnPrepareCustomConfig.from_dict(obj)raiseValueError("Expected PrepareCustomConfig in prepare_custom_config_dict[\"%s\"], got '%s'"%(dict_key,type(obj)))def_get_backend_config(obj:Any,dict_key:str)->Optional[BackendConfig]:""" Convert the given object into a BackendConfig if possible, else throw an exception. """ifisinstance(obj,BackendConfig)orobjisNone:returnobjifisinstance(obj,Dict):returnBackendConfig.from_dict(obj)raiseValueError("Expected BackendConfig in prepare_custom_config_dict[\"%s\"], got '%s'"%(dict_key,type(obj)))conf=cls()for(module_name,qconfig_dict,example_inputs,_prepare_custom_config_dict,backend_config_dict)in\
prepare_custom_config_dict.get(STANDALONE_MODULE_NAME_DICT_KEY,[]):qconfig_mapping=_get_qconfig_mapping(qconfig_dict,STANDALONE_MODULE_NAME_DICT_KEY)prepare_custom_config=_get_prepare_custom_config(_prepare_custom_config_dict,STANDALONE_MODULE_NAME_DICT_KEY)backend_config=_get_backend_config(backend_config_dict,STANDALONE_MODULE_NAME_DICT_KEY)conf.set_standalone_module_name(module_name,qconfig_mapping,example_inputs,prepare_custom_config,backend_config)for(module_class,qconfig_dict,example_inputs,_prepare_custom_config_dict,backend_config_dict)in\
prepare_custom_config_dict.get(STANDALONE_MODULE_CLASS_DICT_KEY,[]):qconfig_mapping=_get_qconfig_mapping(qconfig_dict,STANDALONE_MODULE_CLASS_DICT_KEY)prepare_custom_config=_get_prepare_custom_config(_prepare_custom_config_dict,STANDALONE_MODULE_CLASS_DICT_KEY)backend_config=_get_backend_config(backend_config_dict,STANDALONE_MODULE_CLASS_DICT_KEY)conf.set_standalone_module_class(module_class,qconfig_mapping,example_inputs,prepare_custom_config,backend_config)forquant_type_name,custom_module_mappinginprepare_custom_config_dict.get(FLOAT_TO_OBSERVED_DICT_KEY,{}).items():quant_type=_quant_type_from_str(quant_type_name)forfloat_class,observed_classincustom_module_mapping.items():conf.set_float_to_observed_mapping(float_class,observed_class,quant_type)conf.set_non_traceable_module_names(prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_NAME_DICT_KEY,[]))conf.set_non_traceable_module_classes(prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_CLASS_DICT_KEY,[]))conf.set_input_quantized_indexes(prepare_custom_config_dict.get(INPUT_QUANTIZED_INDEXES_DICT_KEY,[]))conf.set_output_quantized_indexes(prepare_custom_config_dict.get(OUTPUT_QUANTIZED_INDEXES_DICT_KEY,[]))conf.set_preserved_attributes(prepare_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY,[]))returnconf
[docs]defto_dict(self)->Dict[str,Any]:""" Convert this ``PrepareCustomConfig`` to a dictionary with the items described in :func:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig.from_dict`. """def_make_tuple(key:Any,e:StandaloneModuleConfigEntry):qconfig_dict=e.qconfig_mapping.to_dict()ife.qconfig_mappingelseNoneprepare_custom_config_dict=e.prepare_custom_config.to_dict()ife.prepare_custom_configelseNonereturn(key,qconfig_dict,e.example_inputs,prepare_custom_config_dict,e.backend_config)d:Dict[str,Any]={}formodule_name,sm_config_entryinself.standalone_module_names.items():ifSTANDALONE_MODULE_NAME_DICT_KEYnotind:d[STANDALONE_MODULE_NAME_DICT_KEY]=[]d[STANDALONE_MODULE_NAME_DICT_KEY].append(_make_tuple(module_name,sm_config_entry))formodule_class,sm_config_entryinself.standalone_module_classes.items():ifSTANDALONE_MODULE_CLASS_DICT_KEYnotind:d[STANDALONE_MODULE_CLASS_DICT_KEY]=[]d[STANDALONE_MODULE_CLASS_DICT_KEY].append(_make_tuple(module_class,sm_config_entry))forquant_type,float_to_observed_mappinginself.float_to_observed_mapping.items():ifFLOAT_TO_OBSERVED_DICT_KEYnotind:d[FLOAT_TO_OBSERVED_DICT_KEY]={}d[FLOAT_TO_OBSERVED_DICT_KEY][quant_type_to_str(quant_type)]=float_to_observed_mappingiflen(self.non_traceable_module_names)>0:d[NON_TRACEABLE_MODULE_NAME_DICT_KEY]=self.non_traceable_module_namesiflen(self.non_traceable_module_classes)>0:d[NON_TRACEABLE_MODULE_CLASS_DICT_KEY]=self.non_traceable_module_classesiflen(self.input_quantized_indexes)>0:d[INPUT_QUANTIZED_INDEXES_DICT_KEY]=self.input_quantized_indexesiflen(self.output_quantized_indexes)>0:d[OUTPUT_QUANTIZED_INDEXES_DICT_KEY]=self.output_quantized_indexesiflen(self.preserved_attributes)>0:d[PRESERVED_ATTRIBUTES_DICT_KEY]=self.preserved_attributesreturnd
[docs]classConvertCustomConfig:""" Custom configuration for :func:`~torch.ao.quantization.quantize_fx.convert_fx`. Example usage:: convert_custom_config = ConvertCustomConfig() \ .set_observed_to_quantized_mapping(ObservedCustomModule, QuantizedCustomModule) \ .set_preserved_attributes(["attr1", "attr2"]) """def__init__(self):self.observed_to_quantized_mapping:Dict[QuantType,Dict[Type,Type]]={}self.preserved_attributes:List[str]=[]
[docs]defset_observed_to_quantized_mapping(self,observed_class:Type,quantized_class:Type,quant_type:QuantType=QuantType.STATIC)->ConvertCustomConfig:""" Set the mapping from a custom observed module class to a custom quantized module class. The quantized module class must have a ``from_observed`` class method that converts the observed module class to the quantized module class. """ifquant_typenotinself.observed_to_quantized_mapping:self.observed_to_quantized_mapping[quant_type]={}self.observed_to_quantized_mapping[quant_type][observed_class]=quantized_classreturnself
[docs]defset_preserved_attributes(self,attributes:List[str])->ConvertCustomConfig:""" Set the names of the attributes that will persist in the graph module even if they are not used in the model's ``forward`` method. """self.preserved_attributes=attributesreturnself
# TODO: remove this
[docs]@classmethoddeffrom_dict(cls,convert_custom_config_dict:Dict[str,Any])->ConvertCustomConfig:""" Create a ``ConvertCustomConfig`` from a dictionary with the following items: "observed_to_quantized_custom_module_class": a nested dictionary mapping from quantization mode to an inner mapping from observed module classes to quantized module classes, e.g.:: { "static": {FloatCustomModule: ObservedCustomModule}, "dynamic": {FloatCustomModule: ObservedCustomModule}, "weight_only": {FloatCustomModule: ObservedCustomModule} } "preserved_attributes": a list of attributes that persist even if they are not used in ``forward`` This function is primarily for backward compatibility and may be removed in the future. """conf=cls()forquant_type_name,custom_module_mappinginconvert_custom_config_dict.get(OBSERVED_TO_QUANTIZED_DICT_KEY,{}).items():quant_type=_quant_type_from_str(quant_type_name)forobserved_class,quantized_classincustom_module_mapping.items():conf.set_observed_to_quantized_mapping(observed_class,quantized_class,quant_type)conf.set_preserved_attributes(convert_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY,[]))returnconf
[docs]defto_dict(self)->Dict[str,Any]:""" Convert this ``ConvertCustomConfig`` to a dictionary with the items described in :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`. """d:Dict[str,Any]={}forquant_type,observed_to_quantized_mappinginself.observed_to_quantized_mapping.items():ifOBSERVED_TO_QUANTIZED_DICT_KEYnotind:d[OBSERVED_TO_QUANTIZED_DICT_KEY]={}d[OBSERVED_TO_QUANTIZED_DICT_KEY][quant_type_to_str(quant_type)]=observed_to_quantized_mappingiflen(self.preserved_attributes)>0:d[PRESERVED_ATTRIBUTES_DICT_KEY]=self.preserved_attributesreturnd
[docs]classFuseCustomConfig:""" Custom configuration for :func:`~torch.ao.quantization.quantize_fx.fuse_fx`. Example usage:: fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["attr1", "attr2"]) """def__init__(self):self.preserved_attributes:List[str]=[]
[docs]defset_preserved_attributes(self,attributes:List[str])->FuseCustomConfig:""" Set the names of the attributes that will persist in the graph module even if they are not used in the model's ``forward`` method. """self.preserved_attributes=attributesreturnself
# TODO: remove this
[docs]@classmethoddeffrom_dict(cls,fuse_custom_config_dict:Dict[str,Any])->FuseCustomConfig:""" Create a ``ConvertCustomConfig`` from a dictionary with the following items: "preserved_attributes": a list of attributes that persist even if they are not used in ``forward`` This function is primarily for backward compatibility and may be removed in the future. """conf=cls()conf.set_preserved_attributes(fuse_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY,[]))returnconf
[docs]defto_dict(self)->Dict[str,Any]:""" Convert this ``FuseCustomConfig`` to a dictionary with the items described in :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`. """d:Dict[str,Any]={}iflen(self.preserved_attributes)>0:d[PRESERVED_ATTRIBUTES_DICT_KEY]=self.preserved_attributesreturnd
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.