fromfunctoolsimportpartialfromtypingimportAny,Callable,List,Optional,Type,Unionimporttorchimporttorch.nnasnnfromtorchimportTensorfrom..transforms._presetsimportImageClassificationfrom..utilsimport_log_api_usage_oncefrom._apiimportregister_model,Weights,WeightsEnumfrom._metaimport_IMAGENET_CATEGORIESfrom._utilsimport_ovewrite_named_param,handle_legacy_interface__all__=["ResNet","ResNet18_Weights","ResNet34_Weights","ResNet50_Weights","ResNet101_Weights","ResNet152_Weights","ResNeXt50_32X4D_Weights","ResNeXt101_32X8D_Weights","ResNeXt101_64X4D_Weights","Wide_ResNet50_2_Weights","Wide_ResNet101_2_Weights","resnet18","resnet34","resnet50","resnet101","resnet152","resnext50_32x4d","resnext101_32x8d","resnext101_64x4d","wide_resnet50_2","wide_resnet101_2",]defconv3x3(in_planes:int,out_planes:int,stride:int=1,groups:int=1,dilation:int=1)->nn.Conv2d:"""3x3 convolution with padding"""returnnn.Conv2d(in_planes,out_planes,kernel_size=3,stride=stride,padding=dilation,groups=groups,bias=False,dilation=dilation,)defconv1x1(in_planes:int,out_planes:int,stride:int=1)->nn.Conv2d:"""1x1 convolution"""returnnn.Conv2d(in_planes,out_planes,kernel_size=1,stride=stride,bias=False)classBasicBlock(nn.Module):expansion:int=1def__init__(self,inplanes:int,planes:int,stride:int=1,downsample:Optional[nn.Module]=None,groups:int=1,base_width:int=64,dilation:int=1,norm_layer:Optional[Callable[...,nn.Module]]=None,)->None:super().__init__()ifnorm_layerisNone:norm_layer=nn.BatchNorm2difgroups!=1orbase_width!=64:raiseValueError("BasicBlock only supports groups=1 and base_width=64")ifdilation>1:raiseNotImplementedError("Dilation > 1 not supported in BasicBlock")# Both self.conv1 and self.downsample layers downsample the input when stride != 1self.conv1=conv3x3(inplanes,planes,stride)self.bn1=norm_layer(planes)self.relu=nn.ReLU(inplace=True)self.conv2=conv3x3(planes,planes)self.bn2=norm_layer(planes)self.downsample=downsampleself.stride=stridedefforward(self,x:Tensor)->Tensor:identity=xout=self.conv1(x)out=self.bn1(out)out=self.relu(out)out=self.conv2(out)out=self.bn2(out)ifself.downsampleisnotNone:identity=self.downsample(x)out+=identityout=self.relu(out)returnoutclassBottleneck(nn.Module):# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)# while original implementation places the stride at the first 1x1 convolution(self.conv1)# according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385.# This variant is also known as ResNet V1.5 and improves accuracy according to# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.expansion:int=4def__init__(self,inplanes:int,planes:int,stride:int=1,downsample:Optional[nn.Module]=None,groups:int=1,base_width:int=64,dilation:int=1,norm_layer:Optional[Callable[...,nn.Module]]=None,)->None:super().__init__()ifnorm_layerisNone:norm_layer=nn.BatchNorm2dwidth=int(planes*(base_width/64.0))*groups# Both self.conv2 and self.downsample layers downsample the input when stride != 1self.conv1=conv1x1(inplanes,width)self.bn1=norm_layer(width)self.conv2=conv3x3(width,width,stride,groups,dilation)self.bn2=norm_layer(width)self.conv3=conv1x1(width,planes*self.expansion)self.bn3=norm_layer(planes*self.expansion)self.relu=nn.ReLU(inplace=True)self.downsample=downsampleself.stride=stridedefforward(self,x:Tensor)->Tensor:identity=xout=self.conv1(x)out=self.bn1(out)out=self.relu(out)out=self.conv2(out)out=self.bn2(out)out=self.relu(out)out=self.conv3(out)out=self.bn3(out)ifself.downsampleisnotNone:identity=self.downsample(x)out+=identityout=self.relu(out)returnoutclassResNet(nn.Module):def__init__(self,block:Type[Union[BasicBlock,Bottleneck]],layers:List[int],num_classes:int=1000,zero_init_residual:bool=False,groups:int=1,width_per_group:int=64,replace_stride_with_dilation:Optional[List[bool]]=None,norm_layer:Optional[Callable[...,nn.Module]]=None,)->None:super().__init__()_log_api_usage_once(self)ifnorm_layerisNone:norm_layer=nn.BatchNorm2dself._norm_layer=norm_layerself.inplanes=64self.dilation=1ifreplace_stride_with_dilationisNone:# each element in the tuple indicates if we should replace# the 2x2 stride with a dilated convolution insteadreplace_stride_with_dilation=[False,False,False]iflen(replace_stride_with_dilation)!=3:raiseValueError("replace_stride_with_dilation should be None "f"or a 3-element tuple, got {replace_stride_with_dilation}")self.groups=groupsself.base_width=width_per_groupself.conv1=nn.Conv2d(3,self.inplanes,kernel_size=7,stride=2,padding=3,bias=False)self.bn1=norm_layer(self.inplanes)self.relu=nn.ReLU(inplace=True)self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)self.layer1=self._make_layer(block,64,layers[0])self.layer2=self._make_layer(block,128,layers[1],stride=2,dilate=replace_stride_with_dilation[0])self.layer3=self._make_layer(block,256,layers[2],stride=2,dilate=replace_stride_with_dilation[1])self.layer4=self._make_layer(block,512,layers[3],stride=2,dilate=replace_stride_with_dilation[2])self.avgpool=nn.AdaptiveAvgPool2d((1,1))self.fc=nn.Linear(512*block.expansion,num_classes)forminself.modules():ifisinstance(m,nn.Conv2d):nn.init.kaiming_normal_(m.weight,mode="fan_out",nonlinearity="relu")elifisinstance(m,(nn.BatchNorm2d,nn.GroupNorm)):nn.init.constant_(m.weight,1)nn.init.constant_(m.bias,0)# Zero-initialize the last BN in each residual branch,# so that the residual branch starts with zeros, and each residual block behaves like an identity.# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677ifzero_init_residual:forminself.modules():ifisinstance(m,Bottleneck)andm.bn3.weightisnotNone:nn.init.constant_(m.bn3.weight,0)# type: ignore[arg-type]elifisinstance(m,BasicBlock)andm.bn2.weightisnotNone:nn.init.constant_(m.bn2.weight,0)# type: ignore[arg-type]def_make_layer(self,block:Type[Union[BasicBlock,Bottleneck]],planes:int,blocks:int,stride:int=1,dilate:bool=False,)->nn.Sequential:norm_layer=self._norm_layerdownsample=Noneprevious_dilation=self.dilationifdilate:self.dilation*=stridestride=1ifstride!=1orself.inplanes!=planes*block.expansion:downsample=nn.Sequential(conv1x1(self.inplanes,planes*block.expansion,stride),norm_layer(planes*block.expansion),)layers=[]layers.append(block(self.inplanes,planes,stride,downsample,self.groups,self.base_width,previous_dilation,norm_layer))self.inplanes=planes*block.expansionfor_inrange(1,blocks):layers.append(block(self.inplanes,planes,groups=self.groups,base_width=self.base_width,dilation=self.dilation,norm_layer=norm_layer,))returnnn.Sequential(*layers)def_forward_impl(self,x:Tensor)->Tensor:# See note [TorchScript super()]x=self.conv1(x)x=self.bn1(x)x=self.relu(x)x=self.maxpool(x)x=self.layer1(x)x=self.layer2(x)x=self.layer3(x)x=self.layer4(x)x=self.avgpool(x)x=torch.flatten(x,1)x=self.fc(x)returnxdefforward(self,x:Tensor)->Tensor:returnself._forward_impl(x)def_resnet(block:Type[Union[BasicBlock,Bottleneck]],layers:List[int],weights:Optional[WeightsEnum],progress:bool,**kwargs:Any,)->ResNet:ifweightsisnotNone:_ovewrite_named_param(kwargs,"num_classes",len(weights.meta["categories"]))model=ResNet(block,layers,**kwargs)ifweightsisnotNone:model.load_state_dict(weights.get_state_dict(progress=progress,check_hash=True))returnmodel_COMMON_META={"min_size":(1,1),"categories":_IMAGENET_CATEGORIES,}
[docs]classResNet18_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/resnet18-f37072fd.pth",transforms=partial(ImageClassification,crop_size=224),meta={**_COMMON_META,"num_params":11689512,"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#resnet","_metrics":{"ImageNet-1K":{"acc@1":69.758,"acc@5":89.078,}},"_ops":1.814,"_file_size":44.661,"_docs":"""These weights reproduce closely the results of the paper using a simple training recipe.""",},)DEFAULT=IMAGENET1K_V1
[docs]classResNet34_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/resnet34-b627a593.pth",transforms=partial(ImageClassification,crop_size=224),meta={**_COMMON_META,"num_params":21797672,"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#resnet","_metrics":{"ImageNet-1K":{"acc@1":73.314,"acc@5":91.420,}},"_ops":3.664,"_file_size":83.275,"_docs":"""These weights reproduce closely the results of the paper using a simple training recipe.""",},)DEFAULT=IMAGENET1K_V1
[docs]classResNet50_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/resnet50-0676ba61.pth",transforms=partial(ImageClassification,crop_size=224),meta={**_COMMON_META,"num_params":25557032,"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#resnet","_metrics":{"ImageNet-1K":{"acc@1":76.130,"acc@5":92.862,}},"_ops":4.089,"_file_size":97.781,"_docs":"""These weights reproduce closely the results of the paper using a simple training recipe.""",},)IMAGENET1K_V2=Weights(url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth",transforms=partial(ImageClassification,crop_size=224,resize_size=232),meta={**_COMMON_META,"num_params":25557032,"recipe":"https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621","_metrics":{"ImageNet-1K":{"acc@1":80.858,"acc@5":95.434,}},"_ops":4.089,"_file_size":97.79,"_docs":""" These weights improve upon the results of the original paper by using TorchVision's `new training recipe <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. """,},)DEFAULT=IMAGENET1K_V2
[docs]classResNet101_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/resnet101-63fe2227.pth",transforms=partial(ImageClassification,crop_size=224),meta={**_COMMON_META,"num_params":44549160,"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#resnet","_metrics":{"ImageNet-1K":{"acc@1":77.374,"acc@5":93.546,}},"_ops":7.801,"_file_size":170.511,"_docs":"""These weights reproduce closely the results of the paper using a simple training recipe.""",},)IMAGENET1K_V2=Weights(url="https://download.pytorch.org/models/resnet101-cd907fc2.pth",transforms=partial(ImageClassification,crop_size=224,resize_size=232),meta={**_COMMON_META,"num_params":44549160,"recipe":"https://github.com/pytorch/vision/issues/3995#new-recipe","_metrics":{"ImageNet-1K":{"acc@1":81.886,"acc@5":95.780,}},"_ops":7.801,"_file_size":170.53,"_docs":""" These weights improve upon the results of the original paper by using TorchVision's `new training recipe <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. """,},)DEFAULT=IMAGENET1K_V2
[docs]classResNet152_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/resnet152-394f9c45.pth",transforms=partial(ImageClassification,crop_size=224),meta={**_COMMON_META,"num_params":60192808,"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#resnet","_metrics":{"ImageNet-1K":{"acc@1":78.312,"acc@5":94.046,}},"_ops":11.514,"_file_size":230.434,"_docs":"""These weights reproduce closely the results of the paper using a simple training recipe.""",},)IMAGENET1K_V2=Weights(url="https://download.pytorch.org/models/resnet152-f82ba261.pth",transforms=partial(ImageClassification,crop_size=224,resize_size=232),meta={**_COMMON_META,"num_params":60192808,"recipe":"https://github.com/pytorch/vision/issues/3995#new-recipe","_metrics":{"ImageNet-1K":{"acc@1":82.284,"acc@5":96.002,}},"_ops":11.514,"_file_size":230.474,"_docs":""" These weights improve upon the results of the original paper by using TorchVision's `new training recipe <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. """,},)DEFAULT=IMAGENET1K_V2
[docs]classResNeXt50_32X4D_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",transforms=partial(ImageClassification,crop_size=224),meta={**_COMMON_META,"num_params":25028904,"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#resnext","_metrics":{"ImageNet-1K":{"acc@1":77.618,"acc@5":93.698,}},"_ops":4.23,"_file_size":95.789,"_docs":"""These weights reproduce closely the results of the paper using a simple training recipe.""",},)IMAGENET1K_V2=Weights(url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth",transforms=partial(ImageClassification,crop_size=224,resize_size=232),meta={**_COMMON_META,"num_params":25028904,"recipe":"https://github.com/pytorch/vision/issues/3995#new-recipe","_metrics":{"ImageNet-1K":{"acc@1":81.198,"acc@5":95.340,}},"_ops":4.23,"_file_size":95.833,"_docs":""" These weights improve upon the results of the original paper by using TorchVision's `new training recipe <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. """,},)DEFAULT=IMAGENET1K_V2
[docs]classResNeXt101_32X8D_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",transforms=partial(ImageClassification,crop_size=224),meta={**_COMMON_META,"num_params":88791336,"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#resnext","_metrics":{"ImageNet-1K":{"acc@1":79.312,"acc@5":94.526,}},"_ops":16.414,"_file_size":339.586,"_docs":"""These weights reproduce closely the results of the paper using a simple training recipe.""",},)IMAGENET1K_V2=Weights(url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth",transforms=partial(ImageClassification,crop_size=224,resize_size=232),meta={**_COMMON_META,"num_params":88791336,"recipe":"https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres","_metrics":{"ImageNet-1K":{"acc@1":82.834,"acc@5":96.228,}},"_ops":16.414,"_file_size":339.673,"_docs":""" These weights improve upon the results of the original paper by using TorchVision's `new training recipe <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. """,},)DEFAULT=IMAGENET1K_V2
[docs]classResNeXt101_64X4D_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/resnext101_64x4d-173b62eb.pth",transforms=partial(ImageClassification,crop_size=224,resize_size=232),meta={**_COMMON_META,"num_params":83455272,"recipe":"https://github.com/pytorch/vision/pull/5935","_metrics":{"ImageNet-1K":{"acc@1":83.246,"acc@5":96.454,}},"_ops":15.46,"_file_size":319.318,"_docs":""" These weights were trained from scratch by using TorchVision's `new training recipe <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. """,},)DEFAULT=IMAGENET1K_V1
[docs]classWide_ResNet50_2_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",transforms=partial(ImageClassification,crop_size=224),meta={**_COMMON_META,"num_params":68883240,"recipe":"https://github.com/pytorch/vision/pull/912#issue-445437439","_metrics":{"ImageNet-1K":{"acc@1":78.468,"acc@5":94.086,}},"_ops":11.398,"_file_size":131.82,"_docs":"""These weights reproduce closely the results of the paper using a simple training recipe.""",},)IMAGENET1K_V2=Weights(url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",transforms=partial(ImageClassification,crop_size=224,resize_size=232),meta={**_COMMON_META,"num_params":68883240,"recipe":"https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres","_metrics":{"ImageNet-1K":{"acc@1":81.602,"acc@5":95.758,}},"_ops":11.398,"_file_size":263.124,"_docs":""" These weights improve upon the results of the original paper by using TorchVision's `new training recipe <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. """,},)DEFAULT=IMAGENET1K_V2
[docs]classWide_ResNet101_2_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",transforms=partial(ImageClassification,crop_size=224),meta={**_COMMON_META,"num_params":126886696,"recipe":"https://github.com/pytorch/vision/pull/912#issue-445437439","_metrics":{"ImageNet-1K":{"acc@1":78.848,"acc@5":94.284,}},"_ops":22.753,"_file_size":242.896,"_docs":"""These weights reproduce closely the results of the paper using a simple training recipe.""",},)IMAGENET1K_V2=Weights(url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth",transforms=partial(ImageClassification,crop_size=224,resize_size=232),meta={**_COMMON_META,"num_params":126886696,"recipe":"https://github.com/pytorch/vision/issues/3995#new-recipe","_metrics":{"ImageNet-1K":{"acc@1":82.510,"acc@5":96.020,}},"_ops":22.753,"_file_size":484.747,"_docs":""" These weights improve upon the results of the original paper by using TorchVision's `new training recipe <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. """,},)DEFAULT=IMAGENET1K_V2
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",ResNet18_Weights.IMAGENET1K_V1))defresnet18(*,weights:Optional[ResNet18_Weights]=None,progress:bool=True,**kwargs:Any)->ResNet:"""ResNet-18 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__. Args: weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ResNet18_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.ResNet18_Weights :members: """weights=ResNet18_Weights.verify(weights)return_resnet(BasicBlock,[2,2,2,2],weights,progress,**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",ResNet34_Weights.IMAGENET1K_V1))defresnet34(*,weights:Optional[ResNet34_Weights]=None,progress:bool=True,**kwargs:Any)->ResNet:"""ResNet-34 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__. Args: weights (:class:`~torchvision.models.ResNet34_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ResNet34_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.ResNet34_Weights :members: """weights=ResNet34_Weights.verify(weights)return_resnet(BasicBlock,[3,4,6,3],weights,progress,**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",ResNet50_Weights.IMAGENET1K_V1))defresnet50(*,weights:Optional[ResNet50_Weights]=None,progress:bool=True,**kwargs:Any)->ResNet:"""ResNet-50 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__. .. note:: The bottleneck of TorchVision places the stride for downsampling to the second 3x3 convolution while the original paper places it to the first 1x1 convolution. This variant improves the accuracy and is known as `ResNet V1.5 <https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch>`_. Args: weights (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ResNet50_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.ResNet50_Weights :members: """weights=ResNet50_Weights.verify(weights)return_resnet(Bottleneck,[3,4,6,3],weights,progress,**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",ResNet101_Weights.IMAGENET1K_V1))defresnet101(*,weights:Optional[ResNet101_Weights]=None,progress:bool=True,**kwargs:Any)->ResNet:"""ResNet-101 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__. .. note:: The bottleneck of TorchVision places the stride for downsampling to the second 3x3 convolution while the original paper places it to the first 1x1 convolution. This variant improves the accuracy and is known as `ResNet V1.5 <https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch>`_. Args: weights (:class:`~torchvision.models.ResNet101_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ResNet101_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.ResNet101_Weights :members: """weights=ResNet101_Weights.verify(weights)return_resnet(Bottleneck,[3,4,23,3],weights,progress,**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",ResNet152_Weights.IMAGENET1K_V1))defresnet152(*,weights:Optional[ResNet152_Weights]=None,progress:bool=True,**kwargs:Any)->ResNet:"""ResNet-152 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__. .. note:: The bottleneck of TorchVision places the stride for downsampling to the second 3x3 convolution while the original paper places it to the first 1x1 convolution. This variant improves the accuracy and is known as `ResNet V1.5 <https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch>`_. Args: weights (:class:`~torchvision.models.ResNet152_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ResNet152_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.ResNet152_Weights :members: """weights=ResNet152_Weights.verify(weights)return_resnet(Bottleneck,[3,8,36,3],weights,progress,**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",ResNeXt50_32X4D_Weights.IMAGENET1K_V1))defresnext50_32x4d(*,weights:Optional[ResNeXt50_32X4D_Weights]=None,progress:bool=True,**kwargs:Any)->ResNet:"""ResNeXt-50 32x4d model from `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_. Args: weights (:class:`~torchvision.models.ResNeXt50_32X4D_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ResNext50_32X4D_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.ResNeXt50_32X4D_Weights :members: """weights=ResNeXt50_32X4D_Weights.verify(weights)_ovewrite_named_param(kwargs,"groups",32)_ovewrite_named_param(kwargs,"width_per_group",4)return_resnet(Bottleneck,[3,4,6,3],weights,progress,**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",ResNeXt101_32X8D_Weights.IMAGENET1K_V1))defresnext101_32x8d(*,weights:Optional[ResNeXt101_32X8D_Weights]=None,progress:bool=True,**kwargs:Any)->ResNet:"""ResNeXt-101 32x8d model from `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_. Args: weights (:class:`~torchvision.models.ResNeXt101_32X8D_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ResNeXt101_32X8D_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.ResNeXt101_32X8D_Weights :members: """weights=ResNeXt101_32X8D_Weights.verify(weights)_ovewrite_named_param(kwargs,"groups",32)_ovewrite_named_param(kwargs,"width_per_group",8)return_resnet(Bottleneck,[3,4,23,3],weights,progress,**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",ResNeXt101_64X4D_Weights.IMAGENET1K_V1))defresnext101_64x4d(*,weights:Optional[ResNeXt101_64X4D_Weights]=None,progress:bool=True,**kwargs:Any)->ResNet:"""ResNeXt-101 64x4d model from `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_. Args: weights (:class:`~torchvision.models.ResNeXt101_64X4D_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ResNeXt101_64X4D_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.ResNeXt101_64X4D_Weights :members: """weights=ResNeXt101_64X4D_Weights.verify(weights)_ovewrite_named_param(kwargs,"groups",64)_ovewrite_named_param(kwargs,"width_per_group",4)return_resnet(Bottleneck,[3,4,23,3],weights,progress,**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",Wide_ResNet50_2_Weights.IMAGENET1K_V1))defwide_resnet50_2(*,weights:Optional[Wide_ResNet50_2_Weights]=None,progress:bool=True,**kwargs:Any)->ResNet:"""Wide ResNet-50-2 model from `Wide Residual Networks <https://arxiv.org/abs/1605.07146>`_. The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 channels, and in Wide ResNet-50-2 has 2048-1024-2048. Args: weights (:class:`~torchvision.models.Wide_ResNet50_2_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.Wide_ResNet50_2_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.Wide_ResNet50_2_Weights :members: """weights=Wide_ResNet50_2_Weights.verify(weights)_ovewrite_named_param(kwargs,"width_per_group",64*2)return_resnet(Bottleneck,[3,4,6,3],weights,progress,**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",Wide_ResNet101_2_Weights.IMAGENET1K_V1))defwide_resnet101_2(*,weights:Optional[Wide_ResNet101_2_Weights]=None,progress:bool=True,**kwargs:Any)->ResNet:"""Wide ResNet-101-2 model from `Wide Residual Networks <https://arxiv.org/abs/1605.07146>`_. The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same, e.g. last block in ResNet-101 has 2048-512-2048 channels, and in Wide ResNet-101-2 has 2048-1024-2048. Args: weights (:class:`~torchvision.models.Wide_ResNet101_2_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.Wide_ResNet101_2_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.Wide_ResNet101_2_Weights :members: """weights=Wide_ResNet101_2_Weights.verify(weights)_ovewrite_named_param(kwargs,"width_per_group",64*2)return_resnet(Bottleneck,[3,4,23,3],weights,progress,**kwargs)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.