importtorchimporttorch.nnasnnfromtorch.nn.modules.utilsimport_single,_pair,_triplefromtorch.ao.nn.intrinsicimport_FusedModulefromtypingimportTuple,TypeVar,Unionfromtorch.nn.common_typesimport_size_1_t,_size_2_t,_size_3_t__all__=["Conv1d","Conv2d","Conv3d"]MOD=TypeVar('MOD',bound=nn.modules.conv._ConvNd)class_ConvNd(nn.modules.conv._ConvNd):_FLOAT_MODULE=MODdef__init__(self,in_channels:int,out_channels:int,kernel_size:Tuple[int,...],stride:Tuple[int,...],padding:Tuple[int,...],dilation:Tuple[int,...],transposed:bool,output_padding:Tuple[int,...],groups:int,bias:bool,padding_mode:str,qconfig=None,device=None,dtype=None)->None:factory_kwargs={"device":device,"dtype":dtype}nn.modules.conv._ConvNd.__init__(self,in_channels,out_channels,kernel_size,stride,padding,dilation,transposed,output_padding,groups,bias,padding_mode,**factory_kwargs)assertqconfig,'qconfig must be provided for QAT module'self.qconfig=qconfigself.weight_fake_quant=qconfig.weight(factory_kwargs=factory_kwargs)defforward(self,input):returnself._conv_forward(input,self.weight_fake_quant(self.weight),self.bias)@staticmethoddeffrom_float(cls,mod):r"""Create a qat module from a float module Args: `mod`: a float module, either produced by torch.ao.quantization utilities or directly from user """asserttype(mod)==cls._FLOAT_MODULE,("qat."+cls.__name__+".from_float only works for "+cls._FLOAT_MODULE.__name__# type: ignore[attr-defined])asserthasattr(mod,'qconfig'),'Input float module must have qconfig defined'assertmod.qconfig,'Input float module must have a valid qconfig'ifissubclass(type(mod),_FusedModule):mod=mod[0]# type: ignore[index]qconfig=mod.qconfigqat_conv=cls(mod.in_channels,mod.out_channels,mod.kernel_size,stride=mod.stride,padding=mod.padding,dilation=mod.dilation,groups=mod.groups,bias=mod.biasisnotNone,padding_mode=mod.padding_mode,qconfig=qconfig)qat_conv.weight=mod.weightqat_conv.bias=mod.biasreturnqat_convdefto_float(self):""" This works for both single qat conv, and the qat conv - relu modules to convert the qat module to a floating point module """cls=type(self)conv=cls._FLOAT_CONV_MODULE(# type: ignore[attr-defined, operator]self.in_channels,self.out_channels,self.kernel_size,# type: ignore[arg-type]self.stride,# type: ignore[arg-type]self.padding,# type: ignore[arg-type]self.dilation,# type: ignore[arg-type]self.groups,self.biasisnotNone,self.padding_mode)conv.weight=torch.nn.Parameter(self.weight.detach())ifself.biasisnotNone:conv.bias=torch.nn.Parameter(self.bias.detach())# conv reluifissubclass(cls,_FusedModule):modules=[conv]asserthasattr(cls,"_FLOAT_RELU_MODULE")relu=cls._FLOAT_RELU_MODULE()# type: ignore[attr-defined]modules.append(relu)fused=cls._FLOAT_MODULE(*modules)# type: ignore[arg-type, attr-defined, operator]fused.train(self.training)returnfusedelse:returnconvclassConv1d(_ConvNd,nn.Conv1d):r""" A Conv1d module attached with FakeQuantize modules for weight, used for quantization aware training. We adopt the same interface as :class:`~torch.nn.Conv1d` Similar to :class:`~torch.nn.Conv2d`, with FakeQuantize modules initialized to default. Attributes: weight_fake_quant: fake quant module for weight """_FLOAT_MODULE=nn.Conv1d_FLOAT_CONV_MODULE=nn.Conv1ddef__init__(self,in_channels:int,out_channels:int,kernel_size:_size_1_t,stride:_size_1_t=1,padding:Union[str,_size_1_t]=0,dilation:_size_1_t=1,groups:int=1,bias:bool=True,padding_mode:str='zeros',qconfig=None,device=None,dtype=None)->None:kernel_size_=_single(kernel_size)stride_=_single(stride)padding_=paddingifisinstance(padding,str)else_single(padding)dilation_=_single(dilation)super().__init__(in_channels,out_channels,kernel_size_,stride=stride_,padding=padding_,dilation=dilation_,transposed=False,output_padding=_single(0),groups=groups,bias=bias,padding_mode=padding_mode,qconfig=qconfig,device=device,dtype=dtype)@classmethoddeffrom_float(cls,mod):returnsuper().from_float(cls,mod)
[docs]classConv2d(_ConvNd,nn.Conv2d):r""" A Conv2d module attached with FakeQuantize modules for weight, used for quantization aware training. We adopt the same interface as `torch.nn.Conv2d`, please see https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d for documentation. Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to default. Attributes: weight_fake_quant: fake quant module for weight """_FLOAT_MODULE=nn.Conv2d_FLOAT_CONV_MODULE=nn.Conv2ddef__init__(self,in_channels:int,out_channels:int,kernel_size:_size_2_t,stride:_size_2_t=1,padding:Union[str,_size_2_t]=0,dilation:_size_2_t=1,groups:int=1,bias:bool=True,padding_mode:str='zeros',qconfig=None,device=None,dtype=None)->None:kernel_size_=_pair(kernel_size)stride_=_pair(stride)padding_=paddingifisinstance(padding,str)else_pair(padding)dilation_=_pair(dilation)super().__init__(in_channels,out_channels,kernel_size_,stride=stride_,padding=padding_,dilation=dilation_,transposed=False,output_padding=_pair(0),groups=groups,bias=bias,padding_mode=padding_mode,qconfig=qconfig,device=device,dtype=dtype)defforward(self,input):returnself._conv_forward(input,self.weight_fake_quant(self.weight),self.bias)@classmethoddeffrom_float(cls,mod):returnsuper().from_float(cls,mod)
[docs]classConv3d(_ConvNd,nn.Conv3d):r""" A Conv3d module attached with FakeQuantize modules for weight, used for quantization aware training. We adopt the same interface as `torch.nn.Conv3d`, please see https://pytorch.org/docs/stable/nn.html?highlight=conv3d#torch.nn.Conv3d for documentation. Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to default. Attributes: weight_fake_quant: fake quant module for weight """_FLOAT_MODULE=nn.Conv3d_FLOAT_CONV_MODULE=nn.Conv3ddef__init__(self,in_channels:int,out_channels:int,kernel_size:_size_3_t,stride:_size_3_t=1,padding:Union[str,_size_3_t]=0,dilation:_size_3_t=1,groups:int=1,bias:bool=True,padding_mode:str='zeros',qconfig=None,device=None,dtype=None)->None:kernel_size_=_triple(kernel_size)stride_=_triple(stride)padding_=paddingifisinstance(padding,str)else_triple(padding)dilation_=_triple(dilation)super().__init__(in_channels,out_channels,kernel_size_,stride=stride_,padding=padding_,dilation=dilation_,transposed=False,output_padding=_triple(0),groups=groups,bias=bias,padding_mode=padding_mode,qconfig=qconfig,device=device,dtype=dtype)defforward(self,input):returnself._conv_forward(input,self.weight_fake_quant(self.weight),self.bias)@classmethoddeffrom_float(cls,mod):returnsuper().from_float(cls,mod)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.