importtorchimportwarningsfromfunctoolsimportpartialfromtorchimportnnfromtorchimportTensorfrom.._internally_replaced_utilsimportload_state_dict_from_urlfrom..ops.miscimportConvNormActivationfrom._utilsimport_make_divisiblefromtypingimportCallable,Any,Optional,List__all__=['MobileNetV2','mobilenet_v2']model_urls={'mobilenet_v2':'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',}# necessary for backwards compatibilityclass_DeprecatedConvBNAct(ConvNormActivation):def__init__(self,*args,**kwargs):warnings.warn("The ConvBNReLU/ConvBNActivation classes are deprecated and will be removed in future versions. ""Use torchvision.ops.misc.ConvNormActivation instead.",FutureWarning)ifkwargs.get("norm_layer",None)isNone:kwargs["norm_layer"]=nn.BatchNorm2difkwargs.get("activation_layer",None)isNone:kwargs["activation_layer"]=nn.ReLU6super().__init__(*args,**kwargs)ConvBNReLU=_DeprecatedConvBNActConvBNActivation=_DeprecatedConvBNActclassInvertedResidual(nn.Module):def__init__(self,inp:int,oup:int,stride:int,expand_ratio:int,norm_layer:Optional[Callable[...,nn.Module]]=None)->None:super(InvertedResidual,self).__init__()self.stride=strideassertstridein[1,2]ifnorm_layerisNone:norm_layer=nn.BatchNorm2dhidden_dim=int(round(inp*expand_ratio))self.use_res_connect=self.stride==1andinp==ouplayers:List[nn.Module]=[]ifexpand_ratio!=1:# pwlayers.append(ConvNormActivation(inp,hidden_dim,kernel_size=1,norm_layer=norm_layer,activation_layer=nn.ReLU6))layers.extend([# dwConvNormActivation(hidden_dim,hidden_dim,stride=stride,groups=hidden_dim,norm_layer=norm_layer,activation_layer=nn.ReLU6),# pw-linearnn.Conv2d(hidden_dim,oup,1,1,0,bias=False),norm_layer(oup),])self.conv=nn.Sequential(*layers)self.out_channels=oupself._is_cn=stride>1defforward(self,x:Tensor)->Tensor:ifself.use_res_connect:returnx+self.conv(x)else:returnself.conv(x)classMobileNetV2(nn.Module):def__init__(self,num_classes:int=1000,width_mult:float=1.0,inverted_residual_setting:Optional[List[List[int]]]=None,round_nearest:int=8,block:Optional[Callable[...,nn.Module]]=None,norm_layer:Optional[Callable[...,nn.Module]]=None)->None:""" MobileNet V2 main class Args: num_classes (int): Number of classes width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount inverted_residual_setting: Network structure round_nearest (int): Round the number of channels in each layer to be a multiple of this number Set to 1 to turn off rounding block: Module specifying inverted residual building block for mobilenet norm_layer: Module specifying the normalization layer to use """super(MobileNetV2,self).__init__()ifblockisNone:block=InvertedResidualifnorm_layerisNone:norm_layer=nn.BatchNorm2dinput_channel=32last_channel=1280ifinverted_residual_settingisNone:inverted_residual_setting=[# t, c, n, s[1,16,1,1],[6,24,2,2],[6,32,3,2],[6,64,4,2],[6,96,3,1],[6,160,3,2],[6,320,1,1],]# only check the first element, assuming user knows t,c,n,s are requirediflen(inverted_residual_setting)==0orlen(inverted_residual_setting[0])!=4:raiseValueError("inverted_residual_setting should be non-empty ""or a 4-element list, got {}".format(inverted_residual_setting))# building first layerinput_channel=_make_divisible(input_channel*width_mult,round_nearest)self.last_channel=_make_divisible(last_channel*max(1.0,width_mult),round_nearest)features:List[nn.Module]=[ConvNormActivation(3,input_channel,stride=2,norm_layer=norm_layer,activation_layer=nn.ReLU6)]# building inverted residual blocksfort,c,n,sininverted_residual_setting:output_channel=_make_divisible(c*width_mult,round_nearest)foriinrange(n):stride=sifi==0else1features.append(block(input_channel,output_channel,stride,expand_ratio=t,norm_layer=norm_layer))input_channel=output_channel# building last several layersfeatures.append(ConvNormActivation(input_channel,self.last_channel,kernel_size=1,norm_layer=norm_layer,activation_layer=nn.ReLU6))# make it nn.Sequentialself.features=nn.Sequential(*features)# building classifierself.classifier=nn.Sequential(nn.Dropout(0.2),nn.Linear(self.last_channel,num_classes),)# weight initializationforminself.modules():ifisinstance(m,nn.Conv2d):nn.init.kaiming_normal_(m.weight,mode='fan_out')ifm.biasisnotNone:nn.init.zeros_(m.bias)elifisinstance(m,(nn.BatchNorm2d,nn.GroupNorm)):nn.init.ones_(m.weight)nn.init.zeros_(m.bias)elifisinstance(m,nn.Linear):nn.init.normal_(m.weight,0,0.01)nn.init.zeros_(m.bias)def_forward_impl(self,x:Tensor)->Tensor:# This exists since TorchScript doesn't support inheritance, so the superclass method# (this one) needs to have a name other than `forward` that can be accessed in a subclassx=self.features(x)# Cannot use "squeeze" as batch-size can be 1x=nn.functional.adaptive_avg_pool2d(x,(1,1))x=torch.flatten(x,1)x=self.classifier(x)returnxdefforward(self,x:Tensor)->Tensor:returnself._forward_impl(x)
[docs]defmobilenet_v2(pretrained:bool=False,progress:bool=True,**kwargs:Any)->MobileNetV2:""" Constructs a MobileNetV2 architecture from `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """model=MobileNetV2(**kwargs)ifpretrained:state_dict=load_state_dict_from_url(model_urls['mobilenet_v2'],progress=progress)model.load_state_dict(state_dict)returnmodel
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.