Source code for torch.ao.quantization.pt2e.export_utils
# mypy: allow-untyped-defsimporttypesimporttorchimporttorch.nn.functionalasFfromtorch.ao.quantization.utilsimport_assert_and_get_unique_device__all__=["model_is_exported",]_EXPORTED_TRAINING_ATTR="_exported_training"class_WrapperModule(torch.nn.Module):"""Class to wrap a callable in an :class:`torch.nn.Module`. Use this if you are trying to export a callable. """def__init__(self,fn):super().__init__()self.fn=fndefforward(self,*args,**kwargs):"""Simple forward that just calls the ``fn`` provided to :meth:`WrapperModule.__init__`."""returnself.fn(*args,**kwargs)
[docs]defmodel_is_exported(m:torch.nn.Module)->bool:""" Return True if the `torch.nn.Module` was exported, False otherwise (e.g. if the model was FX symbolically traced or not traced at all). """returnisinstance(m,torch.fx.GraphModule)andany("val"inn.metaforninm.graph.nodes)
def_replace_dropout(m:torch.fx.GraphModule,train_to_eval:bool):""" Switch dropout patterns in the model between train and eval modes. Dropout has different behavior in train vs eval mode. For exported models, however, calling `model.train()` or `model.eval()` does not automatically switch the dropout behavior between the two modes, so here we need to rewrite the aten dropout patterns manually to achieve the same effect. See https://github.com/pytorch/pytorch/issues/103681. """# Avoid circular dependenciesfrom.utilsimport_get_aten_graph_module_for_pattern# Needed to ensure subgraph matches are self-containedm.graph.eliminate_dead_code()m.recompile()fromtorch._exportimportgm_using_training_irusing_training_ir=gm_using_training_ir(m)forinplacein[False,True]:defdropout_train(x):returnF.dropout(x,p=0.5,training=True,inplace=inplace)defdropout_eval(x):returnF.dropout(x,p=0.5,training=False,inplace=inplace)example_inputs=(torch.randn(1),)iftrain_to_eval:match_pattern=_get_aten_graph_module_for_pattern(_WrapperModule(dropout_train),example_inputs,using_training_ir=using_training_ir,)replacement_pattern=_get_aten_graph_module_for_pattern(_WrapperModule(dropout_eval),example_inputs,using_training_ir=using_training_ir,)else:match_pattern=_get_aten_graph_module_for_pattern(_WrapperModule(dropout_eval),example_inputs,using_training_ir=using_training_ir,)replacement_pattern=_get_aten_graph_module_for_pattern(_WrapperModule(dropout_train),example_inputs,using_training_ir=using_training_ir,)fromtorch.fx.subgraph_rewriterimportreplace_pattern_with_filtersreplace_pattern_with_filters(m,match_pattern,replacement_pattern,match_filters=[],ignore_literals=True,)m.recompile()def_replace_batchnorm(m:torch.fx.GraphModule,train_to_eval:bool):""" Switch batchnorm patterns in the model between train and eval modes. Batchnorm has different behavior in train vs eval mode. For exported models, however, calling `model.train()` or `model.eval()` does not automatically switch the batchnorm behavior between the two modes, so here we need to rewrite the aten batchnorm patterns manually to achieve the same effect. """# TODO(Leslie): This function still fails to support custom momentum and eps value.# Enable this support in future updates.# Avoid circular dependenciesfrom.utilsimport_get_aten_graph_module_for_pattern# Needed to ensure subgraph matches are self-containedm.graph.eliminate_dead_code()m.recompile()fromtorch._exportimportgm_using_training_irusing_training_ir=gm_using_training_ir(m)defbn_train(x:torch.Tensor,bn_weight:torch.Tensor,bn_bias:torch.Tensor,bn_running_mean:torch.Tensor,bn_running_var:torch.Tensor,):returnF.batch_norm(x,bn_running_mean,bn_running_var,bn_weight,bn_bias,training=True)defbn_eval(x:torch.Tensor,bn_weight:torch.Tensor,bn_bias:torch.Tensor,bn_running_mean:torch.Tensor,bn_running_var:torch.Tensor,):returnF.batch_norm(x,bn_running_mean,bn_running_var,bn_weight,bn_bias,training=False)example_inputs=(torch.randn(1,1,3,3),# xtorch.randn(1),# bn_weighttorch.randn(1),# bn_biastorch.randn(1),# bn_running_meantorch.randn(1),# bn_running_var)device=_assert_and_get_unique_device(m)is_cuda=deviceisnotNoneanddevice.type=="cuda"bn_train_aten=_get_aten_graph_module_for_pattern(_WrapperModule(bn_train),example_inputs,is_cuda,using_training_ir=using_training_ir,)bn_eval_aten=_get_aten_graph_module_for_pattern(_WrapperModule(bn_eval),example_inputs,is_cuda,using_training_ir=using_training_ir,)iftrain_to_eval:match_pattern=bn_train_atenreplacement_pattern=bn_eval_atenelse:match_pattern=bn_eval_atenreplacement_pattern=bn_train_atenfromtorch.fx.subgraph_rewriterimportreplace_pattern_with_filtersreplace_pattern_with_filters(m,match_pattern,replacement_pattern,match_filters=[],ignore_literals=True,)m.recompile()# TODO: expose these under this namespace?def_move_exported_model_to_eval(model:torch.fx.GraphModule):""" Move an exported GraphModule to eval mode. This is equivalent to model.eval() but only for certain special ops like dropout, batchnorm. QAT users should call this before performing inference on the model. This call is idempotent; if the model is already in eval mode, nothing will happen. """is_training=getattr(model,_EXPORTED_TRAINING_ATTR,True)ifnotis_training:returnmodelsetattr(model,_EXPORTED_TRAINING_ATTR,False)_replace_dropout(model,train_to_eval=True)_replace_batchnorm(model,train_to_eval=True)returnmodeldef_move_exported_model_to_train(model:torch.fx.GraphModule):""" Move an exported GraphModule to train mode. This is equivalent to model.train() but only for certain special ops like dropout, batchnorm. QAT users should call this before performing training on the model. This call is idempotent; if the model is already in train mode, nothing will happen. """is_training=getattr(model,_EXPORTED_TRAINING_ATTR,False)ifis_training:returnmodelsetattr(model,_EXPORTED_TRAINING_ATTR,True)_replace_dropout(model,train_to_eval=False)_replace_batchnorm(model,train_to_eval=False)returnmodeldef_allow_exported_model_train_eval(model:torch.fx.GraphModule):""" Allow users to call `model.train()` and `model.eval()` on an exported model, but with the effect of changing behavior between the two modes limited to special ops only, which are currently dropout and batchnorm. Note: This does not achieve the same effect as what `model.train()` and `model.eval()` does in eager models, but only provides an approximation. In particular, user code branching on `training` flag will not function correctly in general because the branch is already specialized at export time. Additionally, other ops beyond dropout and batchnorm that have different train/eval behavior will also not be converted properly. """def_train(self,mode:bool=True):ifmode:_move_exported_model_to_train(self)else:_move_exported_model_to_eval(self)def_eval(self):_move_exported_model_to_eval(self)model.train=types.MethodType(_train,model)# type: ignore[method-assign]model.eval=types.MethodType(_eval,model)# type: ignore[method-assign]returnmodel
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.