Source code for torchvision.models.quantization.googlenet
importwarningsfromfunctoolsimportpartialfromtypingimportAny,Optional,Unionimporttorchimporttorch.nnasnnfromtorchimportTensorfromtorch.nnimportfunctionalasFfrom...transforms._presetsimportImageClassificationfrom.._apiimportregister_model,Weights,WeightsEnumfrom.._metaimport_IMAGENET_CATEGORIESfrom.._utilsimport_ovewrite_named_param,handle_legacy_interfacefrom..googlenetimportBasicConv2d,GoogLeNet,GoogLeNet_Weights,GoogLeNetOutputs,Inception,InceptionAuxfrom.utilsimport_fuse_modules,_replace_relu,quantize_model__all__=["QuantizableGoogLeNet","GoogLeNet_QuantizedWeights","googlenet",]classQuantizableBasicConv2d(BasicConv2d):def__init__(self,*args:Any,**kwargs:Any)->None:super().__init__(*args,**kwargs)self.relu=nn.ReLU()defforward(self,x:Tensor)->Tensor:x=self.conv(x)x=self.bn(x)x=self.relu(x)returnxdeffuse_model(self,is_qat:Optional[bool]=None)->None:_fuse_modules(self,["conv","bn","relu"],is_qat,inplace=True)classQuantizableInception(Inception):def__init__(self,*args:Any,**kwargs:Any)->None:super().__init__(*args,conv_block=QuantizableBasicConv2d,**kwargs)# type: ignore[misc]self.cat=nn.quantized.FloatFunctional()defforward(self,x:Tensor)->Tensor:outputs=self._forward(x)returnself.cat.cat(outputs,1)classQuantizableInceptionAux(InceptionAux):# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659def__init__(self,*args:Any,**kwargs:Any)->None:super().__init__(*args,conv_block=QuantizableBasicConv2d,**kwargs)# type: ignore[misc]self.relu=nn.ReLU()defforward(self,x:Tensor)->Tensor:# aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14x=F.adaptive_avg_pool2d(x,(4,4))# aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4x=self.conv(x)# N x 128 x 4 x 4x=torch.flatten(x,1)# N x 2048x=self.relu(self.fc1(x))# N x 1024x=self.dropout(x)# N x 1024x=self.fc2(x)# N x 1000 (num_classes)returnxclassQuantizableGoogLeNet(GoogLeNet):# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659def__init__(self,*args:Any,**kwargs:Any)->None:super().__init__(# type: ignore[misc]*args,blocks=[QuantizableBasicConv2d,QuantizableInception,QuantizableInceptionAux],**kwargs)self.quant=torch.ao.quantization.QuantStub()self.dequant=torch.ao.quantization.DeQuantStub()defforward(self,x:Tensor)->GoogLeNetOutputs:x=self._transform_input(x)x=self.quant(x)x,aux1,aux2=self._forward(x)x=self.dequant(x)aux_defined=self.trainingandself.aux_logitsiftorch.jit.is_scripting():ifnotaux_defined:warnings.warn("Scripted QuantizableGoogleNet always returns GoogleNetOutputs Tuple")returnGoogLeNetOutputs(x,aux2,aux1)else:returnself.eager_outputs(x,aux2,aux1)deffuse_model(self,is_qat:Optional[bool]=None)->None:r"""Fuse conv/bn/relu modules in googlenet model Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization. Model is modified in place. Note that this operation does not change numerics and the model after modification is in floating point """forminself.modules():iftype(m)isQuantizableBasicConv2d:m.fuse_model(is_qat)
[docs]classGoogLeNet_QuantizedWeights(WeightsEnum):IMAGENET1K_FBGEMM_V1=Weights(url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c81f6644.pth",transforms=partial(ImageClassification,crop_size=224),meta={"num_params":6624904,"min_size":(15,15),"categories":_IMAGENET_CATEGORIES,"backend":"fbgemm","recipe":"https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models","unquantized":GoogLeNet_Weights.IMAGENET1K_V1,"_metrics":{"ImageNet-1K":{"acc@1":69.826,"acc@5":89.404,}},"_ops":1.498,"_file_size":12.618,"_docs":""" These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized weights listed below. """,},)DEFAULT=IMAGENET1K_FBGEMM_V1
[docs]@register_model(name="quantized_googlenet")@handle_legacy_interface(weights=("pretrained",lambdakwargs:GoogLeNet_QuantizedWeights.IMAGENET1K_FBGEMM_V1ifkwargs.get("quantize",False)elseGoogLeNet_Weights.IMAGENET1K_V1,))defgooglenet(*,weights:Optional[Union[GoogLeNet_QuantizedWeights,GoogLeNet_Weights]]=None,progress:bool=True,quantize:bool=False,**kwargs:Any,)->QuantizableGoogLeNet:"""GoogLeNet (Inception v1) model architecture from `Going Deeper with Convolutions <http://arxiv.org/abs/1409.4842>`__. .. note:: Note that ``quantize = True`` returns a quantized model with 8 bit weights. Quantized models only support inference and run on CPUs. GPU inference is not yet supported. Args: weights (:class:`~torchvision.models.quantization.GoogLeNet_QuantizedWeights` or :class:`~torchvision.models.GoogLeNet_Weights`, optional): The pretrained weights for the model. See :class:`~torchvision.models.quantization.GoogLeNet_QuantizedWeights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. quantize (bool, optional): If True, return a quantized version of the model. Default is False. **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableGoogLeNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/googlenet.py>`_ for more details about this class. .. autoclass:: torchvision.models.quantization.GoogLeNet_QuantizedWeights :members: .. autoclass:: torchvision.models.GoogLeNet_Weights :members: :noindex: """weights=(GoogLeNet_QuantizedWeightsifquantizeelseGoogLeNet_Weights).verify(weights)original_aux_logits=kwargs.get("aux_logits",False)ifweightsisnotNone:if"transform_input"notinkwargs:_ovewrite_named_param(kwargs,"transform_input",True)_ovewrite_named_param(kwargs,"aux_logits",True)_ovewrite_named_param(kwargs,"init_weights",False)_ovewrite_named_param(kwargs,"num_classes",len(weights.meta["categories"]))if"backend"inweights.meta:_ovewrite_named_param(kwargs,"backend",weights.meta["backend"])backend=kwargs.pop("backend","fbgemm")model=QuantizableGoogLeNet(**kwargs)_replace_relu(model)ifquantize:quantize_model(model,backend)ifweightsisnotNone:model.load_state_dict(weights.get_state_dict(progress=progress,check_hash=True))ifnotoriginal_aux_logits:model.aux_logits=Falsemodel.aux1=None# type: ignore[assignment]model.aux2=None# type: ignore[assignment]else:warnings.warn("auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them")returnmodel
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.