fromfunctoolsimportpartialfromtypingimportAny,Optionalimporttorchimporttorch.nnasnnimporttorch.nn.initasinitfrom..transforms._presetsimportImageClassificationfrom..utilsimport_log_api_usage_oncefrom._apiimportWeightsEnum,Weightsfrom._metaimport_IMAGENET_CATEGORIESfrom._utilsimporthandle_legacy_interface,_ovewrite_named_param__all__=["SqueezeNet","SqueezeNet1_0_Weights","SqueezeNet1_1_Weights","squeezenet1_0","squeezenet1_1"]classFire(nn.Module):def__init__(self,inplanes:int,squeeze_planes:int,expand1x1_planes:int,expand3x3_planes:int)->None:super().__init__()self.inplanes=inplanesself.squeeze=nn.Conv2d(inplanes,squeeze_planes,kernel_size=1)self.squeeze_activation=nn.ReLU(inplace=True)self.expand1x1=nn.Conv2d(squeeze_planes,expand1x1_planes,kernel_size=1)self.expand1x1_activation=nn.ReLU(inplace=True)self.expand3x3=nn.Conv2d(squeeze_planes,expand3x3_planes,kernel_size=3,padding=1)self.expand3x3_activation=nn.ReLU(inplace=True)defforward(self,x:torch.Tensor)->torch.Tensor:x=self.squeeze_activation(self.squeeze(x))returntorch.cat([self.expand1x1_activation(self.expand1x1(x)),self.expand3x3_activation(self.expand3x3(x))],1)classSqueezeNet(nn.Module):def__init__(self,version:str="1_0",num_classes:int=1000,dropout:float=0.5)->None:super().__init__()_log_api_usage_once(self)self.num_classes=num_classesifversion=="1_0":self.features=nn.Sequential(nn.Conv2d(3,96,kernel_size=7,stride=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3,stride=2,ceil_mode=True),Fire(96,16,64,64),Fire(128,16,64,64),Fire(128,32,128,128),nn.MaxPool2d(kernel_size=3,stride=2,ceil_mode=True),Fire(256,32,128,128),Fire(256,48,192,192),Fire(384,48,192,192),Fire(384,64,256,256),nn.MaxPool2d(kernel_size=3,stride=2,ceil_mode=True),Fire(512,64,256,256),)elifversion=="1_1":self.features=nn.Sequential(nn.Conv2d(3,64,kernel_size=3,stride=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3,stride=2,ceil_mode=True),Fire(64,16,64,64),Fire(128,16,64,64),nn.MaxPool2d(kernel_size=3,stride=2,ceil_mode=True),Fire(128,32,128,128),Fire(256,32,128,128),nn.MaxPool2d(kernel_size=3,stride=2,ceil_mode=True),Fire(256,48,192,192),Fire(384,48,192,192),Fire(384,64,256,256),Fire(512,64,256,256),)else:# FIXME: Is this needed? SqueezeNet should only be called from the# FIXME: squeezenet1_x() functions# FIXME: This checking is not done for the other modelsraiseValueError(f"Unsupported SqueezeNet version {version}: 1_0 or 1_1 expected")# Final convolution is initialized differently from the restfinal_conv=nn.Conv2d(512,self.num_classes,kernel_size=1)self.classifier=nn.Sequential(nn.Dropout(p=dropout),final_conv,nn.ReLU(inplace=True),nn.AdaptiveAvgPool2d((1,1)))forminself.modules():ifisinstance(m,nn.Conv2d):ifmisfinal_conv:init.normal_(m.weight,mean=0.0,std=0.01)else:init.kaiming_uniform_(m.weight)ifm.biasisnotNone:init.constant_(m.bias,0)defforward(self,x:torch.Tensor)->torch.Tensor:x=self.features(x)x=self.classifier(x)returntorch.flatten(x,1)def_squeezenet(version:str,weights:Optional[WeightsEnum],progress:bool,**kwargs:Any,)->SqueezeNet:ifweightsisnotNone:_ovewrite_named_param(kwargs,"num_classes",len(weights.meta["categories"]))model=SqueezeNet(version,**kwargs)ifweightsisnotNone:model.load_state_dict(weights.get_state_dict(progress=progress))returnmodel_COMMON_META={"categories":_IMAGENET_CATEGORIES,"recipe":"https://github.com/pytorch/vision/pull/49#issuecomment-277560717","_docs":"""These weights reproduce closely the results of the paper using a simple training recipe.""",}
[docs]@handle_legacy_interface(weights=("pretrained",SqueezeNet1_0_Weights.IMAGENET1K_V1))defsqueezenet1_0(*,weights:Optional[SqueezeNet1_0_Weights]=None,progress:bool=True,**kwargs:Any)->SqueezeNet:"""SqueezeNet model architecture from the `SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size <https://arxiv.org/abs/1602.07360>`_ paper. Args: weights (:class:`~torchvision.models.SqueezeNet1_0_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.SqueezeNet1_0_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.squeezenet.SqueezeNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/squeezenet.py>`_ for more details about this class. .. autoclass:: torchvision.models.SqueezeNet1_0_Weights :members: """weights=SqueezeNet1_0_Weights.verify(weights)return_squeezenet("1_0",weights,progress,**kwargs)
[docs]@handle_legacy_interface(weights=("pretrained",SqueezeNet1_1_Weights.IMAGENET1K_V1))defsqueezenet1_1(*,weights:Optional[SqueezeNet1_1_Weights]=None,progress:bool=True,**kwargs:Any)->SqueezeNet:"""SqueezeNet 1.1 model from the `official SqueezeNet repo <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_. SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters than SqueezeNet 1.0, without sacrificing accuracy. Args: weights (:class:`~torchvision.models.SqueezeNet1_1_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.SqueezeNet1_1_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.squeezenet.SqueezeNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/squeezenet.py>`_ for more details about this class. .. autoclass:: torchvision.models.SqueezeNet1_1_Weights :members: """weights=SqueezeNet1_1_Weights.verify(weights)return_squeezenet("1_1",weights,progress,**kwargs)
# The dictionary below is internal implementation detail and will be removed in v0.15from._utilsimport_ModelURLsmodel_urls=_ModelURLs({"squeezenet1_0":SqueezeNet1_0_Weights.IMAGENET1K_V1.url,"squeezenet1_1":SqueezeNet1_1_Weights.IMAGENET1K_V1.url,})
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.