fromtypingimportType,Any,Callable,Union,List,Optionalimporttorchimporttorch.nnasnnfromtorchimportTensorfrom.._internally_replaced_utilsimportload_state_dict_from_urlfrom..utilsimport_log_api_usage_once__all__=["ResNet","resnet18","resnet34","resnet50","resnet101","resnet152","resnext50_32x4d","resnext101_32x8d","wide_resnet50_2","wide_resnet101_2",]model_urls={"resnet18":"https://download.pytorch.org/models/resnet18-f37072fd.pth","resnet34":"https://download.pytorch.org/models/resnet34-b627a593.pth","resnet50":"https://download.pytorch.org/models/resnet50-0676ba61.pth","resnet101":"https://download.pytorch.org/models/resnet101-63fe2227.pth","resnet152":"https://download.pytorch.org/models/resnet152-394f9c45.pth","resnext50_32x4d":"https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth","resnext101_32x8d":"https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth","wide_resnet50_2":"https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth","wide_resnet101_2":"https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",}defconv3x3(in_planes:int,out_planes:int,stride:int=1,groups:int=1,dilation:int=1)->nn.Conv2d:"""3x3 convolution with padding"""returnnn.Conv2d(in_planes,out_planes,kernel_size=3,stride=stride,padding=dilation,groups=groups,bias=False,dilation=dilation,)defconv1x1(in_planes:int,out_planes:int,stride:int=1)->nn.Conv2d:"""1x1 convolution"""returnnn.Conv2d(in_planes,out_planes,kernel_size=1,stride=stride,bias=False)classBasicBlock(nn.Module):expansion:int=1def__init__(self,inplanes:int,planes:int,stride:int=1,downsample:Optional[nn.Module]=None,groups:int=1,base_width:int=64,dilation:int=1,norm_layer:Optional[Callable[...,nn.Module]]=None,)->None:super().__init__()ifnorm_layerisNone:norm_layer=nn.BatchNorm2difgroups!=1orbase_width!=64:raiseValueError("BasicBlock only supports groups=1 and base_width=64")ifdilation>1:raiseNotImplementedError("Dilation > 1 not supported in BasicBlock")# Both self.conv1 and self.downsample layers downsample the input when stride != 1self.conv1=conv3x3(inplanes,planes,stride)self.bn1=norm_layer(planes)self.relu=nn.ReLU(inplace=True)self.conv2=conv3x3(planes,planes)self.bn2=norm_layer(planes)self.downsample=downsampleself.stride=stridedefforward(self,x:Tensor)->Tensor:identity=xout=self.conv1(x)out=self.bn1(out)out=self.relu(out)out=self.conv2(out)out=self.bn2(out)ifself.downsampleisnotNone:identity=self.downsample(x)out+=identityout=self.relu(out)returnoutclassBottleneck(nn.Module):# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)# while original implementation places the stride at the first 1x1 convolution(self.conv1)# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.# This variant is also known as ResNet V1.5 and improves accuracy according to# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.expansion:int=4def__init__(self,inplanes:int,planes:int,stride:int=1,downsample:Optional[nn.Module]=None,groups:int=1,base_width:int=64,dilation:int=1,norm_layer:Optional[Callable[...,nn.Module]]=None,)->None:super().__init__()ifnorm_layerisNone:norm_layer=nn.BatchNorm2dwidth=int(planes*(base_width/64.0))*groups# Both self.conv2 and self.downsample layers downsample the input when stride != 1self.conv1=conv1x1(inplanes,width)self.bn1=norm_layer(width)self.conv2=conv3x3(width,width,stride,groups,dilation)self.bn2=norm_layer(width)self.conv3=conv1x1(width,planes*self.expansion)self.bn3=norm_layer(planes*self.expansion)self.relu=nn.ReLU(inplace=True)self.downsample=downsampleself.stride=stridedefforward(self,x:Tensor)->Tensor:identity=xout=self.conv1(x)out=self.bn1(out)out=self.relu(out)out=self.conv2(out)out=self.bn2(out)out=self.relu(out)out=self.conv3(out)out=self.bn3(out)ifself.downsampleisnotNone:identity=self.downsample(x)out+=identityout=self.relu(out)returnoutclassResNet(nn.Module):def__init__(self,block:Type[Union[BasicBlock,Bottleneck]],layers:List[int],num_classes:int=1000,zero_init_residual:bool=False,groups:int=1,width_per_group:int=64,replace_stride_with_dilation:Optional[List[bool]]=None,norm_layer:Optional[Callable[...,nn.Module]]=None,)->None:super().__init__()_log_api_usage_once(self)ifnorm_layerisNone:norm_layer=nn.BatchNorm2dself._norm_layer=norm_layerself.inplanes=64self.dilation=1ifreplace_stride_with_dilationisNone:# each element in the tuple indicates if we should replace# the 2x2 stride with a dilated convolution insteadreplace_stride_with_dilation=[False,False,False]iflen(replace_stride_with_dilation)!=3:raiseValueError("replace_stride_with_dilation should be None "f"or a 3-element tuple, got {replace_stride_with_dilation}")self.groups=groupsself.base_width=width_per_groupself.conv1=nn.Conv2d(3,self.inplanes,kernel_size=7,stride=2,padding=3,bias=False)self.bn1=norm_layer(self.inplanes)self.relu=nn.ReLU(inplace=True)self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)self.layer1=self._make_layer(block,64,layers[0])self.layer2=self._make_layer(block,128,layers[1],stride=2,dilate=replace_stride_with_dilation[0])self.layer3=self._make_layer(block,256,layers[2],stride=2,dilate=replace_stride_with_dilation[1])self.layer4=self._make_layer(block,512,layers[3],stride=2,dilate=replace_stride_with_dilation[2])self.avgpool=nn.AdaptiveAvgPool2d((1,1))self.fc=nn.Linear(512*block.expansion,num_classes)forminself.modules():ifisinstance(m,nn.Conv2d):nn.init.kaiming_normal_(m.weight,mode="fan_out",nonlinearity="relu")elifisinstance(m,(nn.BatchNorm2d,nn.GroupNorm)):nn.init.constant_(m.weight,1)nn.init.constant_(m.bias,0)# Zero-initialize the last BN in each residual branch,# so that the residual branch starts with zeros, and each residual block behaves like an identity.# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677ifzero_init_residual:forminself.modules():ifisinstance(m,Bottleneck):nn.init.constant_(m.bn3.weight,0)# type: ignore[arg-type]elifisinstance(m,BasicBlock):nn.init.constant_(m.bn2.weight,0)# type: ignore[arg-type]def_make_layer(self,block:Type[Union[BasicBlock,Bottleneck]],planes:int,blocks:int,stride:int=1,dilate:bool=False,)->nn.Sequential:norm_layer=self._norm_layerdownsample=Noneprevious_dilation=self.dilationifdilate:self.dilation*=stridestride=1ifstride!=1orself.inplanes!=planes*block.expansion:downsample=nn.Sequential(conv1x1(self.inplanes,planes*block.expansion,stride),norm_layer(planes*block.expansion),)layers=[]layers.append(block(self.inplanes,planes,stride,downsample,self.groups,self.base_width,previous_dilation,norm_layer))self.inplanes=planes*block.expansionfor_inrange(1,blocks):layers.append(block(self.inplanes,planes,groups=self.groups,base_width=self.base_width,dilation=self.dilation,norm_layer=norm_layer,))returnnn.Sequential(*layers)def_forward_impl(self,x:Tensor)->Tensor:# See note [TorchScript super()]x=self.conv1(x)x=self.bn1(x)x=self.relu(x)x=self.maxpool(x)x=self.layer1(x)x=self.layer2(x)x=self.layer3(x)x=self.layer4(x)x=self.avgpool(x)x=torch.flatten(x,1)x=self.fc(x)returnxdefforward(self,x:Tensor)->Tensor:returnself._forward_impl(x)def_resnet(arch:str,block:Type[Union[BasicBlock,Bottleneck]],layers:List[int],pretrained:bool,progress:bool,**kwargs:Any,)->ResNet:model=ResNet(block,layers,**kwargs)ifpretrained:state_dict=load_state_dict_from_url(model_urls[arch],progress=progress)model.load_state_dict(state_dict)returnmodeldefresnet18(pretrained:bool=False,progress:bool=True,**kwargs:Any)->ResNet:r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """return_resnet("resnet18",BasicBlock,[2,2,2,2],pretrained,progress,**kwargs)defresnet34(pretrained:bool=False,progress:bool=True,**kwargs:Any)->ResNet:r"""ResNet-34 model from `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """return_resnet("resnet34",BasicBlock,[3,4,6,3],pretrained,progress,**kwargs)defresnet50(pretrained:bool=False,progress:bool=True,**kwargs:Any)->ResNet:r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """return_resnet("resnet50",Bottleneck,[3,4,6,3],pretrained,progress,**kwargs)defresnet101(pretrained:bool=False,progress:bool=True,**kwargs:Any)->ResNet:r"""ResNet-101 model from `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """return_resnet("resnet101",Bottleneck,[3,4,23,3],pretrained,progress,**kwargs)defresnet152(pretrained:bool=False,progress:bool=True,**kwargs:Any)->ResNet:r"""ResNet-152 model from `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """return_resnet("resnet152",Bottleneck,[3,8,36,3],pretrained,progress,**kwargs)defresnext50_32x4d(pretrained:bool=False,progress:bool=True,**kwargs:Any)->ResNet:r"""ResNeXt-50 32x4d model from `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """kwargs["groups"]=32kwargs["width_per_group"]=4return_resnet("resnext50_32x4d",Bottleneck,[3,4,6,3],pretrained,progress,**kwargs)defresnext101_32x8d(pretrained:bool=False,progress:bool=True,**kwargs:Any)->ResNet:r"""ResNeXt-101 32x8d model from `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """kwargs["groups"]=32kwargs["width_per_group"]=8return_resnet("resnext101_32x8d",Bottleneck,[3,4,23,3],pretrained,progress,**kwargs)
[docs]defwide_resnet50_2(pretrained:bool=False,progress:bool=True,**kwargs:Any)->ResNet:r"""Wide ResNet-50-2 model from `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_. The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 channels, and in Wide ResNet-50-2 has 2048-1024-2048. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """kwargs["width_per_group"]=64*2return_resnet("wide_resnet50_2",Bottleneck,[3,4,6,3],pretrained,progress,**kwargs)
[docs]defwide_resnet101_2(pretrained:bool=False,progress:bool=True,**kwargs:Any)->ResNet:r"""Wide ResNet-101-2 model from `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_. The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 channels, and in Wide ResNet-50-2 has 2048-1024-2048. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """kwargs["width_per_group"]=64*2return_resnet("wide_resnet101_2",Bottleneck,[3,4,23,3],pretrained,progress,**kwargs)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.