fromtypingimportDict,Any,List,Callable,Tuple,Optional,Setimporttorchfromtorch.fximportGraphModulefromtorch.fx._symbolic_traceimportTracerfromtorch.fx.nodeimportTarget,Node,Argumentfromtorch.nn.intrinsicimport_FusedModulefrom.fximportFuser# noqa: F401from.fximportprepare,convert# noqa: F401from.fximportget_tensorrt_backend_config_dict# noqa: F401from.fx.graph_moduleimportObservedGraphModulefrom.fx.qconfig_utilsimport(check_is_valid_convert_custom_config_dict,check_is_valid_fuse_custom_config_dict,check_is_valid_prepare_custom_config_dict,check_is_valid_qconfig_dict,)from.fx.utilsimportgraph_pretty_str# noqa: F401from.fx.utilsimportget_custom_module_class_keys# noqa: F401def_check_is_graph_module(model:torch.nn.Module)->None:ifnotisinstance(model,GraphModule):raiseValueError("input model must be a GraphModule, "+"Got type:"+str(type(model))+" Please make "+"sure to follow the tutorials.")def_swap_ff_with_fxff(model:torch.nn.Module)->None:r""" Swap FloatFunctional with FXFloatFunctional """modules_to_swap=[]forname,moduleinmodel.named_children():ifisinstance(module,torch.nn.quantized.FloatFunctional):modules_to_swap.append(name)else:_swap_ff_with_fxff(module)fornameinmodules_to_swap:delmodel._modules[name]model._modules[name]=torch.nn.quantized.FXFloatFunctional()def_fuse_fx(graph_module:GraphModule,is_qat:bool,fuse_custom_config_dict:Optional[Dict[str,Any]]=None,backend_config_dict:Optional[Dict[str,Any]]=None,)->GraphModule:r""" Internal helper function to fuse modules in preparation for quantization Args: graph_module: GraphModule object from symbolic tracing (torch.fx.symbolic_trace) """_check_is_graph_module(graph_module)fuser=Fuser()returnfuser.fuse(graph_module,is_qat,fuse_custom_config_dict,backend_config_dict)classScope(object):""" Scope object that records the module path and the module type of a module. Scope is used to track the information of the module that contains a Node in a Graph of GraphModule. For example:: class Sub(torch.nn.Module): def forward(self, x): # This will be a call_method Node in GraphModule, # scope for this would be (module_path="sub", module_type=Sub) return x.transpose(1, 2) class M(torch.nn.Module): def __init__(self): self.sub = Sub() def forward(self, x): # This will be a call_method Node as well, # scope for this would be (module_path="", None) x = x.transpose(1, 2) x = self.sub(x) return x """def__init__(self,module_path:str,module_type:Any):super().__init__()self.module_path=module_pathself.module_type=module_typeclassScopeContextManager(object):""" A context manager to track the Scope of Node during symbolic tracing. When entering a forward function of a Module, we'll update the scope information of the current module, and when we exit, we'll restore the previous scope information. """def__init__(self,scope:Scope,current_module:torch.nn.Module,current_module_path:str):super().__init__()self.prev_module_type=scope.module_typeself.prev_module_path=scope.module_pathself.scope=scopeself.scope.module_path=current_module_pathself.scope.module_type=type(current_module)def__enter__(self):returndef__exit__(self,*args):self.scope.module_path=self.prev_module_pathself.scope.module_type=self.prev_module_typereturnclassQuantizationTracer(Tracer):def__init__(self,skipped_module_names:List[str],skipped_module_classes:List[Callable]):super().__init__()self.skipped_module_names=skipped_module_namesself.skipped_module_classes=skipped_module_classes# NB: initialized the module_type of top level module to None# we are assuming people won't configure the model with the type of top level# module here, since people can use "" for global config# We can change this if there is a use case that configures# qconfig using top level module typeself.scope=Scope("",None)self.node_name_to_scope:Dict[str,Tuple[str,type]]={}self.record_stack_traces=Truedefis_leaf_module(self,m:torch.nn.Module,module_qualified_name:str)->bool:return((m.__module__.startswith("torch.nn")andnotisinstance(m,torch.nn.Sequential))ormodule_qualified_nameinself.skipped_module_namesortype(m)inself.skipped_module_classesorisinstance(m,_FusedModule))defcall_module(self,m:torch.nn.Module,forward:Callable[...,Any],args:Tuple[Any,...],kwargs:Dict[str,Any],)->Any:module_qualified_name=self.path_of_module(m)# Creating scope with information of current module# scope will be restored automatically upon exitwithScopeContextManager(self.scope,m,module_qualified_name):returnsuper().call_module(m,forward,args,kwargs)defcreate_node(self,kind:str,target:Target,args:Tuple[Argument,...],kwargs:Dict[str,Argument],name:Optional[str]=None,type_expr:Optional[Any]=None,)->Node:node=super().create_node(kind,target,args,kwargs,name,type_expr)self.node_name_to_scope[node.name]=(self.scope.module_path,self.scope.module_type,)returnnodedef_prepare_fx(model:torch.nn.Module,qconfig_dict:Any,is_qat:bool,prepare_custom_config_dict:Optional[Dict[str,Any]]=None,equalization_qconfig_dict:Optional[Dict[str,Any]]=None,backend_config_dict:Optional[Dict[str,Any]]=None,is_standalone_module:bool=False,)->ObservedGraphModule:r""" Internal helper function for prepare_fx Args: `model`, `qconfig_dict`, `prepare_custom_config_dict`, `equalization_qonfig_dict`: see docs for :func:`~torch.ao.quantization.prepare_fx` `is_standalone_module`: a boolean flag indicates whether we are quantizing a standalone module or not, a standalone module is a submodule of the parent module that is not inlined in theforward graph of the parent module, the way we quantize standalone module is described in: :func:`~torch.ao.quantization._prepare_standalone_module_fx` """ifprepare_custom_config_dictisNone:prepare_custom_config_dict={}ifequalization_qconfig_dictisNone:equalization_qconfig_dict={}check_is_valid_qconfig_dict(qconfig_dict)check_is_valid_prepare_custom_config_dict(prepare_custom_config_dict)check_is_valid_qconfig_dict(equalization_qconfig_dict)skipped_module_names=prepare_custom_config_dict.get("non_traceable_module_name",[])skipped_module_classes=prepare_custom_config_dict.get("non_traceable_module_class",[])# swap FloatFunctional with FXFloatFunctional_swap_ff_with_fxff(model)# symbolically trace the modelifnotis_standalone_module:# standalone module and custom module config are applied in top level modulestandalone_module_name_configs=prepare_custom_config_dict.get("standalone_module_name",[])skipped_module_names+=[config[0]forconfiginstandalone_module_name_configs]standalone_module_class_configs=prepare_custom_config_dict.get("standalone_module_class",[])skipped_module_classes+=[config[0]forconfiginstandalone_module_class_configs]float_custom_module_classes=get_custom_module_class_keys(prepare_custom_config_dict,"float_to_observed_custom_module_class")skipped_module_classes+=float_custom_module_classespreserved_attributes=prepare_custom_config_dict.get("preserved_attributes",[])tracer=QuantizationTracer(skipped_module_names,skipped_module_classes)graph_module=GraphModule(model,tracer.trace(model))forattr_nameinpreserved_attributes:setattr(graph_module,attr_name,getattr(model,attr_name))graph_module=_fuse_fx(graph_module,is_qat,prepare_custom_config_dict,backend_config_dict)prepared=prepare(graph_module,qconfig_dict,is_qat,tracer.node_name_to_scope,prepare_custom_config_dict=prepare_custom_config_dict,equalization_qconfig_dict=equalization_qconfig_dict,backend_config_dict=backend_config_dict,is_standalone_module=is_standalone_module,)forattr_nameinpreserved_attributes:setattr(prepared,attr_name,getattr(model,attr_name))returnprepareddef_prepare_standalone_module_fx(model:torch.nn.Module,qconfig_dict:Any,is_qat:bool,prepare_custom_config_dict:Optional[Dict[str,Any]]=None,backend_config_dict:Optional[Dict[str,Any]]=None,)->GraphModule:r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the parent module. standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. How the standalone module is observed is specified by `input_quantized_idxs` and `output_quantized_idxs` in the prepare_custom_config for the standalone module Returns: * model(GraphModule): prepared standalone module. It has these attributes: * `_standalone_module_input_quantized_idxs(List[Int])`: a list of indexes for the graph input that is expected to be quantized, same as input_quantized_idxs configuration provided for the standalone module * `_standalone_module_output_quantized_idxs(List[Int])`: a list of indexs for the graph output that is quantized same as input_quantized_idxs configuration provided for the standalone module """return_prepare_fx(model,qconfig_dict,is_qat,prepare_custom_config_dict,backend_config_dict=backend_config_dict,is_standalone_module=True,)
[docs]deffuse_fx(model:torch.nn.Module,fuse_custom_config_dict:Optional[Dict[str,Any]]=None)->GraphModule:r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode. Fusion rules are defined in torch.quantization.fx.fusion_pattern.py Args: * `model`: a torch.nn.Module model * `fuse_custom_config_dict`: Dictionary for custom configurations for fuse_fx, e.g.:: fuse_custom_config_dict = { "additional_fuser_method_mapping": { (Module1, Module2): fuse_module1_module2 } # Attributes that are not used in forward function will # be removed when constructing GraphModule, this is a list of attributes # to preserve as an attribute of the GraphModule even when they are # not used in the code, these attributes will also persist through deepcopy "preserved_attributes": ["preserved_attr"], } Example:: from torch.ao.quantization import fuse_fx m = Model().eval() m = fuse_fx(m) """torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx")assertnotmodel.training,"fuse_fx only works on models in eval mode"check_is_valid_fuse_custom_config_dict(fuse_custom_config_dict)graph_module=torch.fx.symbolic_trace(model)preserved_attributes:Set[str]=set()iffuse_custom_config_dict:preserved_attributes=set(fuse_custom_config_dict.get("preserved_attributes",[]))forattr_nameinpreserved_attributes:setattr(graph_module,attr_name,getattr(model,attr_name))return_fuse_fx(graph_module,False,fuse_custom_config_dict)
[docs]defprepare_fx(model:torch.nn.Module,qconfig_dict:Any,prepare_custom_config_dict:Optional[Dict[str,Any]]=None,equalization_qconfig_dict:Optional[Dict[str,Any]]=None,backend_config_dict:Optional[Dict[str,Any]]=None,)->ObservedGraphModule:r""" Prepare a model for post training static quantization Args: * `model`: torch.nn.Module model, must be in eval mode * `qconfig_dict`: qconfig_dict is a dictionary with the following configurations:: qconfig_dict = { # optional, global config "": qconfig?, # optional, used for module and function types # could also be split into module_types and function_types if we prefer "object_type": [ (torch.nn.Conv2d, qconfig?), (torch.nn.functional.add, qconfig?), ..., ], # optional, used for module names "module_name": [ ("foo.bar", qconfig?) ..., ], # optional, matched in order, first match takes precedence "module_name_regex": [ ("foo.*bar.*conv[0-9]+", qconfig?) ..., ], # optional, used for matching object type invocations in a submodule by # order # TODO(future PR): potentially support multiple indices ('0,1') and/or # ranges ('0:3'). "module_name_object_type_order": [ # fully_qualified_name, object_type, index, qconfig ("foo.bar", torch.nn.functional.linear, 0, qconfig?), ], # priority (in increasing order): # global, object_type, module_name_regex, module_name, # module_name_object_type_order # qconfig == None means fusion and quantization should be skipped for anything # matching the rule } * `prepare_custom_config_dict`: customization configuration dictionary for quantization tool:: prepare_custom_config_dict = { # optional: specify the path for standalone modules # These modules are symbolically traced and quantized as one unit "standalone_module_name": [ # module_name, qconfig_dict, prepare_custom_config_dict ("submodule.standalone", None, # qconfig_dict for the prepare function called in the submodule, # None means use qconfig from parent qconfig_dict {"input_quantized_idxs": [], "output_quantized_idxs": []}), # prepare_custom_config_dict {} # backend_config_dict, TODO: point to README doc when it's ready ], "standalone_module_class": [ # module_class, qconfig_dict, prepare_custom_config_dict (StandaloneModule, None, # qconfig_dict for the prepare function called in the submodule, # None means use qconfig from parent qconfig_dict {"input_quantized_idxs": [0], "output_quantized_idxs": [0]}, # prepare_custom_config_dict {}) # backend_config_dict, TODO: point to README doc when it's ready ], # user will manually define the corresponding observed # module class which has a from_float class method that converts # float custom module to observed custom module # (only needed for static quantization) "float_to_observed_custom_module_class": { "static": { CustomModule: ObservedCustomModule } }, # the qualified names for the submodule that are not symbolically traceable "non_traceable_module_name": [ "non_traceable_module" ], # the module classes that are not symbolically traceable # we'll also put dynamic/weight_only custom module here "non_traceable_module_class": [ NonTraceableModule ], # Additional fuser_method mapping "additional_fuser_method_mapping": { (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn }, # Additioanl module mapping for qat "additional_qat_module_mapping": { torch.nn.intrinsic.ConvBn2d: torch.nn.qat.ConvBn2d }, # Additional fusion patterns "additional_fusion_pattern": { (torch.nn.BatchNorm2d, torch.nn.Conv2d): ConvReluFusionhandler }, # Additional quantization patterns "additional_quant_pattern": { torch.nn.Conv2d: ConvReluQuantizeHandler, (torch.nn.ReLU, torch.nn.Conv2d): ConvReluQuantizeHandler, } # By default, inputs and outputs of the graph are assumed to be in # fp32. Providing `input_quantized_idxs` will set the inputs with the # corresponding indices to be quantized. Providing # `output_quantized_idxs` will set the outputs with the corresponding # indices to be quantized. "input_quantized_idxs": [0], "output_quantized_idxs": [0], # Attributes that are not used in forward function will # be removed when constructing GraphModule, this is a list of attributes # to preserve as an attribute of the GraphModule even when they are # not used in the code, these attributes will also persist through deepcopy "preserved_attributes": ["preserved_attr"], } * `equalization_qconfig_dict`: equalization_qconfig_dict is a dictionary with a similar structure as qconfig_dict except it will contain configurations specific to equalization techniques such as input-weight equalization. * `backend_config_dict`: a dictionary that specifies how operators are quantized in a backend, this includes how the operaetors are observed, supported fusion patterns, how quantize/dequantize ops are inserted, supported dtypes etc. The structure of the dictionary is still WIP and will change in the future, please don't use right now. Return: A GraphModule with observer (configured by qconfig_dict), ready for calibration Example:: import torch from torch.ao.quantization import get_default_qconfig from torch.ao.quantization import prepare_fx float_model.eval() qconfig = get_default_qconfig('fbgemm') def calibrate(model, data_loader): model.eval() with torch.no_grad(): for image, target in data_loader: model(image) qconfig_dict = {"": qconfig} prepared_model = prepare_fx(float_model, qconfig_dict) # Run calibration calibrate(prepared_model, sample_inference_data) """torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx")assertnotmodel.training,"prepare_fx only works for models in "+"eval mode"return_prepare_fx(model,qconfig_dict,False,# is_qatprepare_custom_config_dict,equalization_qconfig_dict,backend_config_dict,)
[docs]defprepare_qat_fx(model:torch.nn.Module,qconfig_dict:Any,prepare_custom_config_dict:Optional[Dict[str,Any]]=None,backend_config_dict:Optional[Dict[str,Any]]=None,)->ObservedGraphModule:r""" Prepare a model for quantization aware training Args: * `model`: torch.nn.Module model, must be in train mode * `qconfig_dict`: see :func:`~torch.ao.quantization.prepare_fx` * `prepare_custom_config_dict`: see :func:`~torch.ao.quantization.prepare_fx` * `backend_config_dict`: see :func:`~torch.ao.quantization.prepare_fx` Return: A GraphModule with fake quant modules (configured by qconfig_dict), ready for quantization aware training Example:: import torch from torch.ao.quantization import get_default_qat_qconfig from torch.ao.quantization import prepare_fx qconfig = get_default_qat_qconfig('fbgemm') def train_loop(model, train_data): model.train() for image, target in data_loader: ... float_model.train() qconfig_dict = {"": qconfig} prepared_model = prepare_fx(float_model, qconfig_dict) # Run calibration train_loop(prepared_model, train_loop) """torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx")assertmodel.training,"prepare_qat_fx only works for models in "+"train mode"return_prepare_fx(model,qconfig_dict,True,# is_qatprepare_custom_config_dict,backend_config_dict=backend_config_dict,)
def_convert_fx(graph_module:GraphModule,is_reference:bool,convert_custom_config_dict:Optional[Dict[str,Any]]=None,is_standalone_module:bool=False,_remove_qconfig:bool=True,qconfig_dict:Dict[str,Any]=None,)->torch.nn.Module:""" `is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx` """ifconvert_custom_config_dictisNone:convert_custom_config_dict={}_check_is_graph_module(graph_module)check_is_valid_convert_custom_config_dict(convert_custom_config_dict)quantized=convert(graph_module,is_reference,convert_custom_config_dict,is_standalone_module,_remove_qconfig_flag=_remove_qconfig,convert_qconfig_dict=qconfig_dict,)preserved_attributes=convert_custom_config_dict.get("preserved_attributes",[])forattr_nameinpreserved_attributes:setattr(quantized,attr_name,getattr(graph_module,attr_name))returnquantized
[docs]defconvert_fx(graph_module:GraphModule,is_reference:bool=False,convert_custom_config_dict:Optional[Dict[str,Any]]=None,_remove_qconfig:bool=True,qconfig_dict:Dict[str,Any]=None,)->torch.nn.Module:r""" Convert a calibrated or trained model to a quantized model Args: * `graph_module`: A prepared and calibrated/trained model (GraphModule) * `is_reference`: flag for whether to produce a reference quantized model, which will be a common interface between pytorch quantization with other backends like accelerators * `convert_custom_config_dict`: dictionary for custom configurations for convert function:: convert_custom_config_dict = { # additional object (module/operator) mappings that will overwrite the default # module mappinng "additional_object_mapping": { "static": { FloatModule: QuantizedModule, float_op: quantized_op }, "dynamic": { FloatModule: DynamicallyQuantizedModule, float_op: dynamically_quantized_op }, }, # user will manually define the corresponding quantized # module class which has a from_observed class method that converts # observed custom module to quantized custom module "observed_to_quantized_custom_module_class": { "static": { ObservedCustomModule: QuantizedCustomModule }, "dynamic": { ObservedCustomModule: QuantizedCustomModule }, "weight_only": { ObservedCustomModule: QuantizedCustomModule } }, # Attributes that are not used in forward function will # be removed when constructing GraphModule, this is a list of attributes # to preserve as an attribute of the GraphModule even when they are # not used in the code "preserved_attributes": ["preserved_attr"], } * `_remove_qconfig`: Option to remove the qconfig attributes in the model after convert. * `qconfig_dict`: qconfig_dict with either same keys as what is passed to the qconfig_dict in `prepare_fx` API, with same values or `None`, or additional keys with values set to `None` For each entry whose value is set to None, we skip quantizing that entry in the model:: qconfig_dict = { # used for object_type, skip quantizing torch.nn.functional.add "object_type": [ (torch.nn.functional.add, None), (torch.nn.functional.linear, qconfig_from_prepare) ..., ], # sed for module names, skip quantizing "foo.bar" "module_name": [ ("foo.bar", None) ..., ], } Return: A quantized model (GraphModule) Example:: # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training quantized_model = convert_fx(prepared_model) """torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx")return_convert_fx(graph_module,is_reference,convert_custom_config_dict,_remove_qconfig=_remove_qconfig,qconfig_dict=qconfig_dict,)
def_convert_standalone_module_fx(graph_module:GraphModule,is_reference:bool=False,convert_custom_config_dict:Optional[Dict[str,Any]]=None,)->torch.nn.Module:r""" [Internal use only] Convert a model produced by :func:`~torch.ao.quantization.prepare_standalone_module_fx` and convert it to a quantized model Returns a quantized standalone module, whether input/output is quantized is specified by prepare_custom_config_dict, with input_quantized_idxs, output_quantized_idxs, please see docs for prepare_fx for details """return_convert_fx(graph_module,is_reference,convert_custom_config_dict,is_standalone_module=True,)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.