Source code for torchvision.models.quantization.resnet
fromfunctoolsimportpartialfromtypingimportAny,List,Optional,Type,Unionimporttorchimporttorch.nnasnnfromtorchimportTensorfromtorchvision.models.resnetimport(BasicBlock,Bottleneck,ResNet,ResNet18_Weights,ResNet50_Weights,ResNeXt101_32X8D_Weights,ResNeXt101_64X4D_Weights,)from...transforms._presetsimportImageClassificationfrom.._apiimportregister_model,Weights,WeightsEnumfrom.._metaimport_IMAGENET_CATEGORIESfrom.._utilsimport_ovewrite_named_param,handle_legacy_interfacefrom.utilsimport_fuse_modules,_replace_relu,quantize_model__all__=["QuantizableResNet","ResNet18_QuantizedWeights","ResNet50_QuantizedWeights","ResNeXt101_32X8D_QuantizedWeights","ResNeXt101_64X4D_QuantizedWeights","resnet18","resnet50","resnext101_32x8d","resnext101_64x4d",]classQuantizableBasicBlock(BasicBlock):def__init__(self,*args:Any,**kwargs:Any)->None:super().__init__(*args,**kwargs)self.add_relu=torch.nn.quantized.FloatFunctional()defforward(self,x:Tensor)->Tensor:identity=xout=self.conv1(x)out=self.bn1(out)out=self.relu(out)out=self.conv2(out)out=self.bn2(out)ifself.downsampleisnotNone:identity=self.downsample(x)out=self.add_relu.add_relu(out,identity)returnoutdeffuse_model(self,is_qat:Optional[bool]=None)->None:_fuse_modules(self,[["conv1","bn1","relu"],["conv2","bn2"]],is_qat,inplace=True)ifself.downsample:_fuse_modules(self.downsample,["0","1"],is_qat,inplace=True)classQuantizableBottleneck(Bottleneck):def__init__(self,*args:Any,**kwargs:Any)->None:super().__init__(*args,**kwargs)self.skip_add_relu=nn.quantized.FloatFunctional()self.relu1=nn.ReLU(inplace=False)self.relu2=nn.ReLU(inplace=False)defforward(self,x:Tensor)->Tensor:identity=xout=self.conv1(x)out=self.bn1(out)out=self.relu1(out)out=self.conv2(out)out=self.bn2(out)out=self.relu2(out)out=self.conv3(out)out=self.bn3(out)ifself.downsampleisnotNone:identity=self.downsample(x)out=self.skip_add_relu.add_relu(out,identity)returnoutdeffuse_model(self,is_qat:Optional[bool]=None)->None:_fuse_modules(self,[["conv1","bn1","relu1"],["conv2","bn2","relu2"],["conv3","bn3"]],is_qat,inplace=True)ifself.downsample:_fuse_modules(self.downsample,["0","1"],is_qat,inplace=True)classQuantizableResNet(ResNet):def__init__(self,*args:Any,**kwargs:Any)->None:super().__init__(*args,**kwargs)self.quant=torch.ao.quantization.QuantStub()self.dequant=torch.ao.quantization.DeQuantStub()defforward(self,x:Tensor)->Tensor:x=self.quant(x)# Ensure scriptability# super(QuantizableResNet,self).forward(x)# is not scriptablex=self._forward_impl(x)x=self.dequant(x)returnxdeffuse_model(self,is_qat:Optional[bool]=None)->None:r"""Fuse conv/bn/relu modules in resnet models Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization. Model is modified in place. Note that this operation does not change numerics and the model after modification is in floating point """_fuse_modules(self,["conv1","bn1","relu"],is_qat,inplace=True)forminself.modules():iftype(m)isQuantizableBottleneckortype(m)isQuantizableBasicBlock:m.fuse_model(is_qat)def_resnet(block:Type[Union[QuantizableBasicBlock,QuantizableBottleneck]],layers:List[int],weights:Optional[WeightsEnum],progress:bool,quantize:bool,**kwargs:Any,)->QuantizableResNet:ifweightsisnotNone:_ovewrite_named_param(kwargs,"num_classes",len(weights.meta["categories"]))if"backend"inweights.meta:_ovewrite_named_param(kwargs,"backend",weights.meta["backend"])backend=kwargs.pop("backend","fbgemm")model=QuantizableResNet(block,layers,**kwargs)_replace_relu(model)ifquantize:quantize_model(model,backend)ifweightsisnotNone:model.load_state_dict(weights.get_state_dict(progress=progress,check_hash=True))returnmodel_COMMON_META={"min_size":(1,1),"categories":_IMAGENET_CATEGORIES,"backend":"fbgemm","recipe":"https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models","_docs":""" These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized weights listed below. """,}
[docs]@register_model(name="quantized_resnet18")@handle_legacy_interface(weights=("pretrained",lambdakwargs:ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1ifkwargs.get("quantize",False)elseResNet18_Weights.IMAGENET1K_V1,))defresnet18(*,weights:Optional[Union[ResNet18_QuantizedWeights,ResNet18_Weights]]=None,progress:bool=True,quantize:bool=False,**kwargs:Any,)->QuantizableResNet:"""ResNet-18 model from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`_ .. note:: Note that ``quantize = True`` returns a quantized model with 8 bit weights. Quantized models only support inference and run on CPUs. GPU inference is not yet supported. Args: weights (:class:`~torchvision.models.quantization.ResNet18_QuantizedWeights` or :class:`~torchvision.models.ResNet18_Weights`, optional): The pretrained weights for the model. See :class:`~torchvision.models.quantization.ResNet18_QuantizedWeights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. quantize (bool, optional): If True, return a quantized version of the model. Default is False. **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.quantization.ResNet18_QuantizedWeights :members: .. autoclass:: torchvision.models.ResNet18_Weights :members: :noindex: """weights=(ResNet18_QuantizedWeightsifquantizeelseResNet18_Weights).verify(weights)return_resnet(QuantizableBasicBlock,[2,2,2,2],weights,progress,quantize,**kwargs)
[docs]@register_model(name="quantized_resnet50")@handle_legacy_interface(weights=("pretrained",lambdakwargs:ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1ifkwargs.get("quantize",False)elseResNet50_Weights.IMAGENET1K_V1,))defresnet50(*,weights:Optional[Union[ResNet50_QuantizedWeights,ResNet50_Weights]]=None,progress:bool=True,quantize:bool=False,**kwargs:Any,)->QuantizableResNet:"""ResNet-50 model from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`_ .. note:: Note that ``quantize = True`` returns a quantized model with 8 bit weights. Quantized models only support inference and run on CPUs. GPU inference is not yet supported. Args: weights (:class:`~torchvision.models.quantization.ResNet50_QuantizedWeights` or :class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for the model. See :class:`~torchvision.models.quantization.ResNet50_QuantizedWeights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. quantize (bool, optional): If True, return a quantized version of the model. Default is False. **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.quantization.ResNet50_QuantizedWeights :members: .. autoclass:: torchvision.models.ResNet50_Weights :members: :noindex: """weights=(ResNet50_QuantizedWeightsifquantizeelseResNet50_Weights).verify(weights)return_resnet(QuantizableBottleneck,[3,4,6,3],weights,progress,quantize,**kwargs)
[docs]@register_model(name="quantized_resnext101_32x8d")@handle_legacy_interface(weights=("pretrained",lambdakwargs:ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1ifkwargs.get("quantize",False)elseResNeXt101_32X8D_Weights.IMAGENET1K_V1,))defresnext101_32x8d(*,weights:Optional[Union[ResNeXt101_32X8D_QuantizedWeights,ResNeXt101_32X8D_Weights]]=None,progress:bool=True,quantize:bool=False,**kwargs:Any,)->QuantizableResNet:"""ResNeXt-101 32x8d model from `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_ .. note:: Note that ``quantize = True`` returns a quantized model with 8 bit weights. Quantized models only support inference and run on CPUs. GPU inference is not yet supported. Args: weights (:class:`~torchvision.models.quantization.ResNeXt101_32X8D_QuantizedWeights` or :class:`~torchvision.models.ResNeXt101_32X8D_Weights`, optional): The pretrained weights for the model. See :class:`~torchvision.models.quantization.ResNet101_32X8D_QuantizedWeights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. quantize (bool, optional): If True, return a quantized version of the model. Default is False. **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.quantization.ResNeXt101_32X8D_QuantizedWeights :members: .. autoclass:: torchvision.models.ResNeXt101_32X8D_Weights :members: :noindex: """weights=(ResNeXt101_32X8D_QuantizedWeightsifquantizeelseResNeXt101_32X8D_Weights).verify(weights)_ovewrite_named_param(kwargs,"groups",32)_ovewrite_named_param(kwargs,"width_per_group",8)return_resnet(QuantizableBottleneck,[3,4,23,3],weights,progress,quantize,**kwargs)
[docs]@register_model(name="quantized_resnext101_64x4d")@handle_legacy_interface(weights=("pretrained",lambdakwargs:ResNeXt101_64X4D_QuantizedWeights.IMAGENET1K_FBGEMM_V1ifkwargs.get("quantize",False)elseResNeXt101_64X4D_Weights.IMAGENET1K_V1,))defresnext101_64x4d(*,weights:Optional[Union[ResNeXt101_64X4D_QuantizedWeights,ResNeXt101_64X4D_Weights]]=None,progress:bool=True,quantize:bool=False,**kwargs:Any,)->QuantizableResNet:"""ResNeXt-101 64x4d model from `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_ .. note:: Note that ``quantize = True`` returns a quantized model with 8 bit weights. Quantized models only support inference and run on CPUs. GPU inference is not yet supported. Args: weights (:class:`~torchvision.models.quantization.ResNeXt101_64X4D_QuantizedWeights` or :class:`~torchvision.models.ResNeXt101_64X4D_Weights`, optional): The pretrained weights for the model. See :class:`~torchvision.models.quantization.ResNet101_64X4D_QuantizedWeights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. quantize (bool, optional): If True, return a quantized version of the model. Default is False. **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.quantization.ResNeXt101_64X4D_QuantizedWeights :members: .. autoclass:: torchvision.models.ResNeXt101_64X4D_Weights :members: :noindex: """weights=(ResNeXt101_64X4D_QuantizedWeightsifquantizeelseResNeXt101_64X4D_Weights).verify(weights)_ovewrite_named_param(kwargs,"groups",64)_ovewrite_named_param(kwargs,"width_per_group",4)return_resnet(QuantizableBottleneck,[3,4,23,3],weights,progress,quantize,**kwargs)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.