Source code for torch.nn.intrinsic.qat.modules.conv_fused
importmathimporttorchimporttorch.nnasnnimporttorch.nn.intrinsicasnniimporttorch.nn.qatasnnqatimporttorch.nn.functionalasFfromtorch.nnimportinitfromtorch.nn.modules.utilsimport_single,_pair,_triplefromtorch.nn.parameterimportParameterfromtypingimportTypeVar_BN_CLASS_MAP={1:nn.BatchNorm1d,2:nn.BatchNorm2d,3:nn.BatchNorm3d,}MOD=TypeVar('MOD',bound=nn.modules.conv._ConvNd)class_ConvBnNd(nn.modules.conv._ConvNd,nni._FusedModule):_version=2_FLOAT_MODULE=MODdef__init__(self,# ConvNd argsin_channels,out_channels,kernel_size,stride,padding,dilation,transposed,output_padding,groups,bias,padding_mode,# BatchNormNd args# num_features: out_channelseps=1e-05,momentum=0.1,# affine: True# track_running_stats: True# Args for this modulefreeze_bn=False,qconfig=None,dim=2):nn.modules.conv._ConvNd.__init__(self,in_channels,out_channels,kernel_size,stride,padding,dilation,transposed,output_padding,groups,False,padding_mode)assertqconfig,'qconfig must be provided for QAT module'self.qconfig=qconfigself.freeze_bn=freeze_bnifself.trainingelseTrueself.bn=_BN_CLASS_MAP[dim](out_channels,eps,momentum,True,True)self.weight_fake_quant=self.qconfig.weight()ifbias:self.bias=Parameter(torch.empty(out_channels))else:self.register_parameter('bias',None)self.reset_bn_parameters()# this needs to be called after reset_bn_parameters,# as they modify the same stateifself.training:iffreeze_bn:self.freeze_bn_stats()else:self.update_bn_stats()else:self.freeze_bn_stats()defreset_running_stats(self):self.bn.reset_running_stats()defreset_bn_parameters(self):self.bn.reset_running_stats()init.uniform_(self.bn.weight)init.zeros_(self.bn.bias)# note: below is actully for conv, not BNifself.biasisnotNone:fan_in,_=init._calculate_fan_in_and_fan_out(self.weight)bound=1/math.sqrt(fan_in)init.uniform_(self.bias,-bound,bound)defreset_parameters(self):super(_ConvBnNd,self).reset_parameters()defupdate_bn_stats(self):self.freeze_bn=Falseself.bn.training=Truereturnselfdeffreeze_bn_stats(self):self.freeze_bn=Trueself.bn.training=Falsereturnselfdef_forward(self,input):assertself.bn.running_varisnotNonerunning_std=torch.sqrt(self.bn.running_var+self.bn.eps)scale_factor=self.bn.weight/running_stdweight_shape=[1]*len(self.weight.shape)weight_shape[0]=-1bias_shape=[1]*len(self.weight.shape)bias_shape[1]=-1scaled_weight=self.weight_fake_quant(self.weight*scale_factor.reshape(weight_shape))# using zero bias here since the bias for original conv# will be added laterifself.biasisnotNone:zero_bias=torch.zeros_like(self.bias)else:zero_bias=torch.zeros(self.out_channels,device=scaled_weight.device)conv=self._conv_forward(input,scaled_weight,zero_bias)conv_orig=conv/scale_factor.reshape(bias_shape)ifself.biasisnotNone:conv_orig=conv_orig+self.bias.reshape(bias_shape)conv=self.bn(conv_orig)returnconvdefextra_repr(self):# TODO(jerryzh): extendreturnsuper(_ConvBnNd,self).extra_repr()defforward(self,input):returnself._forward(input)deftrain(self,mode=True):""" Batchnorm's training behavior is using the self.training flag. Prevent changing it if BN is frozen. This makes sure that calling `model.train()` on a model with a frozen BN will behave properly. """self.training=modeifnotself.freeze_bn:formoduleinself.children():module.train(mode)returnself# ===== Serialization version history =====## Version 1/None# self# |--- weight : Tensor# |--- bias : Tensor# |--- gamma : Tensor# |--- beta : Tensor# |--- running_mean : Tensor# |--- running_var : Tensor# |--- num_batches_tracked : Tensor## Version 2# self# |--- weight : Tensor# |--- bias : Tensor# |--- bn : Module# |--- weight : Tensor (moved from v1.self.gamma)# |--- bias : Tensor (moved from v1.self.beta)# |--- running_mean : Tensor (moved from v1.self.running_mean)# |--- running_var : Tensor (moved from v1.self.running_var)# |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked)def_load_from_state_dict(self,state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs):version=local_metadata.get('version',None)ifversionisNoneorversion==1:# BN related parameters and buffers were moved into the BN module for v2v2_to_v1_names={'bn.weight':'gamma','bn.bias':'beta','bn.running_mean':'running_mean','bn.running_var':'running_var','bn.num_batches_tracked':'num_batches_tracked',}forv2_name,v1_nameinv2_to_v1_names.items():ifprefix+v1_nameinstate_dict:state_dict[prefix+v2_name]=state_dict[prefix+v1_name]state_dict.pop(prefix+v1_name)elifprefix+v2_nameinstate_dict:# there was a brief period where forward compatibility# for this module was broken (between# https://github.com/pytorch/pytorch/pull/38478# and https://github.com/pytorch/pytorch/pull/38820)# and modules emitted the v2 state_dict format while# specifying that version == 1. This patches the forward# compatibility issue by allowing the v2 style entries to# be used.passelifstrict:missing_keys.append(prefix+v2_name)super(_ConvBnNd,self)._load_from_state_dict(state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs)@classmethoddeffrom_float(cls,mod):r"""Create a qat module from a float module or qparams_dict Args: `mod` a float module, either produced by torch.quantization utilities or directly from user """# The ignore is because _FLOAT_MODULE is a TypeVar here where the bound# has no __name__ (code is fine though)asserttype(mod)==cls._FLOAT_MODULE,'qat.'+cls.__name__+'.from_float only works for '+ \
cls._FLOAT_MODULE.__name__# type: ignore[attr-defined]asserthasattr(mod,'qconfig'),'Input float module must have qconfig defined'assertmod.qconfig,'Input float module must have a valid qconfig'qconfig=mod.qconfigconv,bn=mod[0],mod[1]qat_convbn=cls(conv.in_channels,conv.out_channels,conv.kernel_size,conv.stride,conv.padding,conv.dilation,conv.groups,conv.biasisnotNone,conv.padding_mode,bn.eps,bn.momentum,False,qconfig)qat_convbn.weight=conv.weightqat_convbn.bias=conv.biasqat_convbn.bn.weight=bn.weightqat_convbn.bn.bias=bn.biasqat_convbn.bn.running_mean=bn.running_meanqat_convbn.bn.running_var=bn.running_var# mypy error: Cannot determine type of 'num_batches_tracked'qat_convbn.bn.num_batches_tracked=bn.num_batches_tracked# type: ignore[has-type]returnqat_convbndefto_float(self):modules=[]cls=type(self)conv=cls._FLOAT_CONV_MODULE(# type: ignore[attr-defined]self.in_channels,self.out_channels,self.kernel_size,self.stride,self.padding,self.dilation,self.groups,self.biasisnotNone,self.padding_mode)conv.weight=torch.nn.Parameter(self.weight.detach())ifself.biasisnotNone:conv.bias=torch.nn.Parameter(self.bias.detach())modules.append(conv)ifcls._FLOAT_BN_MODULE:# type: ignore[attr-defined]bn=cls._FLOAT_BN_MODULE(# type: ignore[attr-defined]self.bn.num_features,self.bn.eps,self.bn.momentum,self.bn.affine,self.bn.track_running_stats)bn.weight=Parameter(self.bn.weight.detach())ifself.bn.affine:bn.bias=Parameter(self.bn.bias.detach())modules.append(bn)ifcls._FLOAT_RELU_MODULE:# type: ignore[attr-defined]relu=cls._FLOAT_RELU_MODULE()# type: ignore[attr-defined]modules.append(relu)result=cls._FLOAT_MODULE(*modules)# type: ignore[operator]result.train(self.training)returnresultclassConvBn1d(_ConvBnNd,nn.Conv1d):r""" A ConvBn1d module is a module fused from Conv1d and BatchNorm1d, attached with FakeQuantize modules for weight, used in quantization aware training. We combined the interface of :class:`torch.nn.Conv1d` and :class:`torch.nn.BatchNorm1d`. Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized to default. Attributes: freeze_bn: weight_fake_quant: fake quant module for weight """_FLOAT_BN_MODULE=nn.BatchNorm1d_FLOAT_RELU_MODULE=None_FLOAT_MODULE=nni.ConvBn1d_FLOAT_CONV_MODULE=nn.Conv1ddef__init__(self,# Conv1d argsin_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=None,padding_mode='zeros',# BatchNorm1d args# num_features: out_channelseps=1e-05,momentum=0.1,# affine: True# track_running_stats: True# Args for this modulefreeze_bn=False,qconfig=None):kernel_size=_single(kernel_size)stride=_single(stride)padding=_single(padding)dilation=_single(dilation)_ConvBnNd.__init__(self,in_channels,out_channels,kernel_size,stride,padding,dilation,False,_single(0),groups,bias,padding_mode,eps,momentum,freeze_bn,qconfig,dim=1)classConvBnReLU1d(ConvBn1d):r""" A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU, attached with FakeQuantize modules for weight, used in quantization aware training. We combined the interface of :class:`torch.nn.Conv1d` and :class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`. Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to default. Attributes: weight_fake_quant: fake quant module for weight """# base class defines _FLOAT_MODULE as "ConvBn1d"_FLOAT_MODULE=nni.ConvBnReLU1d# type: ignore[assignment]_FLOAT_CONV_MODULE=nn.Conv1d_FLOAT_BN_MODULE=nn.BatchNorm1d_FLOAT_RELU_MODULE=nn.ReLU# type: ignore[assignment]def__init__(self,# Conv1d argsin_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=None,padding_mode='zeros',# BatchNorm1d args# num_features: out_channelseps=1e-05,momentum=0.1,# affine: True# track_running_stats: True# Args for this modulefreeze_bn=False,qconfig=None):super().__init__(in_channels,out_channels,kernel_size,stride,padding,dilation,groups,bias,padding_mode,eps,momentum,freeze_bn,qconfig)defforward(self,input):returnF.relu(ConvBn1d._forward(self,input))@classmethoddeffrom_float(cls,mod):returnsuper(ConvBnReLU1d,cls).from_float(mod)
[docs]classConvBn2d(_ConvBnNd,nn.Conv2d):r""" A ConvBn2d module is a module fused from Conv2d and BatchNorm2d, attached with FakeQuantize modules for weight, used in quantization aware training. We combined the interface of :class:`torch.nn.Conv2d` and :class:`torch.nn.BatchNorm2d`. Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized to default. Attributes: freeze_bn: weight_fake_quant: fake quant module for weight """_FLOAT_MODULE=nni.ConvBn2d_FLOAT_CONV_MODULE=nn.Conv2d_FLOAT_BN_MODULE=nn.BatchNorm2d_FLOAT_RELU_MODULE=Nonedef__init__(self,# ConvNd argsin_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=None,padding_mode='zeros',# BatchNorm2d args# num_features: out_channelseps=1e-05,momentum=0.1,# affine: True# track_running_stats: True# Args for this modulefreeze_bn=False,qconfig=None):kernel_size=_pair(kernel_size)stride=_pair(stride)padding=_pair(padding)dilation=_pair(dilation)_ConvBnNd.__init__(self,in_channels,out_channels,kernel_size,stride,padding,dilation,False,_pair(0),groups,bias,padding_mode,eps,momentum,freeze_bn,qconfig,dim=2)
[docs]classConvBnReLU2d(ConvBn2d):r""" A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU, attached with FakeQuantize modules for weight, used in quantization aware training. We combined the interface of :class:`torch.nn.Conv2d` and :class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`. Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to default. Attributes: weight_fake_quant: fake quant module for weight """# base class defines _FLOAT_MODULE as "ConvBn2d"_FLOAT_MODULE=nni.ConvBnReLU2d# type: ignore[assignment]_FLOAT_CONV_MODULE=nn.Conv2d_FLOAT_BN_MODULE=nn.BatchNorm2d_FLOAT_RELU_MODULE=nn.ReLU# type: ignore[assignment]def__init__(self,# Conv2d argsin_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=None,padding_mode='zeros',# BatchNorm2d args# num_features: out_channelseps=1e-05,momentum=0.1,# affine: True# track_running_stats: True# Args for this modulefreeze_bn=False,qconfig=None):super(ConvBnReLU2d,self).__init__(in_channels,out_channels,kernel_size,stride,padding,dilation,groups,bias,padding_mode,eps,momentum,freeze_bn,qconfig)defforward(self,input):returnF.relu(ConvBn2d._forward(self,input))@classmethoddeffrom_float(cls,mod):returnsuper(ConvBnReLU2d,cls).from_float(mod)
[docs]classConvReLU2d(nnqat.Conv2d,nni._FusedModule):r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with FakeQuantize modules for weight for quantization aware training. We combined the interface of :class:`~torch.nn.Conv2d` and :class:`~torch.nn.BatchNorm2d`. Attributes: weight_fake_quant: fake quant module for weight """_FLOAT_MODULE=nni.ConvReLU2d_FLOAT_CONV_MODULE=nn.Conv2d_FLOAT_BN_MODULE=None_FLOAT_RELU_MODULE=nn.ReLUdef__init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True,padding_mode='zeros',qconfig=None):super(ConvReLU2d,self).__init__(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias,padding_mode=padding_mode,qconfig=qconfig)assertqconfig,'qconfig must be provided for QAT module'self.qconfig=qconfigself.weight_fake_quant=self.qconfig.weight()defforward(self,input):returnF.relu(self._conv_forward(input,self.weight_fake_quant(self.weight),self.bias))@classmethoddeffrom_float(cls,mod):returnsuper(ConvReLU2d,cls).from_float(mod)
[docs]classConvBn3d(_ConvBnNd,nn.Conv3d):r""" A ConvBn3d module is a module fused from Conv3d and BatchNorm3d, attached with FakeQuantize modules for weight, used in quantization aware training. We combined the interface of :class:`torch.nn.Conv3d` and :class:`torch.nn.BatchNorm3d`. Similar to :class:`torch.nn.Conv3d`, with FakeQuantize modules initialized to default. Attributes: freeze_bn: weight_fake_quant: fake quant module for weight """_FLOAT_MODULE=nni.ConvBn3d_FLOAT_CONV_MODULE=nn.Conv3d_FLOAT_BN_MODULE=nn.BatchNorm3d_FLOAT_RELU_MODULE=Nonedef__init__(self,# ConvNd argsin_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=None,padding_mode="zeros",# BatchNorm3d args# num_features: out_channelseps=1e-05,momentum=0.1,# affine: True# track_running_stats: True# Args for this modulefreeze_bn=False,qconfig=None,):kernel_size=_triple(kernel_size)stride=_triple(stride)padding=_triple(padding)dilation=_triple(dilation)_ConvBnNd.__init__(self,in_channels,out_channels,kernel_size,stride,padding,dilation,False,_triple(0),groups,bias,padding_mode,eps,momentum,freeze_bn,qconfig,dim=3,)
[docs]classConvBnReLU3d(ConvBn3d):r""" A ConvBnReLU3d module is a module fused from Conv3d, BatchNorm3d and ReLU, attached with FakeQuantize modules for weight, used in quantization aware training. We combined the interface of :class:`torch.nn.Conv3d` and :class:`torch.nn.BatchNorm3d` and :class:`torch.nn.ReLU`. Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to default. Attributes: weight_fake_quant: fake quant module for weight """_FLOAT_MODULE=nni.ConvBnReLU3d# type: ignore[assignment]_FLOAT_CONV_MODULE=nn.Conv3d_FLOAT_BN_MODULE=nn.BatchNorm3d_FLOAT_RELU_MODULE=nn.ReLU# type: ignore[assignment]def__init__(self,# Conv3d argsin_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=None,padding_mode="zeros",# BatchNorm3d args# num_features: out_channelseps=1e-05,momentum=0.1,# affine: True# track_running_stats: True# Args for this modulefreeze_bn=False,qconfig=None,):super(ConvBnReLU3d,self).__init__(in_channels,out_channels,kernel_size,stride,padding,dilation,groups,bias,padding_mode,eps,momentum,freeze_bn,qconfig,)defforward(self,input):returnF.relu(ConvBn3d._forward(self,input))@classmethoddeffrom_float(cls,mod):returnsuper(ConvBnReLU3d,cls).from_float(mod)
[docs]classConvReLU3d(nnqat.Conv3d,nni._FusedModule):r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with FakeQuantize modules for weight for quantization aware training. We combined the interface of :class:`~torch.nn.Conv3d` and :class:`~torch.nn.BatchNorm3d`. Attributes: weight_fake_quant: fake quant module for weight """_FLOAT_MODULE=nni.ConvReLU3d_FLOAT_CONV_MODULE=nn.Conv3d_FLOAT_BN_MODULE=None_FLOAT_RELU_MODULE=nn.ReLUdef__init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True,padding_mode="zeros",qconfig=None,):super(ConvReLU3d,self).__init__(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias,padding_mode=padding_mode,qconfig=qconfig,)assertqconfig,"qconfig must be provided for QAT module"self.qconfig=qconfigself.weight_fake_quant=self.qconfig.weight()defforward(self,input):returnF.relu(self._conv_forward(input,self.weight_fake_quant(self.weight),self.bias))@classmethoddeffrom_float(cls,mod):returnsuper(ConvReLU3d,cls).from_float(mod)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.