fromfunctoolsimportpartialfromtypingimportAny,Callable,List,Optional,Sequenceimporttorchfromtorchimportnn,Tensorfrom..ops.miscimportConv2dNormActivation,SqueezeExcitationasSElayerfrom..transforms._presetsimportImageClassificationfrom..utilsimport_log_api_usage_oncefrom._apiimportregister_model,Weights,WeightsEnumfrom._metaimport_IMAGENET_CATEGORIESfrom._utilsimport_make_divisible,_ovewrite_named_param,handle_legacy_interface__all__=["MobileNetV3","MobileNet_V3_Large_Weights","MobileNet_V3_Small_Weights","mobilenet_v3_large","mobilenet_v3_small",]classInvertedResidualConfig:# Stores information listed at Tables 1 and 2 of the MobileNetV3 paperdef__init__(self,input_channels:int,kernel:int,expanded_channels:int,out_channels:int,use_se:bool,activation:str,stride:int,dilation:int,width_mult:float,):self.input_channels=self.adjust_channels(input_channels,width_mult)self.kernel=kernelself.expanded_channels=self.adjust_channels(expanded_channels,width_mult)self.out_channels=self.adjust_channels(out_channels,width_mult)self.use_se=use_seself.use_hs=activation=="HS"self.stride=strideself.dilation=dilation@staticmethoddefadjust_channels(channels:int,width_mult:float):return_make_divisible(channels*width_mult,8)classInvertedResidual(nn.Module):# Implemented as described at section 5 of MobileNetV3 paperdef__init__(self,cnf:InvertedResidualConfig,norm_layer:Callable[...,nn.Module],se_layer:Callable[...,nn.Module]=partial(SElayer,scale_activation=nn.Hardsigmoid),):super().__init__()ifnot(1<=cnf.stride<=2):raiseValueError("illegal stride value")self.use_res_connect=cnf.stride==1andcnf.input_channels==cnf.out_channelslayers:List[nn.Module]=[]activation_layer=nn.Hardswishifcnf.use_hselsenn.ReLU# expandifcnf.expanded_channels!=cnf.input_channels:layers.append(Conv2dNormActivation(cnf.input_channels,cnf.expanded_channels,kernel_size=1,norm_layer=norm_layer,activation_layer=activation_layer,))# depthwisestride=1ifcnf.dilation>1elsecnf.stridelayers.append(Conv2dNormActivation(cnf.expanded_channels,cnf.expanded_channels,kernel_size=cnf.kernel,stride=stride,dilation=cnf.dilation,groups=cnf.expanded_channels,norm_layer=norm_layer,activation_layer=activation_layer,))ifcnf.use_se:squeeze_channels=_make_divisible(cnf.expanded_channels//4,8)layers.append(se_layer(cnf.expanded_channels,squeeze_channels))# projectlayers.append(Conv2dNormActivation(cnf.expanded_channels,cnf.out_channels,kernel_size=1,norm_layer=norm_layer,activation_layer=None))self.block=nn.Sequential(*layers)self.out_channels=cnf.out_channelsself._is_cn=cnf.stride>1defforward(self,input:Tensor)->Tensor:result=self.block(input)ifself.use_res_connect:result+=inputreturnresultclassMobileNetV3(nn.Module):def__init__(self,inverted_residual_setting:List[InvertedResidualConfig],last_channel:int,num_classes:int=1000,block:Optional[Callable[...,nn.Module]]=None,norm_layer:Optional[Callable[...,nn.Module]]=None,dropout:float=0.2,**kwargs:Any,)->None:""" MobileNet V3 main class Args: inverted_residual_setting (List[InvertedResidualConfig]): Network structure last_channel (int): The number of channels on the penultimate layer num_classes (int): Number of classes block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use dropout (float): The droupout probability """super().__init__()_log_api_usage_once(self)ifnotinverted_residual_setting:raiseValueError("The inverted_residual_setting should not be empty")elifnot(isinstance(inverted_residual_setting,Sequence)andall([isinstance(s,InvertedResidualConfig)forsininverted_residual_setting])):raiseTypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")ifblockisNone:block=InvertedResidualifnorm_layerisNone:norm_layer=partial(nn.BatchNorm2d,eps=0.001,momentum=0.01)layers:List[nn.Module]=[]# building first layerfirstconv_output_channels=inverted_residual_setting[0].input_channelslayers.append(Conv2dNormActivation(3,firstconv_output_channels,kernel_size=3,stride=2,norm_layer=norm_layer,activation_layer=nn.Hardswish,))# building inverted residual blocksforcnfininverted_residual_setting:layers.append(block(cnf,norm_layer))# building last several layerslastconv_input_channels=inverted_residual_setting[-1].out_channelslastconv_output_channels=6*lastconv_input_channelslayers.append(Conv2dNormActivation(lastconv_input_channels,lastconv_output_channels,kernel_size=1,norm_layer=norm_layer,activation_layer=nn.Hardswish,))self.features=nn.Sequential(*layers)self.avgpool=nn.AdaptiveAvgPool2d(1)self.classifier=nn.Sequential(nn.Linear(lastconv_output_channels,last_channel),nn.Hardswish(inplace=True),nn.Dropout(p=dropout,inplace=True),nn.Linear(last_channel,num_classes),)forminself.modules():ifisinstance(m,nn.Conv2d):nn.init.kaiming_normal_(m.weight,mode="fan_out")ifm.biasisnotNone:nn.init.zeros_(m.bias)elifisinstance(m,(nn.BatchNorm2d,nn.GroupNorm)):nn.init.ones_(m.weight)nn.init.zeros_(m.bias)elifisinstance(m,nn.Linear):nn.init.normal_(m.weight,0,0.01)nn.init.zeros_(m.bias)def_forward_impl(self,x:Tensor)->Tensor:x=self.features(x)x=self.avgpool(x)x=torch.flatten(x,1)x=self.classifier(x)returnxdefforward(self,x:Tensor)->Tensor:returnself._forward_impl(x)def_mobilenet_v3_conf(arch:str,width_mult:float=1.0,reduced_tail:bool=False,dilated:bool=False,**kwargs:Any):reduce_divider=2ifreduced_tailelse1dilation=2ifdilatedelse1bneck_conf=partial(InvertedResidualConfig,width_mult=width_mult)adjust_channels=partial(InvertedResidualConfig.adjust_channels,width_mult=width_mult)ifarch=="mobilenet_v3_large":inverted_residual_setting=[bneck_conf(16,3,16,16,False,"RE",1,1),bneck_conf(16,3,64,24,False,"RE",2,1),# C1bneck_conf(24,3,72,24,False,"RE",1,1),bneck_conf(24,5,72,40,True,"RE",2,1),# C2bneck_conf(40,5,120,40,True,"RE",1,1),bneck_conf(40,5,120,40,True,"RE",1,1),bneck_conf(40,3,240,80,False,"HS",2,1),# C3bneck_conf(80,3,200,80,False,"HS",1,1),bneck_conf(80,3,184,80,False,"HS",1,1),bneck_conf(80,3,184,80,False,"HS",1,1),bneck_conf(80,3,480,112,True,"HS",1,1),bneck_conf(112,3,672,112,True,"HS",1,1),bneck_conf(112,5,672,160//reduce_divider,True,"HS",2,dilation),# C4bneck_conf(160//reduce_divider,5,960//reduce_divider,160//reduce_divider,True,"HS",1,dilation),bneck_conf(160//reduce_divider,5,960//reduce_divider,160//reduce_divider,True,"HS",1,dilation),]last_channel=adjust_channels(1280//reduce_divider)# C5elifarch=="mobilenet_v3_small":inverted_residual_setting=[bneck_conf(16,3,16,16,True,"RE",2,1),# C1bneck_conf(16,3,72,24,False,"RE",2,1),# C2bneck_conf(24,3,88,24,False,"RE",1,1),bneck_conf(24,5,96,40,True,"HS",2,1),# C3bneck_conf(40,5,240,40,True,"HS",1,1),bneck_conf(40,5,240,40,True,"HS",1,1),bneck_conf(40,5,120,48,True,"HS",1,1),bneck_conf(48,5,144,48,True,"HS",1,1),bneck_conf(48,5,288,96//reduce_divider,True,"HS",2,dilation),# C4bneck_conf(96//reduce_divider,5,576//reduce_divider,96//reduce_divider,True,"HS",1,dilation),bneck_conf(96//reduce_divider,5,576//reduce_divider,96//reduce_divider,True,"HS",1,dilation),]last_channel=adjust_channels(1024//reduce_divider)# C5else:raiseValueError(f"Unsupported model type {arch}")returninverted_residual_setting,last_channeldef_mobilenet_v3(inverted_residual_setting:List[InvertedResidualConfig],last_channel:int,weights:Optional[WeightsEnum],progress:bool,**kwargs:Any,)->MobileNetV3:ifweightsisnotNone:_ovewrite_named_param(kwargs,"num_classes",len(weights.meta["categories"]))model=MobileNetV3(inverted_residual_setting,last_channel,**kwargs)ifweightsisnotNone:model.load_state_dict(weights.get_state_dict(progress=progress,check_hash=True))returnmodel_COMMON_META={"min_size":(1,1),"categories":_IMAGENET_CATEGORIES,}
[docs]classMobileNet_V3_Large_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",transforms=partial(ImageClassification,crop_size=224),meta={**_COMMON_META,"num_params":5483032,"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small","_metrics":{"ImageNet-1K":{"acc@1":74.042,"acc@5":91.340,}},"_ops":0.217,"_file_size":21.114,"_docs":"""These weights were trained from scratch by using a simple training recipe.""",},)IMAGENET1K_V2=Weights(url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth",transforms=partial(ImageClassification,crop_size=224,resize_size=232),meta={**_COMMON_META,"num_params":5483032,"recipe":"https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning","_metrics":{"ImageNet-1K":{"acc@1":75.274,"acc@5":92.566,}},"_ops":0.217,"_file_size":21.107,"_docs":""" These weights improve marginally upon the results of the original paper by using a modified version of TorchVision's `new training recipe <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. """,},)DEFAULT=IMAGENET1K_V2
[docs]classMobileNet_V3_Small_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",transforms=partial(ImageClassification,crop_size=224),meta={**_COMMON_META,"num_params":2542856,"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small","_metrics":{"ImageNet-1K":{"acc@1":67.668,"acc@5":87.402,}},"_ops":0.057,"_file_size":9.829,"_docs":""" These weights improve upon the results of the original paper by using a simple training recipe. """,},)DEFAULT=IMAGENET1K_V1
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",MobileNet_V3_Large_Weights.IMAGENET1K_V1))defmobilenet_v3_large(*,weights:Optional[MobileNet_V3_Large_Weights]=None,progress:bool=True,**kwargs:Any)->MobileNetV3:""" Constructs a large MobileNetV3 architecture from `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`__. Args: weights (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.MobileNet_V3_Large_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.mobilenet.MobileNetV3`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv3.py>`_ for more details about this class. .. autoclass:: torchvision.models.MobileNet_V3_Large_Weights :members: """weights=MobileNet_V3_Large_Weights.verify(weights)inverted_residual_setting,last_channel=_mobilenet_v3_conf("mobilenet_v3_large",**kwargs)return_mobilenet_v3(inverted_residual_setting,last_channel,weights,progress,**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",MobileNet_V3_Small_Weights.IMAGENET1K_V1))defmobilenet_v3_small(*,weights:Optional[MobileNet_V3_Small_Weights]=None,progress:bool=True,**kwargs:Any)->MobileNetV3:""" Constructs a small MobileNetV3 architecture from `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`__. Args: weights (:class:`~torchvision.models.MobileNet_V3_Small_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.MobileNet_V3_Small_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.mobilenet.MobileNetV3`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv3.py>`_ for more details about this class. .. autoclass:: torchvision.models.MobileNet_V3_Small_Weights :members: """weights=MobileNet_V3_Small_Weights.verify(weights)inverted_residual_setting,last_channel=_mobilenet_v3_conf("mobilenet_v3_small",**kwargs)return_mobilenet_v3(inverted_residual_setting,last_channel,weights,progress,**kwargs)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.