Source code for torch.ao.nn.intrinsic.quantized.modules.conv_relu
# mypy: allow-untyped-defsimporttorchimporttorch.ao.nn.intrinsicimporttorch.ao.nn.intrinsic.qatimporttorch.ao.nn.quantizedasnnqimporttorch.nn.functionalasFfromtorch.nn.utilsimportfuse_conv_bn_weights__all__=["ConvReLU1d","ConvReLU2d","ConvReLU3d",]_reverse_repeat_padding=nnq.modules.conv._reverse_repeat_padding# TODO: factor out the common parts to ConvNd
[docs]classConvReLU1d(nnq.Conv1d):r""" A ConvReLU1d module is a fused module of Conv1d and ReLU We adopt the same interface as :class:`torch.ao.nn.quantized.Conv1d`. Attributes: Same as torch.ao.nn.quantized.Conv1d """_FLOAT_MODULE=torch.ao.nn.intrinsic.ConvReLU1d# type: ignore[assignment]def__init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True,padding_mode="zeros",device=None,dtype=None,):super().__init__(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype,)defforward(self,input):# Temporarily using len(shape) instead of ndim due to JIT issue# https://github.com/pytorch/pytorch/issues/23890iflen(input.shape)!=3:raiseValueError("Input shape must be `(N, C, L)`!")ifself.padding_mode!="zeros":# Padding in Conv1d is stored as (p, p), need to get (p,)_reversed_padding_repeated_twice=_reverse_repeat_padding(self.padding[:1])input=F.pad(input,_reversed_padding_repeated_twice,mode=self.padding_mode)returntorch.ops.quantized.conv1d_relu(input,self._packed_params,self.scale,self.zero_point)def_get_name(self):return"QuantizedConvReLU1d"@classmethoddeffrom_float(cls,mod,use_precomputed_fake_quant=False):iftype(mod)==torch.ao.nn.intrinsic.qat.ConvBnReLU1d:assertmod.bn.running_varisnotNoneandmod.bn.running_meanisnotNonemod.weight,mod.bias=fuse_conv_bn_weights(mod.weight,mod.bias,mod.bn.running_mean,mod.bn.running_var,mod.bn.eps,mod.bn.weight,mod.bn.bias,)returnsuper().from_float(mod,use_precomputed_fake_quant)@classmethoddeffrom_reference(cls,ref_qconv,output_scale,output_zero_point):assert(type(ref_qconv)!=torch.ao.nn.intrinsic.ConvBnReLU1d),"BatchNorm1d should be fused into Conv1d before converting to reference module"returnsuper().from_reference(ref_qconv[0],output_scale,output_zero_point)
[docs]classConvReLU2d(nnq.Conv2d):r""" A ConvReLU2d module is a fused module of Conv2d and ReLU We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`. Attributes: Same as torch.ao.nn.quantized.Conv2d """_FLOAT_MODULE=torch.ao.nn.intrinsic.ConvReLU2d# type: ignore[assignment]def__init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True,padding_mode="zeros",device=None,dtype=None,):super().__init__(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype,)defforward(self,input):# Temporarily using len(shape) instead of ndim due to JIT issue# https://github.com/pytorch/pytorch/issues/23890iflen(input.shape)!=4:raiseValueError("Input shape must be `(N, C, H, W)`!")ifself.padding_mode!="zeros":_reversed_padding_repeated_twice=_reverse_repeat_padding(self.padding)input=F.pad(input,_reversed_padding_repeated_twice,mode=self.padding_mode)returntorch.ops.quantized.conv2d_relu(input,self._packed_params,self.scale,self.zero_point)def_get_name(self):return"QuantizedConvReLU2d"@classmethoddeffrom_float(cls,mod,use_precomputed_fake_quant=False):iftype(mod)==torch.ao.nn.intrinsic.qat.ConvBnReLU2d:assertmod.bn.running_varisnotNoneandmod.bn.running_meanisnotNonemod.weight,mod.bias=fuse_conv_bn_weights(mod.weight,mod.bias,mod.bn.running_mean,mod.bn.running_var,mod.bn.eps,mod.bn.weight,mod.bn.bias,)returnsuper().from_float(mod,use_precomputed_fake_quant=use_precomputed_fake_quant)@classmethoddeffrom_reference(cls,ref_qconv,output_scale,output_zero_point):assert(type(ref_qconv)!=torch.ao.nn.intrinsic.ConvBnReLU2d),"BatchNorm2d should be fused into Conv2d before converting to reference module"returnsuper().from_reference(ref_qconv[0],output_scale,output_zero_point)
[docs]classConvReLU3d(nnq.Conv3d):r""" A ConvReLU3d module is a fused module of Conv3d and ReLU We adopt the same interface as :class:`torch.ao.nn.quantized.Conv3d`. Attributes: Same as torch.ao.nn.quantized.Conv3d """_FLOAT_MODULE=torch.ao.nn.intrinsic.ConvReLU3d# type: ignore[assignment]def__init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True,padding_mode="zeros",device=None,dtype=None,):assertpadding_mode!="reflect","Conv3d does not support reflection padding"super().__init__(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype,)defforward(self,input):# Temporarily using len(shape) instead of ndim due to JIT issue# https://github.com/pytorch/pytorch/issues/23890iflen(input.shape)!=5:raiseValueError("Input shape must be `(N, C, D, H, W)`!")ifself.padding_mode!="zeros":_reversed_padding_repeated_twice=_reverse_repeat_padding(self.padding)input=F.pad(input,_reversed_padding_repeated_twice,mode=self.padding_mode)returntorch.ops.quantized.conv3d_relu(input,self._packed_params,self.scale,self.zero_point)def_get_name(self):return"QuantizedConvReLU3d"@classmethoddeffrom_float(cls,mod,use_precomputed_fake_quant=False):iftype(mod)==torch.ao.nn.intrinsic.qat.ConvBnReLU3d:assertmod.bn.running_varisnotNoneandmod.bn.running_meanisnotNonemod.weight,mod.bias=fuse_conv_bn_weights(mod.weight,mod.bias,mod.bn.running_mean,mod.bn.running_var,mod.bn.eps,mod.bn.weight,mod.bn.bias,)returnsuper().from_float(mod,use_precomputed_fake_quant=use_precomputed_fake_quant)@classmethoddeffrom_reference(cls,ref_qconv,output_scale,output_zero_point):assert(type(ref_qconv)!=torch.ao.nn.intrinsic.ConvBnReLU3d),"BatchNorm3d should be fused into Conv3d before converting to reference module"returnsuper().from_reference(ref_qconv[0],output_scale,output_zero_point)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.