"""This module contains utility method for mobile model optimization and lint."""importtorchfromenumimportEnumfromtorch._CimportMobileOptimizerTypefromtypingimportOptional,Set,List,AnyStrclassLintCode(Enum):BUNDLED_INPUT=1REQUIRES_GRAD=2DROPOUT=3BATCHNORM=4
[docs]defoptimize_for_mobile(script_module:torch.jit.ScriptModule,optimization_blocklist:Optional[Set[MobileOptimizerType]]=None,preserved_methods:Optional[List[AnyStr]]=None,backend:str='CPU')->torch.jit.RecursiveScriptModule:""" Args: script_module: An instance of torch script module with type of ScriptModule. optimization_blocklist: A set with type of MobileOptimizerType. When set is not passed, optimization method will run all the optimizer pass; otherwise, optimizer method will run the optimization pass that is not included inside optimization_blocklist. preserved_methods: A list of methods that needed to be preserved when freeze_module pass is invoked backend: Device type to use for running the result model ('CPU'(default), 'Vulkan' or 'Metal'). Returns: A new optimized torch script module """ifnotisinstance(script_module,torch.jit.ScriptModule):raiseTypeError('Got {}, but ScriptModule is expected.'.format(type(script_module)))ifoptimization_blocklistisNone:optimization_blocklist=set()ifpreserved_methodsisNone:preserved_methods=[]# Convert potential byte arrays into strings (if there is any) to pass type checking# Here we use a new name as assigning it back to preserved_methods will invoke# mypy errors (i.e. List[AnyStr] = List[str])preserved_methods_str:List[str]=[str(method)formethodinpreserved_methods]bundled_inputs_attributes=_get_bundled_inputs_preserved_attributes(script_module,preserved_methods_str)ifall([hasattr(script_module,method)formethodinbundled_inputs_attributes]):preserved_methods_str=list(set(preserved_methods_str+bundled_inputs_attributes))non_exist_methods=[]formethodinpreserved_methods_str:ifnothasattr(script_module,method):non_exist_methods.append(method)ifnon_exist_methods:raiseAttributeError('The following methods to preserve do not exist in script_module: {}'.format(', '.join(non_exist_methods)))backend=backend.lower()ifbackend=='cpu':optimized_cpp_module=torch._C._jit_pass_optimize_for_mobile(script_module._c,optimization_blocklist,preserved_methods_str)elifbackend=='vulkan':optimized_cpp_module=torch._C._jit_pass_vulkan_optimize_for_mobile(script_module._c,preserved_methods_str)elifbackend=='metal':optimized_cpp_module=torch._C._jit_pass_metal_optimize_for_mobile(script_module._c,preserved_methods_str)else:raiseTypeError("Unknown backend, must be one of 'CPU', 'Vulkan' or 'Metal'")returntorch.jit._recursive.wrap_cpp_module(optimized_cpp_module)
defgenerate_mobile_module_lints(script_module:torch.jit.ScriptModule):""" Args: script_module: An instance of torch script module with type of ScriptModule Returns: lint_map: A list of dictionary that contains modules lints """ifnotisinstance(script_module,torch.jit.ScriptModule):raiseTypeError('Got {}, but ScriptModule is expected.'.format(type(script_module)))lint_list=[]ifnothasattr(script_module,"_generate_bundled_inputs_for_forward"):lint_list.append({"name":LintCode.BUNDLED_INPUT.name,"message":"No bundled input for forward, please add bundled inputs ""before saving the module using torch.utils.bundled_inputs.augment_model_with_bundled_inputs."})forname,paraminscript_module.named_parameters():ifparam.requires_grad:lint_list.append({"name":LintCode.REQUIRES_GRAD.name,"message":"Param {} requires grad, ""please set torch.no_grad() to reduce memory usage and improve computation speed during ""inference phase.".format(name)})op_names=torch.jit.export_opnames(script_module)forop_nameinop_names:if"dropout"inop_name:lint_list.append({"name":LintCode.DROPOUT.name,"message":"Operator {} exists, remember to call eval() before ""saving the module.and call torch.utils.mobile_optimizer.optimize_for_mobile to drop dropout ""operator.".format(op_name)})if"batch_norm"inop_name:lint_list.append({"name":LintCode.BATCHNORM.name,"message":"Operator {} exists, remember to call eval() before ""saving the module and call torch.utils.mobile_optimizer.optimize_for_mobile to drop batch_norm ""operator.".format(op_name)})returnlint_listdef_get_bundled_inputs_preserved_attributes(script_module:torch.jit.ScriptModule,preserved_methods:List[str])->List[str]:bundled_inputs_attributes=[]# Has bundled inputs for forwardifhasattr(script_module,'get_all_bundled_inputs'):bundled_inputs_attributes.append('get_all_bundled_inputs')bundled_inputs_attributes.append('get_num_bundled_inputs')# Bundled inputs in module after the change that introduced bundled inputs for multiple functionsifhasattr(script_module,'get_bundled_inputs_functions_and_info'):bundled_inputs_attributes.append('get_bundled_inputs_functions_and_info')all_info=script_module.get_bundled_inputs_functions_and_info()forfunction_nameinall_info:iffunction_namenotinpreserved_methods:bundled_inputs_attributes.append(function_name)bundled_inputs_attributes.append("get_all_bundled_inputs_for_"+function_name)bundled_inputs_attributes.append("_bundled_inputs_deflated_"+function_name)returnbundled_inputs_attributes
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.