Source code for torch.ao.nn.intrinsic.qat.modules.conv_fused
# mypy: allow-untyped-defsimportmathfromtypingimportClassVar,Optional,Typeimporttorchimporttorch.ao.nn.intrinsicasnniimporttorch.ao.nn.qatasnnqatimporttorch.nnasnnimporttorch.nn.functionalasFfromtorch.nnimportinitfromtorch.nn.modules.utilsimport_pair,_single,_triplefromtorch.nn.parameterimportParameterfromtorch.nn.utilsimportfuse_conv_bn_weights__all__=["ConvBn1d","ConvBnReLU1d","ConvReLU1d","ConvBn2d","ConvBnReLU2d","ConvReLU2d","ConvBn3d","ConvBnReLU3d","ConvReLU3d","update_bn_stats","freeze_bn_stats",]_BN_CLASS_MAP={1:nn.BatchNorm1d,2:nn.BatchNorm2d,3:nn.BatchNorm3d,}class_ConvBnNd(nn.modules.conv._ConvNd,nni._FusedModule):_version=2_FLOAT_MODULE:ClassVar[Type[nn.modules.conv._ConvNd]]def__init__(self,# ConvNd argsin_channels,out_channels,kernel_size,stride,padding,dilation,transposed,output_padding,groups,bias,padding_mode,# BatchNormNd args# num_features: out_channelseps=1e-05,momentum=0.1,# affine: True# track_running_stats: True# Args for this modulefreeze_bn=False,qconfig=None,dim=2,):nn.modules.conv._ConvNd.__init__(self,in_channels,out_channels,kernel_size,stride,padding,dilation,transposed,output_padding,groups,False,padding_mode,)assertqconfig,"qconfig must be provided for QAT module"self.qconfig=qconfigself.freeze_bn=freeze_bnifself.trainingelseTrueself.bn=_BN_CLASS_MAP[dim](out_channels,eps,momentum,True,True)self.weight_fake_quant=self.qconfig.weight()ifbias:self.bias=Parameter(torch.empty(out_channels))else:self.register_parameter("bias",None)self.reset_bn_parameters()# this needs to be called after reset_bn_parameters,# as they modify the same stateifself.training:iffreeze_bn:self.freeze_bn_stats()else:self.update_bn_stats()else:self.freeze_bn_stats()self._enable_slow_path_for_better_numerical_stability=Falsedefreset_running_stats(self):self.bn.reset_running_stats()defreset_bn_parameters(self):self.bn.reset_running_stats()init.uniform_(self.bn.weight)init.zeros_(self.bn.bias)# note: below is actually for conv, not BNifself.biasisnotNone:fan_in,_=init._calculate_fan_in_and_fan_out(self.weight)bound=1/math.sqrt(fan_in)init.uniform_(self.bias,-bound,bound)defreset_parameters(self):super().reset_parameters()defupdate_bn_stats(self):self.freeze_bn=Falseself.bn.training=Truereturnselfdeffreeze_bn_stats(self):self.freeze_bn=Trueself.bn.training=Falsereturnselfdef_forward(self,input):ifself._enable_slow_path_for_better_numerical_stability:returnself._forward_slow(input)returnself._forward_approximate(input)def_forward_approximate(self,input):"""Approximated method to fuse conv and bn. It requires only one forward pass. conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std """assertself.bn.running_varisnotNonerunning_std=torch.sqrt(self.bn.running_var+self.bn.eps)scale_factor=self.bn.weight/running_stdweight_shape=[1]*len(self.weight.shape)weight_shape[0]=-1bias_shape=[1]*len(self.weight.shape)bias_shape[1]=-1scaled_weight=self.weight_fake_quant(self.weight*scale_factor.reshape(weight_shape))# using zero bias here since the bias for original conv# will be added laterifself.biasisnotNone:zero_bias=torch.zeros_like(self.bias,dtype=input.dtype)else:zero_bias=torch.zeros(self.out_channels,device=scaled_weight.device,dtype=input.dtype)conv=self._conv_forward(input,scaled_weight,zero_bias)conv_orig=conv/scale_factor.reshape(bias_shape)ifself.biasisnotNone:conv_orig=conv_orig+self.bias.reshape(bias_shape)conv=self.bn(conv_orig)returnconvdef_forward_slow(self,input):""" A more accurate but slow method to compute conv bn fusion, following https://arxiv.org/pdf/1806.08342.pdf It requires two forward passes but handles the case bn.weight == 0 Conv: Y = WX + B_c Conv without bias: Y0 = WX = Y - B_c, Y = Y0 + B_c Batch statistics: mean_Y = Y.mean() = Y0.mean() + B_c var_Y = (Y - mean_Y)^2.mean() = (Y0 - Y0.mean())^2.mean() BN (r: bn.weight, beta: bn.bias): Z = r * (Y - mean_Y) / sqrt(var_Y + eps) + beta = r * (Y0 - Y0.mean()) / sqrt(var_Y + eps) + beta Fused Conv BN training (std_Y = sqrt(var_Y + eps)): Z = (r * W / std_Y) * X + r * (B_c - mean_Y) / std_Y + beta = (r * W / std_Y) * X - r * Y0.mean() / std_Y + beta Fused Conv BN inference (running_std = sqrt(running_var + eps)): Z = (r * W / running_std) * X - r * (running_mean - B_c) / running_std + beta QAT with fused conv bn: Z_train = fake_quant(r * W / running_std) * X * (running_std / std_Y) - r * Y0.mean() / std_Y + beta = conv(X, fake_quant(r * W / running_std)) * (running_std / std_Y) - r * Y0.mean() / std_Y + beta Z_inference = conv(X, fake_quant(r * W / running_std)) - r * (running_mean - B_c) / running_std + beta """assertself.bn.running_varisnotNoneassertself.bn.running_meanisnotNone# using zero bias here since the bias for original conv# will be added laterzero_bias=torch.zeros(self.out_channels,device=self.weight.device,dtype=input.dtype)weight_shape=[1]*len(self.weight.shape)weight_shape[0]=-1bias_shape=[1]*len(self.weight.shape)bias_shape[1]=-1ifself.bn.training:# needed to compute batch mean/stdconv_out=self._conv_forward(input,self.weight,zero_bias)# update bn statisticswithtorch.no_grad():conv_out_bias=(conv_outifself.biasisNoneelseconv_out+self.bias.reshape(bias_shape))self.bn(conv_out_bias)# fused conv + bn without bias using bn running statisticsrunning_std=torch.sqrt(self.bn.running_var+self.bn.eps)scale_factor=self.bn.weight/running_stdscaled_weight=self.weight_fake_quant(self.weight*scale_factor.reshape(weight_shape))# fused conv without bias for inference: (r * W / running_std) * Xconv_bn=self._conv_forward(input,scaled_weight,zero_bias)avg_dims=[0]+list(range(2,len(self.weight.shape)))batch_mean=conv_out.mean(avg_dims)batch_var=torch.square(conv_out-batch_mean.reshape(bias_shape)).mean(avg_dims)batch_std=torch.sqrt(batch_var+self.bn.eps)# scale to use batch std in training mode# conv(X, r * W / std_Y) = conv(X, r * W / running_std) * (running_std / std_Y)unscale_factor=running_std/batch_stdconv_bn*=unscale_factor.reshape(bias_shape)fused_mean=batch_meanfused_std=batch_stdelse:# fused conv + bn without bias using bn running statisticsrunning_std=torch.sqrt(self.bn.running_var+self.bn.eps)scale_factor=self.bn.weight/running_stdscaled_weight=self.weight_fake_quant(self.weight*scale_factor.reshape(weight_shape))# fused conv without bias for inference: (r * W / running_std) * Xconv_bn=self._conv_forward(input,scaled_weight,zero_bias)fused_mean=self.bn.running_mean-(self.biasifself.biasisnotNoneelse0)fused_std=running_std# fused bias = beta - r * mean / stdfused_bias=self.bn.bias-self.bn.weight*fused_mean/fused_stdconv_bn+=fused_bias.reshape(bias_shape)# HACK to let conv bias participate in loss to avoid DDP error (parameters# were not used in producing loss)ifself.biasisnotNone:conv_bn+=(self.bias-self.bias).reshape(bias_shape)returnconv_bndefextra_repr(self):# TODO(jerryzh): extendreturnsuper().extra_repr()defforward(self,input):returnself._forward(input)deftrain(self,mode=True):""" Batchnorm's training behavior is using the self.training flag. Prevent changing it if BN is frozen. This makes sure that calling `model.train()` on a model with a frozen BN will behave properly. """self.training=modeifnotself.freeze_bn:formoduleinself.children():module.train(mode)returnself# ===== Serialization version history =====## Version 1/None# self# |--- weight : Tensor# |--- bias : Tensor# |--- gamma : Tensor# |--- beta : Tensor# |--- running_mean : Tensor# |--- running_var : Tensor# |--- num_batches_tracked : Tensor## Version 2# self# |--- weight : Tensor# |--- bias : Tensor# |--- bn : Module# |--- weight : Tensor (moved from v1.self.gamma)# |--- bias : Tensor (moved from v1.self.beta)# |--- running_mean : Tensor (moved from v1.self.running_mean)# |--- running_var : Tensor (moved from v1.self.running_var)# |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked)def_load_from_state_dict(self,state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs,):version=local_metadata.get("version",None)ifversionisNoneorversion==1:# BN related parameters and buffers were moved into the BN module for v2v2_to_v1_names={"bn.weight":"gamma","bn.bias":"beta","bn.running_mean":"running_mean","bn.running_var":"running_var","bn.num_batches_tracked":"num_batches_tracked",}forv2_name,v1_nameinv2_to_v1_names.items():ifprefix+v1_nameinstate_dict:state_dict[prefix+v2_name]=state_dict[prefix+v1_name]state_dict.pop(prefix+v1_name)elifprefix+v2_nameinstate_dict:# there was a brief period where forward compatibility# for this module was broken (between# https://github.com/pytorch/pytorch/pull/38478# and https://github.com/pytorch/pytorch/pull/38820)# and modules emitted the v2 state_dict format while# specifying that version == 1. This patches the forward# compatibility issue by allowing the v2 style entries to# be used.passelifstrict:missing_keys.append(prefix+v2_name)super()._load_from_state_dict(state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs,)@classmethoddeffrom_float(cls,mod,use_precomputed_fake_quant=False):r"""Create a qat module from a float module or qparams_dict Args: `mod` a float module, either produced by torch.ao.quantization utilities or directly from user """# The ignore is because _FLOAT_MODULE is a TypeVar here where the bound# has no __name__ (code is fine though)asserttype(mod)==cls._FLOAT_MODULE,("qat."+cls.__name__+".from_float only works for "+cls._FLOAT_MODULE.__name__)asserthasattr(mod,"qconfig"),"Input float module must have qconfig defined"assertmod.qconfig,"Input float module must have a valid qconfig"qconfig=mod.qconfigconv,bn=mod[0],mod[1]# type: ignore[index]qat_convbn=cls(conv.in_channels,conv.out_channels,conv.kernel_size,conv.stride,conv.padding,conv.dilation,conv.groups,conv.biasisnotNone,conv.padding_mode,bn.eps,bn.momentum,False,qconfig,)qat_convbn.weight=conv.weightqat_convbn.bias=conv.biasqat_convbn.bn.weight=bn.weightqat_convbn.bn.bias=bn.biasqat_convbn.bn.running_mean=bn.running_meanqat_convbn.bn.running_var=bn.running_var# mypy error: Cannot determine type of 'num_batches_tracked'qat_convbn.bn.num_batches_tracked=bn.num_batches_trackedreturnqat_convbndefto_float(self):cls=type(self)conv=cls._FLOAT_CONV_MODULE(# type: ignore[attr-defined]self.in_channels,self.out_channels,self.kernel_size,self.stride,self.padding,self.dilation,self.groups,self.biasisnotNone,self.padding_mode,)conv.weight=torch.nn.Parameter(self.weight.detach())ifself.biasisnotNone:conv.bias=torch.nn.Parameter(self.bias.detach())ifcls._FLOAT_BN_MODULE:# type: ignore[attr-defined]# fuse bn into convassertself.bn.running_varisnotNoneandself.bn.running_meanisnotNoneconv.weight,conv.bias=fuse_conv_bn_weights(conv.weight,conv.bias,self.bn.running_mean,self.bn.running_var,self.bn.eps,self.bn.weight,self.bn.bias,)ifcls._FLOAT_RELU_MODULE:# type: ignore[attr-defined]modules=[]modules.append(conv)relu=cls._FLOAT_RELU_MODULE()# type: ignore[attr-defined]modules.append(relu)conv_relu=cls._FUSED_FLOAT_MODULE(*modules)# type: ignore[attr-defined]conv_relu.train(self.training)returnconv_reluelse:conv.train(self.training)returnconv
[docs]classConvBn1d(_ConvBnNd,nn.Conv1d):r""" A ConvBn1d module is a module fused from Conv1d and BatchNorm1d, attached with FakeQuantize modules for weight, used in quantization aware training. We combined the interface of :class:`torch.nn.Conv1d` and :class:`torch.nn.BatchNorm1d`. Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized to default. Attributes: freeze_bn: weight_fake_quant: fake quant module for weight """_FLOAT_BN_MODULE:ClassVar[Type[nn.BatchNorm1d]]=nn.BatchNorm1d_FLOAT_RELU_MODULE:ClassVar[Optional[Type[nn.Module]]]=None_FLOAT_MODULE:ClassVar[Type[nn.Module]]=nni.ConvBn1d# type: ignore[assignment]_FLOAT_CONV_MODULE:ClassVar[Type[nn.Conv1d]]=nn.Conv1ddef__init__(self,# Conv1d argsin_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=None,padding_mode="zeros",# BatchNorm1d args# num_features: out_channelseps=1e-05,momentum=0.1,# affine: True# track_running_stats: True# Args for this modulefreeze_bn=False,qconfig=None,):kernel_size=_single(kernel_size)stride=_single(stride)padding=_single(padding)dilation=_single(dilation)_ConvBnNd.__init__(self,in_channels,out_channels,kernel_size,stride,padding,dilation,False,_single(0),groups,bias,padding_mode,eps,momentum,freeze_bn,qconfig,dim=1,)
[docs]classConvBnReLU1d(ConvBn1d):r""" A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU, attached with FakeQuantize modules for weight, used in quantization aware training. We combined the interface of :class:`torch.nn.Conv1d` and :class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`. Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to default. Attributes: weight_fake_quant: fake quant module for weight """# base class defines _FLOAT_MODULE as "ConvBn1d"_FLOAT_MODULE:ClassVar[Type[nn.Module]]=nni.ConvBnReLU1d_FLOAT_CONV_MODULE:ClassVar[Type[nn.Conv1d]]=nn.Conv1d_FLOAT_BN_MODULE:ClassVar[Type[nn.BatchNorm1d]]=nn.BatchNorm1d_FLOAT_RELU_MODULE:ClassVar[Optional[Type[nn.Module]]]=nn.ReLU# module class after fusing bn into conv_FUSED_FLOAT_MODULE:ClassVar[Optional[Type[nn.Module]]]=nni.ConvReLU1ddef__init__(self,# Conv1d argsin_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=None,padding_mode="zeros",# BatchNorm1d args# num_features: out_channelseps=1e-05,momentum=0.1,# affine: True# track_running_stats: True# Args for this modulefreeze_bn=False,qconfig=None,):super().__init__(in_channels,out_channels,kernel_size,stride,padding,dilation,groups,bias,padding_mode,eps,momentum,freeze_bn,qconfig,)defforward(self,input):returnF.relu(self._forward(input))@classmethoddeffrom_float(cls,mod,use_precomputed_fake_quant=False):returnsuper().from_float(mod,use_precomputed_fake_quant)
classConvReLU1d(nnqat.Conv1d,nni._FusedModule):r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with FakeQuantize modules for weight for quantization aware training. We combined the interface of :class:`~torch.nn.Conv1d` and :class:`~torch.nn.BatchNorm1d`. Attributes: weight_fake_quant: fake quant module for weight """_FLOAT_MODULE:ClassVar[Type[nni.ConvReLU1d]]=nni.ConvReLU1d# type: ignore[assignment]_FLOAT_CONV_MODULE:ClassVar[Type[nn.Conv1d]]=nn.Conv1d_FLOAT_BN_MODULE:ClassVar[Optional[Type[nn.Module]]]=None_FLOAT_RELU_MODULE:ClassVar[Optional[Type[nn.Module]]]=nn.ReLUdef__init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True,padding_mode="zeros",qconfig=None,):super().__init__(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias,padding_mode=padding_mode,qconfig=qconfig,)assertqconfig,"qconfig must be provided for QAT module"self.qconfig=qconfigself.weight_fake_quant=self.qconfig.weight()defforward(self,input):returnF.relu(self._conv_forward(input,self.weight_fake_quant(self.weight),self.bias))@classmethoddeffrom_float(cls,mod,use_precomputed_fake_quant=False):returnsuper().from_float(mod,use_precomputed_fake_quant=use_precomputed_fake_quant)
[docs]classConvBn2d(_ConvBnNd,nn.Conv2d):r""" A ConvBn2d module is a module fused from Conv2d and BatchNorm2d, attached with FakeQuantize modules for weight, used in quantization aware training. We combined the interface of :class:`torch.nn.Conv2d` and :class:`torch.nn.BatchNorm2d`. Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized to default. Attributes: freeze_bn: weight_fake_quant: fake quant module for weight """_FLOAT_MODULE:ClassVar[Type[nni.ConvBn2d]]=nni.ConvBn2d# type: ignore[assignment]_FLOAT_CONV_MODULE:ClassVar[Type[nn.Conv2d]]=nn.Conv2d_FLOAT_BN_MODULE:ClassVar[Optional[Type[nn.Module]]]=nn.BatchNorm2d_FLOAT_RELU_MODULE:ClassVar[Optional[Type[nn.Module]]]=Nonedef__init__(self,# ConvNd argsin_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=None,padding_mode="zeros",# BatchNorm2d args# num_features: out_channelseps=1e-05,momentum=0.1,# affine: True# track_running_stats: True# Args for this modulefreeze_bn=False,qconfig=None,):kernel_size=_pair(kernel_size)stride=_pair(stride)padding=_pair(padding)dilation=_pair(dilation)_ConvBnNd.__init__(self,in_channels,out_channels,kernel_size,stride,padding,dilation,False,_pair(0),groups,bias,padding_mode,eps,momentum,freeze_bn,qconfig,dim=2,)
[docs]classConvBnReLU2d(ConvBn2d):r""" A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU, attached with FakeQuantize modules for weight, used in quantization aware training. We combined the interface of :class:`torch.nn.Conv2d` and :class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`. Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to default. Attributes: weight_fake_quant: fake quant module for weight """# base class defines _FLOAT_MODULE as "ConvBn2d"_FLOAT_MODULE:ClassVar[Type[nni.ConvBnReLU2d]]=nni.ConvBnReLU2d# type: ignore[assignment]_FLOAT_CONV_MODULE:ClassVar[Type[nn.Conv2d]]=nn.Conv2d_FLOAT_BN_MODULE:ClassVar[Type[nn.BatchNorm2d]]=nn.BatchNorm2d_FLOAT_RELU_MODULE:ClassVar[Optional[Type[nn.Module]]]=nn.ReLU# module class after fusing bn into conv_FUSED_FLOAT_MODULE:ClassVar[Optional[Type[nni.ConvReLU2d]]]=nni.ConvReLU2ddef__init__(self,# Conv2d argsin_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=None,padding_mode="zeros",# BatchNorm2d args# num_features: out_channelseps=1e-05,momentum=0.1,# affine: True# track_running_stats: True# Args for this modulefreeze_bn=False,qconfig=None,):super().__init__(in_channels,out_channels,kernel_size,stride,padding,dilation,groups,bias,padding_mode,eps,momentum,freeze_bn,qconfig,)defforward(self,input):returnF.relu(self._forward(input))@classmethoddeffrom_float(cls,mod,use_precomputed_fake_quant=False):returnsuper().from_float(mod,use_precomputed_fake_quant)
[docs]classConvReLU2d(nnqat.Conv2d,nni._FusedModule):r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with FakeQuantize modules for weight for quantization aware training. We combined the interface of :class:`~torch.nn.Conv2d` and :class:`~torch.nn.BatchNorm2d`. Attributes: weight_fake_quant: fake quant module for weight """_FLOAT_MODULE:ClassVar[Type[nn.Module]]=nni.ConvReLU2d# type: ignore[assignment]_FLOAT_CONV_MODULE:ClassVar[Type[nn.Conv2d]]=nn.Conv2d_FLOAT_BN_MODULE:ClassVar[Optional[Type[nn.Module]]]=None_FLOAT_RELU_MODULE:ClassVar[Optional[Type[nn.Module]]]=nn.ReLUdef__init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True,padding_mode="zeros",qconfig=None,):super().__init__(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias,padding_mode=padding_mode,qconfig=qconfig,)assertqconfig,"qconfig must be provided for QAT module"self.qconfig=qconfigself.weight_fake_quant=self.qconfig.weight()defforward(self,input):returnF.relu(self._conv_forward(input,self.weight_fake_quant(self.weight),self.bias))@classmethoddeffrom_float(cls,mod,use_precomputed_fake_quant=False):returnsuper().from_float(mod,use_precomputed_fake_quant=use_precomputed_fake_quant)
[docs]classConvBn3d(_ConvBnNd,nn.Conv3d):r""" A ConvBn3d module is a module fused from Conv3d and BatchNorm3d, attached with FakeQuantize modules for weight, used in quantization aware training. We combined the interface of :class:`torch.nn.Conv3d` and :class:`torch.nn.BatchNorm3d`. Similar to :class:`torch.nn.Conv3d`, with FakeQuantize modules initialized to default. Attributes: freeze_bn: weight_fake_quant: fake quant module for weight """_FLOAT_MODULE:ClassVar[Type[nni.ConvBn3d]]=nni.ConvBn3d# type: ignore[assignment]_FLOAT_CONV_MODULE:ClassVar[Type[nn.Conv3d]]=nn.Conv3d_FLOAT_BN_MODULE:ClassVar[Optional[Type[nn.Module]]]=nn.BatchNorm3d_FLOAT_RELU_MODULE:ClassVar[Optional[Type[nn.Module]]]=Nonedef__init__(self,# ConvNd argsin_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=None,padding_mode="zeros",# BatchNorm3d args# num_features: out_channelseps=1e-05,momentum=0.1,# affine: True# track_running_stats: True# Args for this modulefreeze_bn=False,qconfig=None,):kernel_size=_triple(kernel_size)stride=_triple(stride)padding=_triple(padding)dilation=_triple(dilation)_ConvBnNd.__init__(self,in_channels,out_channels,kernel_size,stride,padding,dilation,False,_triple(0),groups,bias,padding_mode,eps,momentum,freeze_bn,qconfig,dim=3,)
[docs]classConvBnReLU3d(ConvBn3d):r""" A ConvBnReLU3d module is a module fused from Conv3d, BatchNorm3d and ReLU, attached with FakeQuantize modules for weight, used in quantization aware training. We combined the interface of :class:`torch.nn.Conv3d` and :class:`torch.nn.BatchNorm3d` and :class:`torch.nn.ReLU`. Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to default. Attributes: weight_fake_quant: fake quant module for weight """_FLOAT_MODULE:ClassVar[Type[nni.ConvBnReLU3d]]=nni.ConvBnReLU3d# type: ignore[assignment]_FLOAT_CONV_MODULE:ClassVar[Type[nn.Conv3d]]=nn.Conv3d_FLOAT_BN_MODULE:ClassVar[Type[nn.BatchNorm3d]]=nn.BatchNorm3d_FLOAT_RELU_MODULE:ClassVar[Optional[Type[nn.ReLU]]]=nn.ReLU# module class after fusing bn into conv_FUSED_FLOAT_MODULE:ClassVar[Optional[Type[nni.ConvReLU3d]]]=nni.ConvReLU3ddef__init__(self,# Conv3d argsin_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=None,padding_mode="zeros",# BatchNorm3d args# num_features: out_channelseps=1e-05,momentum=0.1,# affine: True# track_running_stats: True# Args for this modulefreeze_bn=False,qconfig=None,):super().__init__(in_channels,out_channels,kernel_size,stride,padding,dilation,groups,bias,padding_mode,eps,momentum,freeze_bn,qconfig,)defforward(self,input):returnF.relu(ConvBn3d._forward(self,input))@classmethoddeffrom_float(cls,mod,use_precomputed_fake_quant=False):returnsuper().from_float(mod,use_precomputed_fake_quant=use_precomputed_fake_quant)
[docs]classConvReLU3d(nnqat.Conv3d,nni._FusedModule):r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with FakeQuantize modules for weight for quantization aware training. We combined the interface of :class:`~torch.nn.Conv3d` and :class:`~torch.nn.BatchNorm3d`. Attributes: weight_fake_quant: fake quant module for weight """_FLOAT_MODULE:ClassVar[Type[nni.ConvReLU3d]]=nni.ConvReLU3d# type: ignore[assignment]_FLOAT_CONV_MODULE:ClassVar[Type[nn.Conv3d]]=nn.Conv3d_FLOAT_BN_MODULE:ClassVar[Optional[Type[nn.Module]]]=None_FLOAT_RELU_MODULE:ClassVar[Optional[Type[nn.Module]]]=nn.ReLUdef__init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True,padding_mode="zeros",qconfig=None,):super().__init__(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias,padding_mode=padding_mode,qconfig=qconfig,)assertqconfig,"qconfig must be provided for QAT module"self.qconfig=qconfigself.weight_fake_quant=self.qconfig.weight()defforward(self,input):returnF.relu(self._conv_forward(input,self.weight_fake_quant(self.weight),self.bias))@classmethoddeffrom_float(cls,mod,use_precomputed_fake_quant=False):returnsuper().from_float(mod,use_precomputed_fake_quant=use_precomputed_fake_quant)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.