importcopyimportmathimporttorchfromfunctoolsimportpartialfromtorchimportnn,TensorfromtypingimportAny,Callable,List,Optional,Sequencefrom.._internally_replaced_utilsimportload_state_dict_from_urlfrom..ops.miscimportConvNormActivation,SqueezeExcitationfrom._utilsimport_make_divisiblefromtorchvision.opsimportStochasticDepth__all__=["EfficientNet","efficientnet_b0","efficientnet_b1","efficientnet_b2","efficientnet_b3","efficientnet_b4","efficientnet_b5","efficientnet_b6","efficientnet_b7"]model_urls={# Weights ported from https://github.com/rwightman/pytorch-image-models/"efficientnet_b0":"https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth","efficientnet_b1":"https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth","efficientnet_b2":"https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth","efficientnet_b3":"https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth","efficientnet_b4":"https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",# Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/"efficientnet_b5":"https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth","efficientnet_b6":"https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth","efficientnet_b7":"https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",}classMBConvConfig:# Stores information listed at Table 1 of the EfficientNet paperdef__init__(self,expand_ratio:float,kernel:int,stride:int,input_channels:int,out_channels:int,num_layers:int,width_mult:float,depth_mult:float)->None:self.expand_ratio=expand_ratioself.kernel=kernelself.stride=strideself.input_channels=self.adjust_channels(input_channels,width_mult)self.out_channels=self.adjust_channels(out_channels,width_mult)self.num_layers=self.adjust_depth(num_layers,depth_mult)def__repr__(self)->str:s=self.__class__.__name__+'('s+='expand_ratio={expand_ratio}'s+=', kernel={kernel}'s+=', stride={stride}'s+=', input_channels={input_channels}'s+=', out_channels={out_channels}'s+=', num_layers={num_layers}'s+=')'returns.format(**self.__dict__)@staticmethoddefadjust_channels(channels:int,width_mult:float,min_value:Optional[int]=None)->int:return_make_divisible(channels*width_mult,8,min_value)@staticmethoddefadjust_depth(num_layers:int,depth_mult:float):returnint(math.ceil(num_layers*depth_mult))classMBConv(nn.Module):def__init__(self,cnf:MBConvConfig,stochastic_depth_prob:float,norm_layer:Callable[...,nn.Module],se_layer:Callable[...,nn.Module]=SqueezeExcitation)->None:super().__init__()ifnot(1<=cnf.stride<=2):raiseValueError('illegal stride value')self.use_res_connect=cnf.stride==1andcnf.input_channels==cnf.out_channelslayers:List[nn.Module]=[]activation_layer=nn.SiLU# expandexpanded_channels=cnf.adjust_channels(cnf.input_channels,cnf.expand_ratio)ifexpanded_channels!=cnf.input_channels:layers.append(ConvNormActivation(cnf.input_channels,expanded_channels,kernel_size=1,norm_layer=norm_layer,activation_layer=activation_layer))# depthwiselayers.append(ConvNormActivation(expanded_channels,expanded_channels,kernel_size=cnf.kernel,stride=cnf.stride,groups=expanded_channels,norm_layer=norm_layer,activation_layer=activation_layer))# squeeze and excitationsqueeze_channels=max(1,cnf.input_channels//4)layers.append(se_layer(expanded_channels,squeeze_channels,activation=partial(nn.SiLU,inplace=True)))# projectlayers.append(ConvNormActivation(expanded_channels,cnf.out_channels,kernel_size=1,norm_layer=norm_layer,activation_layer=None))self.block=nn.Sequential(*layers)self.stochastic_depth=StochasticDepth(stochastic_depth_prob,"row")self.out_channels=cnf.out_channelsdefforward(self,input:Tensor)->Tensor:result=self.block(input)ifself.use_res_connect:result=self.stochastic_depth(result)result+=inputreturnresultclassEfficientNet(nn.Module):def__init__(self,inverted_residual_setting:List[MBConvConfig],dropout:float,stochastic_depth_prob:float=0.2,num_classes:int=1000,block:Optional[Callable[...,nn.Module]]=None,norm_layer:Optional[Callable[...,nn.Module]]=None,**kwargs:Any)->None:""" EfficientNet main class Args: inverted_residual_setting (List[MBConvConfig]): Network structure dropout (float): The droupout probability stochastic_depth_prob (float): The stochastic depth probability num_classes (int): Number of classes block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use """super().__init__()ifnotinverted_residual_setting:raiseValueError("The inverted_residual_setting should not be empty")elifnot(isinstance(inverted_residual_setting,Sequence)andall([isinstance(s,MBConvConfig)forsininverted_residual_setting])):raiseTypeError("The inverted_residual_setting should be List[MBConvConfig]")ifblockisNone:block=MBConvifnorm_layerisNone:norm_layer=nn.BatchNorm2dlayers:List[nn.Module]=[]# building first layerfirstconv_output_channels=inverted_residual_setting[0].input_channelslayers.append(ConvNormActivation(3,firstconv_output_channels,kernel_size=3,stride=2,norm_layer=norm_layer,activation_layer=nn.SiLU))# building inverted residual blockstotal_stage_blocks=sum([cnf.num_layersforcnfininverted_residual_setting])stage_block_id=0forcnfininverted_residual_setting:stage:List[nn.Module]=[]for_inrange(cnf.num_layers):# copy to avoid modifications. shallow copy is enoughblock_cnf=copy.copy(cnf)# overwrite info if not the first conv in the stageifstage:block_cnf.input_channels=block_cnf.out_channelsblock_cnf.stride=1# adjust stochastic depth probability based on the depth of the stage blocksd_prob=stochastic_depth_prob*float(stage_block_id)/total_stage_blocksstage.append(block(block_cnf,sd_prob,norm_layer))stage_block_id+=1layers.append(nn.Sequential(*stage))# building last several layerslastconv_input_channels=inverted_residual_setting[-1].out_channelslastconv_output_channels=4*lastconv_input_channelslayers.append(ConvNormActivation(lastconv_input_channels,lastconv_output_channels,kernel_size=1,norm_layer=norm_layer,activation_layer=nn.SiLU))self.features=nn.Sequential(*layers)self.avgpool=nn.AdaptiveAvgPool2d(1)self.classifier=nn.Sequential(nn.Dropout(p=dropout,inplace=True),nn.Linear(lastconv_output_channels,num_classes),)forminself.modules():ifisinstance(m,nn.Conv2d):nn.init.kaiming_normal_(m.weight,mode='fan_out')ifm.biasisnotNone:nn.init.zeros_(m.bias)elifisinstance(m,(nn.BatchNorm2d,nn.GroupNorm)):nn.init.ones_(m.weight)nn.init.zeros_(m.bias)elifisinstance(m,nn.Linear):init_range=1.0/math.sqrt(m.out_features)nn.init.uniform_(m.weight,-init_range,init_range)nn.init.zeros_(m.bias)def_forward_impl(self,x:Tensor)->Tensor:x=self.features(x)x=self.avgpool(x)x=torch.flatten(x,1)x=self.classifier(x)returnxdefforward(self,x:Tensor)->Tensor:returnself._forward_impl(x)def_efficientnet_conf(width_mult:float,depth_mult:float,**kwargs:Any)->List[MBConvConfig]:bneck_conf=partial(MBConvConfig,width_mult=width_mult,depth_mult=depth_mult)inverted_residual_setting=[bneck_conf(1,3,1,32,16,1),bneck_conf(6,3,2,16,24,2),bneck_conf(6,5,2,24,40,2),bneck_conf(6,3,2,40,80,3),bneck_conf(6,5,1,80,112,3),bneck_conf(6,5,2,112,192,4),bneck_conf(6,3,1,192,320,1),]returninverted_residual_settingdef_efficientnet_model(arch:str,inverted_residual_setting:List[MBConvConfig],dropout:float,pretrained:bool,progress:bool,**kwargs:Any)->EfficientNet:model=EfficientNet(inverted_residual_setting,dropout,**kwargs)ifpretrained:ifmodel_urls.get(arch,None)isNone:raiseValueError("No checkpoint is available for model type {}".format(arch))state_dict=load_state_dict_from_url(model_urls[arch],progress=progress)model.load_state_dict(state_dict)returnmodel
[docs]defefficientnet_b0(pretrained:bool=False,progress:bool=True,**kwargs:Any)->EfficientNet:""" Constructs a EfficientNet B0 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" <https://arxiv.org/abs/1905.11946>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """inverted_residual_setting=_efficientnet_conf(width_mult=1.0,depth_mult=1.0,**kwargs)return_efficientnet_model("efficientnet_b0",inverted_residual_setting,0.2,pretrained,progress,**kwargs)
[docs]defefficientnet_b1(pretrained:bool=False,progress:bool=True,**kwargs:Any)->EfficientNet:""" Constructs a EfficientNet B1 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" <https://arxiv.org/abs/1905.11946>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """inverted_residual_setting=_efficientnet_conf(width_mult=1.0,depth_mult=1.1,**kwargs)return_efficientnet_model("efficientnet_b1",inverted_residual_setting,0.2,pretrained,progress,**kwargs)
[docs]defefficientnet_b2(pretrained:bool=False,progress:bool=True,**kwargs:Any)->EfficientNet:""" Constructs a EfficientNet B2 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" <https://arxiv.org/abs/1905.11946>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """inverted_residual_setting=_efficientnet_conf(width_mult=1.1,depth_mult=1.2,**kwargs)return_efficientnet_model("efficientnet_b2",inverted_residual_setting,0.3,pretrained,progress,**kwargs)
[docs]defefficientnet_b3(pretrained:bool=False,progress:bool=True,**kwargs:Any)->EfficientNet:""" Constructs a EfficientNet B3 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" <https://arxiv.org/abs/1905.11946>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """inverted_residual_setting=_efficientnet_conf(width_mult=1.2,depth_mult=1.4,**kwargs)return_efficientnet_model("efficientnet_b3",inverted_residual_setting,0.3,pretrained,progress,**kwargs)
[docs]defefficientnet_b4(pretrained:bool=False,progress:bool=True,**kwargs:Any)->EfficientNet:""" Constructs a EfficientNet B4 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" <https://arxiv.org/abs/1905.11946>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """inverted_residual_setting=_efficientnet_conf(width_mult=1.4,depth_mult=1.8,**kwargs)return_efficientnet_model("efficientnet_b4",inverted_residual_setting,0.4,pretrained,progress,**kwargs)
[docs]defefficientnet_b5(pretrained:bool=False,progress:bool=True,**kwargs:Any)->EfficientNet:""" Constructs a EfficientNet B5 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" <https://arxiv.org/abs/1905.11946>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """inverted_residual_setting=_efficientnet_conf(width_mult=1.6,depth_mult=2.2,**kwargs)return_efficientnet_model("efficientnet_b5",inverted_residual_setting,0.4,pretrained,progress,norm_layer=partial(nn.BatchNorm2d,eps=0.001,momentum=0.01),**kwargs)
[docs]defefficientnet_b6(pretrained:bool=False,progress:bool=True,**kwargs:Any)->EfficientNet:""" Constructs a EfficientNet B6 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" <https://arxiv.org/abs/1905.11946>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """inverted_residual_setting=_efficientnet_conf(width_mult=1.8,depth_mult=2.6,**kwargs)return_efficientnet_model("efficientnet_b6",inverted_residual_setting,0.5,pretrained,progress,norm_layer=partial(nn.BatchNorm2d,eps=0.001,momentum=0.01),**kwargs)
[docs]defefficientnet_b7(pretrained:bool=False,progress:bool=True,**kwargs:Any)->EfficientNet:""" Constructs a EfficientNet B7 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" <https://arxiv.org/abs/1905.11946>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """inverted_residual_setting=_efficientnet_conf(width_mult=2.0,depth_mult=3.1,**kwargs)return_efficientnet_model("efficientnet_b7",inverted_residual_setting,0.5,pretrained,progress,norm_layer=partial(nn.BatchNorm2d,eps=0.001,momentum=0.01),**kwargs)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.