importwarningsfromtypingimportAny,Dict,Listimporttorchimporttorch.nnasnnfromtorchimportTensorfrom.._internally_replaced_utilsimportload_state_dict_from_urlfrom..utilsimport_log_api_usage_once__all__=["MNASNet","mnasnet0_5","mnasnet0_75","mnasnet1_0","mnasnet1_3"]_MODEL_URLS={"mnasnet0_5":"https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth","mnasnet0_75":None,"mnasnet1_0":"https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth","mnasnet1_3":None,}# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is# 1.0 - tensorflow._BN_MOMENTUM=1-0.9997class_InvertedResidual(nn.Module):def__init__(self,in_ch:int,out_ch:int,kernel_size:int,stride:int,expansion_factor:int,bn_momentum:float=0.1)->None:super().__init__()assertstridein[1,2]assertkernel_sizein[3,5]mid_ch=in_ch*expansion_factorself.apply_residual=in_ch==out_chandstride==1self.layers=nn.Sequential(# Pointwisenn.Conv2d(in_ch,mid_ch,1,bias=False),nn.BatchNorm2d(mid_ch,momentum=bn_momentum),nn.ReLU(inplace=True),# Depthwisenn.Conv2d(mid_ch,mid_ch,kernel_size,padding=kernel_size//2,stride=stride,groups=mid_ch,bias=False),nn.BatchNorm2d(mid_ch,momentum=bn_momentum),nn.ReLU(inplace=True),# Linear pointwise. Note that there's no activation.nn.Conv2d(mid_ch,out_ch,1,bias=False),nn.BatchNorm2d(out_ch,momentum=bn_momentum),)defforward(self,input:Tensor)->Tensor:ifself.apply_residual:returnself.layers(input)+inputelse:returnself.layers(input)def_stack(in_ch:int,out_ch:int,kernel_size:int,stride:int,exp_factor:int,repeats:int,bn_momentum:float)->nn.Sequential:"""Creates a stack of inverted residuals."""assertrepeats>=1# First one has no skip, because feature map size changes.first=_InvertedResidual(in_ch,out_ch,kernel_size,stride,exp_factor,bn_momentum=bn_momentum)remaining=[]for_inrange(1,repeats):remaining.append(_InvertedResidual(out_ch,out_ch,kernel_size,1,exp_factor,bn_momentum=bn_momentum))returnnn.Sequential(first,*remaining)def_round_to_multiple_of(val:float,divisor:int,round_up_bias:float=0.9)->int:"""Asymmetric rounding to make `val` divisible by `divisor`. With default bias, will round up, unless the number is no more than 10% greater than the smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88."""assert0.0<round_up_bias<1.0new_val=max(divisor,int(val+divisor/2)//divisor*divisor)returnnew_valifnew_val>=round_up_bias*valelsenew_val+divisordef_get_depths(alpha:float)->List[int]:"""Scales tensor depths as in reference MobileNet code, prefers rouding up rather than down."""depths=[32,16,24,40,80,96,192,320]return[_round_to_multiple_of(depth*alpha,8)fordepthindepths]classMNASNet(torch.nn.Module):"""MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This implements the B1 variant of the model. >>> model = MNASNet(1.0, num_classes=1000) >>> x = torch.rand(1, 3, 224, 224) >>> y = model(x) >>> y.dim() 2 >>> y.nelement() 1000 """# Version 2 adds depth scaling in the initial stages of the network._version=2def__init__(self,alpha:float,num_classes:int=1000,dropout:float=0.2)->None:super().__init__()_log_api_usage_once(self)assertalpha>0.0self.alpha=alphaself.num_classes=num_classesdepths=_get_depths(alpha)layers=[# First layer: regular conv.nn.Conv2d(3,depths[0],3,padding=1,stride=2,bias=False),nn.BatchNorm2d(depths[0],momentum=_BN_MOMENTUM),nn.ReLU(inplace=True),# Depthwise separable, no skip.nn.Conv2d(depths[0],depths[0],3,padding=1,stride=1,groups=depths[0],bias=False),nn.BatchNorm2d(depths[0],momentum=_BN_MOMENTUM),nn.ReLU(inplace=True),nn.Conv2d(depths[0],depths[1],1,padding=0,stride=1,bias=False),nn.BatchNorm2d(depths[1],momentum=_BN_MOMENTUM),# MNASNet blocks: stacks of inverted residuals._stack(depths[1],depths[2],3,2,3,3,_BN_MOMENTUM),_stack(depths[2],depths[3],5,2,3,3,_BN_MOMENTUM),_stack(depths[3],depths[4],5,2,6,3,_BN_MOMENTUM),_stack(depths[4],depths[5],3,1,6,2,_BN_MOMENTUM),_stack(depths[5],depths[6],5,2,6,4,_BN_MOMENTUM),_stack(depths[6],depths[7],3,1,6,1,_BN_MOMENTUM),# Final mapping to classifier input.nn.Conv2d(depths[7],1280,1,padding=0,stride=1,bias=False),nn.BatchNorm2d(1280,momentum=_BN_MOMENTUM),nn.ReLU(inplace=True),]self.layers=nn.Sequential(*layers)self.classifier=nn.Sequential(nn.Dropout(p=dropout,inplace=True),nn.Linear(1280,num_classes))forminself.modules():ifisinstance(m,nn.Conv2d):nn.init.kaiming_normal_(m.weight,mode="fan_out",nonlinearity="relu")ifm.biasisnotNone:nn.init.zeros_(m.bias)elifisinstance(m,nn.BatchNorm2d):nn.init.ones_(m.weight)nn.init.zeros_(m.bias)elifisinstance(m,nn.Linear):nn.init.kaiming_uniform_(m.weight,mode="fan_out",nonlinearity="sigmoid")nn.init.zeros_(m.bias)defforward(self,x:Tensor)->Tensor:x=self.layers(x)# Equivalent to global avgpool and removing H and W dimensions.x=x.mean([2,3])returnself.classifier(x)def_load_from_state_dict(self,state_dict:Dict,prefix:str,local_metadata:Dict,strict:bool,missing_keys:List[str],unexpected_keys:List[str],error_msgs:List[str],)->None:version=local_metadata.get("version",None)assertversionin[1,2]ifversion==1andnotself.alpha==1.0:# In the initial version of the model (v1), stem was fixed-size.# All other layer configurations were the same. This will patch# the model so that it's identical to v1. Model with alpha 1.0 is# unaffected.depths=_get_depths(self.alpha)v1_stem=[nn.Conv2d(3,32,3,padding=1,stride=2,bias=False),nn.BatchNorm2d(32,momentum=_BN_MOMENTUM),nn.ReLU(inplace=True),nn.Conv2d(32,32,3,padding=1,stride=1,groups=32,bias=False),nn.BatchNorm2d(32,momentum=_BN_MOMENTUM),nn.ReLU(inplace=True),nn.Conv2d(32,16,1,padding=0,stride=1,bias=False),nn.BatchNorm2d(16,momentum=_BN_MOMENTUM),_stack(16,depths[2],3,2,3,3,_BN_MOMENTUM),]foridx,layerinenumerate(v1_stem):self.layers[idx]=layer# The model is now identical to v1, and must be saved as such.self._version=1warnings.warn("A new version of MNASNet model has been implemented. ""Your checkpoint was saved using the previous version. ""This checkpoint will load and work as before, but ""you may want to upgrade by training a newer model or ""transfer learning from an updated ImageNet checkpoint.",UserWarning,)super()._load_from_state_dict(state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs)def_load_pretrained(model_name:str,model:nn.Module,progress:bool)->None:ifmodel_namenotin_MODEL_URLSor_MODEL_URLS[model_name]isNone:raiseValueError(f"No checkpoint is available for model type {model_name}")checkpoint_url=_MODEL_URLS[model_name]model.load_state_dict(load_state_dict_from_url(checkpoint_url,progress=progress))
[docs]defmnasnet0_5(pretrained:bool=False,progress:bool=True,**kwargs:Any)->MNASNet:r"""MNASNet with depth multiplier of 0.5 from `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" <https://arxiv.org/pdf/1807.11626.pdf>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """model=MNASNet(0.5,**kwargs)ifpretrained:_load_pretrained("mnasnet0_5",model,progress)returnmodel
[docs]defmnasnet0_75(pretrained:bool=False,progress:bool=True,**kwargs:Any)->MNASNet:r"""MNASNet with depth multiplier of 0.75 from `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" <https://arxiv.org/pdf/1807.11626.pdf>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """model=MNASNet(0.75,**kwargs)ifpretrained:_load_pretrained("mnasnet0_75",model,progress)returnmodel
[docs]defmnasnet1_0(pretrained:bool=False,progress:bool=True,**kwargs:Any)->MNASNet:r"""MNASNet with depth multiplier of 1.0 from `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" <https://arxiv.org/pdf/1807.11626.pdf>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """model=MNASNet(1.0,**kwargs)ifpretrained:_load_pretrained("mnasnet1_0",model,progress)returnmodel
[docs]defmnasnet1_3(pretrained:bool=False,progress:bool=True,**kwargs:Any)->MNASNet:r"""MNASNet with depth multiplier of 1.3 from `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" <https://arxiv.org/pdf/1807.11626.pdf>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """model=MNASNet(1.3,**kwargs)ifpretrained:_load_pretrained("mnasnet1_3",model,progress)returnmodel
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.