fromfunctoolsimportpartialfromtypingimportAny,cast,Dict,List,Optional,Unionimporttorchimporttorch.nnasnnfrom..transforms._presetsimportImageClassificationfrom..utilsimport_log_api_usage_oncefrom._apiimportregister_model,Weights,WeightsEnumfrom._metaimport_IMAGENET_CATEGORIESfrom._utilsimport_ovewrite_named_param,handle_legacy_interface__all__=["VGG","VGG11_Weights","VGG11_BN_Weights","VGG13_Weights","VGG13_BN_Weights","VGG16_Weights","VGG16_BN_Weights","VGG19_Weights","VGG19_BN_Weights","vgg11","vgg11_bn","vgg13","vgg13_bn","vgg16","vgg16_bn","vgg19","vgg19_bn",]classVGG(nn.Module):def__init__(self,features:nn.Module,num_classes:int=1000,init_weights:bool=True,dropout:float=0.5)->None:super().__init__()_log_api_usage_once(self)self.features=featuresself.avgpool=nn.AdaptiveAvgPool2d((7,7))self.classifier=nn.Sequential(nn.Linear(512*7*7,4096),nn.ReLU(True),nn.Dropout(p=dropout),nn.Linear(4096,4096),nn.ReLU(True),nn.Dropout(p=dropout),nn.Linear(4096,num_classes),)ifinit_weights:forminself.modules():ifisinstance(m,nn.Conv2d):nn.init.kaiming_normal_(m.weight,mode="fan_out",nonlinearity="relu")ifm.biasisnotNone:nn.init.constant_(m.bias,0)elifisinstance(m,nn.BatchNorm2d):nn.init.constant_(m.weight,1)nn.init.constant_(m.bias,0)elifisinstance(m,nn.Linear):nn.init.normal_(m.weight,0,0.01)nn.init.constant_(m.bias,0)defforward(self,x:torch.Tensor)->torch.Tensor:x=self.features(x)x=self.avgpool(x)x=torch.flatten(x,1)x=self.classifier(x)returnxdefmake_layers(cfg:List[Union[str,int]],batch_norm:bool=False)->nn.Sequential:layers:List[nn.Module]=[]in_channels=3forvincfg:ifv=="M":layers+=[nn.MaxPool2d(kernel_size=2,stride=2)]else:v=cast(int,v)conv2d=nn.Conv2d(in_channels,v,kernel_size=3,padding=1)ifbatch_norm:layers+=[conv2d,nn.BatchNorm2d(v),nn.ReLU(inplace=True)]else:layers+=[conv2d,nn.ReLU(inplace=True)]in_channels=vreturnnn.Sequential(*layers)cfgs:Dict[str,List[Union[str,int]]]={"A":[64,"M",128,"M",256,256,"M",512,512,"M",512,512,"M"],"B":[64,64,"M",128,128,"M",256,256,"M",512,512,"M",512,512,"M"],"D":[64,64,"M",128,128,"M",256,256,256,"M",512,512,512,"M",512,512,512,"M"],"E":[64,64,"M",128,128,"M",256,256,256,256,"M",512,512,512,512,"M",512,512,512,512,"M"],}def_vgg(cfg:str,batch_norm:bool,weights:Optional[WeightsEnum],progress:bool,**kwargs:Any)->VGG:ifweightsisnotNone:kwargs["init_weights"]=Falseifweights.meta["categories"]isnotNone:_ovewrite_named_param(kwargs,"num_classes",len(weights.meta["categories"]))model=VGG(make_layers(cfgs[cfg],batch_norm=batch_norm),**kwargs)ifweightsisnotNone:model.load_state_dict(weights.get_state_dict(progress=progress,check_hash=True))returnmodel_COMMON_META={"min_size":(32,32),"categories":_IMAGENET_CATEGORIES,"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg","_docs":"""These weights were trained from scratch by using a simplified training recipe.""",}
[docs]classVGG16_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/vgg16-397923af.pth",transforms=partial(ImageClassification,crop_size=224),meta={**_COMMON_META,"num_params":138357544,"_metrics":{"ImageNet-1K":{"acc@1":71.592,"acc@5":90.382,}},"_ops":15.47,"_file_size":527.796,},)IMAGENET1K_FEATURES=Weights(# Weights ported from https://github.com/amdegroot/ssd.pytorch/url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth",transforms=partial(ImageClassification,crop_size=224,mean=(0.48235,0.45882,0.40784),std=(1.0/255.0,1.0/255.0,1.0/255.0),),meta={**_COMMON_META,"num_params":138357544,"categories":None,"recipe":"https://github.com/amdegroot/ssd.pytorch#training-ssd","_metrics":{"ImageNet-1K":{"acc@1":float("nan"),"acc@5":float("nan"),}},"_ops":15.47,"_file_size":527.802,"_docs":""" These weights can't be used for classification because they are missing values in the `classifier` module. Only the `features` module has valid values and can be used for feature extraction. The weights were trained using the original input standardization method as described in the paper. """,},)DEFAULT=IMAGENET1K_V1
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",VGG11_Weights.IMAGENET1K_V1))defvgg11(*,weights:Optional[VGG11_Weights]=None,progress:bool=True,**kwargs:Any)->VGG:"""VGG-11 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__. Args: weights (:class:`~torchvision.models.VGG11_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.VGG11_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_ for more details about this class. .. autoclass:: torchvision.models.VGG11_Weights :members: """weights=VGG11_Weights.verify(weights)return_vgg("A",False,weights,progress,**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",VGG11_BN_Weights.IMAGENET1K_V1))defvgg11_bn(*,weights:Optional[VGG11_BN_Weights]=None,progress:bool=True,**kwargs:Any)->VGG:"""VGG-11-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__. Args: weights (:class:`~torchvision.models.VGG11_BN_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.VGG11_BN_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_ for more details about this class. .. autoclass:: torchvision.models.VGG11_BN_Weights :members: """weights=VGG11_BN_Weights.verify(weights)return_vgg("A",True,weights,progress,**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",VGG13_Weights.IMAGENET1K_V1))defvgg13(*,weights:Optional[VGG13_Weights]=None,progress:bool=True,**kwargs:Any)->VGG:"""VGG-13 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__. Args: weights (:class:`~torchvision.models.VGG13_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.VGG13_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_ for more details about this class. .. autoclass:: torchvision.models.VGG13_Weights :members: """weights=VGG13_Weights.verify(weights)return_vgg("B",False,weights,progress,**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",VGG13_BN_Weights.IMAGENET1K_V1))defvgg13_bn(*,weights:Optional[VGG13_BN_Weights]=None,progress:bool=True,**kwargs:Any)->VGG:"""VGG-13-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__. Args: weights (:class:`~torchvision.models.VGG13_BN_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.VGG13_BN_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_ for more details about this class. .. autoclass:: torchvision.models.VGG13_BN_Weights :members: """weights=VGG13_BN_Weights.verify(weights)return_vgg("B",True,weights,progress,**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",VGG16_Weights.IMAGENET1K_V1))defvgg16(*,weights:Optional[VGG16_Weights]=None,progress:bool=True,**kwargs:Any)->VGG:"""VGG-16 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__. Args: weights (:class:`~torchvision.models.VGG16_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.VGG16_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_ for more details about this class. .. autoclass:: torchvision.models.VGG16_Weights :members: """weights=VGG16_Weights.verify(weights)return_vgg("D",False,weights,progress,**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",VGG16_BN_Weights.IMAGENET1K_V1))defvgg16_bn(*,weights:Optional[VGG16_BN_Weights]=None,progress:bool=True,**kwargs:Any)->VGG:"""VGG-16-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__. Args: weights (:class:`~torchvision.models.VGG16_BN_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.VGG16_BN_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_ for more details about this class. .. autoclass:: torchvision.models.VGG16_BN_Weights :members: """weights=VGG16_BN_Weights.verify(weights)return_vgg("D",True,weights,progress,**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",VGG19_Weights.IMAGENET1K_V1))defvgg19(*,weights:Optional[VGG19_Weights]=None,progress:bool=True,**kwargs:Any)->VGG:"""VGG-19 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__. Args: weights (:class:`~torchvision.models.VGG19_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.VGG19_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_ for more details about this class. .. autoclass:: torchvision.models.VGG19_Weights :members: """weights=VGG19_Weights.verify(weights)return_vgg("E",False,weights,progress,**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",VGG19_BN_Weights.IMAGENET1K_V1))defvgg19_bn(*,weights:Optional[VGG19_BN_Weights]=None,progress:bool=True,**kwargs:Any)->VGG:"""VGG-19_BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__. Args: weights (:class:`~torchvision.models.VGG19_BN_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.VGG19_BN_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_ for more details about this class. .. autoclass:: torchvision.models.VGG19_BN_Weights :members: """weights=VGG19_BN_Weights.verify(weights)return_vgg("E",True,weights,progress,**kwargs)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.