importwarningsfromfunctoolsimportpartialfromtypingimportAny,Dict,List,Optionalimporttorchimporttorch.nnasnnfromtorchimportTensorfrom..transforms._presetsimportImageClassificationfrom..utilsimport_log_api_usage_oncefrom._apiimportWeightsEnum,Weightsfrom._metaimport_IMAGENET_CATEGORIESfrom._utilsimporthandle_legacy_interface,_ovewrite_named_param__all__=["MNASNet","MNASNet0_5_Weights","MNASNet0_75_Weights","MNASNet1_0_Weights","MNASNet1_3_Weights","mnasnet0_5","mnasnet0_75","mnasnet1_0","mnasnet1_3",]# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is# 1.0 - tensorflow._BN_MOMENTUM=1-0.9997class_InvertedResidual(nn.Module):def__init__(self,in_ch:int,out_ch:int,kernel_size:int,stride:int,expansion_factor:int,bn_momentum:float=0.1)->None:super().__init__()ifstridenotin[1,2]:raiseValueError(f"stride should be 1 or 2 instead of {stride}")ifkernel_sizenotin[3,5]:raiseValueError(f"kernel_size should be 3 or 5 instead of {kernel_size}")mid_ch=in_ch*expansion_factorself.apply_residual=in_ch==out_chandstride==1self.layers=nn.Sequential(# Pointwisenn.Conv2d(in_ch,mid_ch,1,bias=False),nn.BatchNorm2d(mid_ch,momentum=bn_momentum),nn.ReLU(inplace=True),# Depthwisenn.Conv2d(mid_ch,mid_ch,kernel_size,padding=kernel_size//2,stride=stride,groups=mid_ch,bias=False),nn.BatchNorm2d(mid_ch,momentum=bn_momentum),nn.ReLU(inplace=True),# Linear pointwise. Note that there's no activation.nn.Conv2d(mid_ch,out_ch,1,bias=False),nn.BatchNorm2d(out_ch,momentum=bn_momentum),)defforward(self,input:Tensor)->Tensor:ifself.apply_residual:returnself.layers(input)+inputelse:returnself.layers(input)def_stack(in_ch:int,out_ch:int,kernel_size:int,stride:int,exp_factor:int,repeats:int,bn_momentum:float)->nn.Sequential:"""Creates a stack of inverted residuals."""ifrepeats<1:raiseValueError(f"repeats should be >= 1, instead got {repeats}")# First one has no skip, because feature map size changes.first=_InvertedResidual(in_ch,out_ch,kernel_size,stride,exp_factor,bn_momentum=bn_momentum)remaining=[]for_inrange(1,repeats):remaining.append(_InvertedResidual(out_ch,out_ch,kernel_size,1,exp_factor,bn_momentum=bn_momentum))returnnn.Sequential(first,*remaining)def_round_to_multiple_of(val:float,divisor:int,round_up_bias:float=0.9)->int:"""Asymmetric rounding to make `val` divisible by `divisor`. With default bias, will round up, unless the number is no more than 10% greater than the smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88."""ifnot0.0<round_up_bias<1.0:raiseValueError(f"round_up_bias should be greater than 0.0 and smaller than 1.0 instead of {round_up_bias}")new_val=max(divisor,int(val+divisor/2)//divisor*divisor)returnnew_valifnew_val>=round_up_bias*valelsenew_val+divisordef_get_depths(alpha:float)->List[int]:"""Scales tensor depths as in reference MobileNet code, prefers rouding up rather than down."""depths=[32,16,24,40,80,96,192,320]return[_round_to_multiple_of(depth*alpha,8)fordepthindepths]classMNASNet(torch.nn.Module):"""MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This implements the B1 variant of the model. >>> model = MNASNet(1.0, num_classes=1000) >>> x = torch.rand(1, 3, 224, 224) >>> y = model(x) >>> y.dim() 2 >>> y.nelement() 1000 """# Version 2 adds depth scaling in the initial stages of the network._version=2def__init__(self,alpha:float,num_classes:int=1000,dropout:float=0.2)->None:super().__init__()_log_api_usage_once(self)ifalpha<=0.0:raiseValueError(f"alpha should be greater than 0.0 instead of {alpha}")self.alpha=alphaself.num_classes=num_classesdepths=_get_depths(alpha)layers=[# First layer: regular conv.nn.Conv2d(3,depths[0],3,padding=1,stride=2,bias=False),nn.BatchNorm2d(depths[0],momentum=_BN_MOMENTUM),nn.ReLU(inplace=True),# Depthwise separable, no skip.nn.Conv2d(depths[0],depths[0],3,padding=1,stride=1,groups=depths[0],bias=False),nn.BatchNorm2d(depths[0],momentum=_BN_MOMENTUM),nn.ReLU(inplace=True),nn.Conv2d(depths[0],depths[1],1,padding=0,stride=1,bias=False),nn.BatchNorm2d(depths[1],momentum=_BN_MOMENTUM),# MNASNet blocks: stacks of inverted residuals._stack(depths[1],depths[2],3,2,3,3,_BN_MOMENTUM),_stack(depths[2],depths[3],5,2,3,3,_BN_MOMENTUM),_stack(depths[3],depths[4],5,2,6,3,_BN_MOMENTUM),_stack(depths[4],depths[5],3,1,6,2,_BN_MOMENTUM),_stack(depths[5],depths[6],5,2,6,4,_BN_MOMENTUM),_stack(depths[6],depths[7],3,1,6,1,_BN_MOMENTUM),# Final mapping to classifier input.nn.Conv2d(depths[7],1280,1,padding=0,stride=1,bias=False),nn.BatchNorm2d(1280,momentum=_BN_MOMENTUM),nn.ReLU(inplace=True),]self.layers=nn.Sequential(*layers)self.classifier=nn.Sequential(nn.Dropout(p=dropout,inplace=True),nn.Linear(1280,num_classes))forminself.modules():ifisinstance(m,nn.Conv2d):nn.init.kaiming_normal_(m.weight,mode="fan_out",nonlinearity="relu")ifm.biasisnotNone:nn.init.zeros_(m.bias)elifisinstance(m,nn.BatchNorm2d):nn.init.ones_(m.weight)nn.init.zeros_(m.bias)elifisinstance(m,nn.Linear):nn.init.kaiming_uniform_(m.weight,mode="fan_out",nonlinearity="sigmoid")nn.init.zeros_(m.bias)defforward(self,x:Tensor)->Tensor:x=self.layers(x)# Equivalent to global avgpool and removing H and W dimensions.x=x.mean([2,3])returnself.classifier(x)def_load_from_state_dict(self,state_dict:Dict,prefix:str,local_metadata:Dict,strict:bool,missing_keys:List[str],unexpected_keys:List[str],error_msgs:List[str],)->None:version=local_metadata.get("version",None)ifversionnotin[1,2]:raiseValueError(f"version shluld be set to 1 or 2 instead of {version}")ifversion==1andnotself.alpha==1.0:# In the initial version of the model (v1), stem was fixed-size.# All other layer configurations were the same. This will patch# the model so that it's identical to v1. Model with alpha 1.0 is# unaffected.depths=_get_depths(self.alpha)v1_stem=[nn.Conv2d(3,32,3,padding=1,stride=2,bias=False),nn.BatchNorm2d(32,momentum=_BN_MOMENTUM),nn.ReLU(inplace=True),nn.Conv2d(32,32,3,padding=1,stride=1,groups=32,bias=False),nn.BatchNorm2d(32,momentum=_BN_MOMENTUM),nn.ReLU(inplace=True),nn.Conv2d(32,16,1,padding=0,stride=1,bias=False),nn.BatchNorm2d(16,momentum=_BN_MOMENTUM),_stack(16,depths[2],3,2,3,3,_BN_MOMENTUM),]foridx,layerinenumerate(v1_stem):self.layers[idx]=layer# The model is now identical to v1, and must be saved as such.self._version=1warnings.warn("A new version of MNASNet model has been implemented. ""Your checkpoint was saved using the previous version. ""This checkpoint will load and work as before, but ""you may want to upgrade by training a newer model or ""transfer learning from an updated ImageNet checkpoint.",UserWarning,)super()._load_from_state_dict(state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs)_COMMON_META={"min_size":(1,1),"categories":_IMAGENET_CATEGORIES,"recipe":"https://github.com/1e100/mnasnet_trainer",}classMNASNet0_5_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",transforms=partial(ImageClassification,crop_size=224),meta={**_COMMON_META,"num_params":2218512,"_metrics":{"ImageNet-1K":{"acc@1":67.734,"acc@5":87.490,}},"_docs":"""These weights reproduce closely the results of the paper.""",},)DEFAULT=IMAGENET1K_V1classMNASNet0_75_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/mnasnet0_75-7090bc5f.pth",transforms=partial(ImageClassification,crop_size=224,resize_size=232),meta={**_COMMON_META,"recipe":"https://github.com/pytorch/vision/pull/6019","num_params":3170208,"_metrics":{"ImageNet-1K":{"acc@1":71.180,"acc@5":90.496,}},"_docs":""" These weights were trained from scratch by using TorchVision's `new training recipe <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. """,},)DEFAULT=IMAGENET1K_V1
[docs]classMNASNet1_0_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",transforms=partial(ImageClassification,crop_size=224),meta={**_COMMON_META,"num_params":4383312,"_metrics":{"ImageNet-1K":{"acc@1":73.456,"acc@5":91.510,}},"_docs":"""These weights reproduce closely the results of the paper.""",},)DEFAULT=IMAGENET1K_V1
[docs]classMNASNet1_3_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/mnasnet1_3-a4c69d6f.pth",transforms=partial(ImageClassification,crop_size=224,resize_size=232),meta={**_COMMON_META,"recipe":"https://github.com/pytorch/vision/pull/6019","num_params":6282256,"_metrics":{"ImageNet-1K":{"acc@1":76.506,"acc@5":93.522,}},"_docs":""" These weights were trained from scratch by using TorchVision's `new training recipe <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. """,},)DEFAULT=IMAGENET1K_V1
def_mnasnet(alpha:float,weights:Optional[WeightsEnum],progress:bool,**kwargs:Any)->MNASNet:ifweightsisnotNone:_ovewrite_named_param(kwargs,"num_classes",len(weights.meta["categories"]))model=MNASNet(alpha,**kwargs)ifweights:model.load_state_dict(weights.get_state_dict(progress=progress))returnmodel@handle_legacy_interface(weights=("pretrained",MNASNet0_5_Weights.IMAGENET1K_V1))defmnasnet0_5(*,weights:Optional[MNASNet0_5_Weights]=None,progress:bool=True,**kwargs:Any)->MNASNet:"""MNASNet with depth multiplier of 0.5 from `MnasNet: Platform-Aware Neural Architecture Search for Mobile <https://arxiv.org/pdf/1807.11626.pdf>`_ paper. Args: weights (:class:`~torchvision.models.MNASNet0_5_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.MNASNet0_5_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.MNASNet0_5_Weights :members: """weights=MNASNet0_5_Weights.verify(weights)return_mnasnet(0.5,weights,progress,**kwargs)@handle_legacy_interface(weights=("pretrained",MNASNet0_75_Weights.IMAGENET1K_V1))defmnasnet0_75(*,weights:Optional[MNASNet0_75_Weights]=None,progress:bool=True,**kwargs:Any)->MNASNet:"""MNASNet with depth multiplier of 0.75 from `MnasNet: Platform-Aware Neural Architecture Search for Mobile <https://arxiv.org/pdf/1807.11626.pdf>`_ paper. Args: weights (:class:`~torchvision.models.MNASNet0_75_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.MNASNet0_75_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.MNASNet0_75_Weights :members: """weights=MNASNet0_75_Weights.verify(weights)return_mnasnet(0.75,weights,progress,**kwargs)
[docs]@handle_legacy_interface(weights=("pretrained",MNASNet1_0_Weights.IMAGENET1K_V1))defmnasnet1_0(*,weights:Optional[MNASNet1_0_Weights]=None,progress:bool=True,**kwargs:Any)->MNASNet:"""MNASNet with depth multiplier of 1.0 from `MnasNet: Platform-Aware Neural Architecture Search for Mobile <https://arxiv.org/pdf/1807.11626.pdf>`_ paper. Args: weights (:class:`~torchvision.models.MNASNet1_0_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.MNASNet1_0_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.MNASNet1_0_Weights :members: """weights=MNASNet1_0_Weights.verify(weights)return_mnasnet(1.0,weights,progress,**kwargs)
[docs]@handle_legacy_interface(weights=("pretrained",MNASNet1_3_Weights.IMAGENET1K_V1))defmnasnet1_3(*,weights:Optional[MNASNet1_3_Weights]=None,progress:bool=True,**kwargs:Any)->MNASNet:"""MNASNet with depth multiplier of 1.3 from `MnasNet: Platform-Aware Neural Architecture Search for Mobile <https://arxiv.org/pdf/1807.11626.pdf>`_ paper. Args: weights (:class:`~torchvision.models.MNASNet1_3_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.MNASNet1_3_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_ for more details about this class. .. autoclass:: torchvision.models.MNASNet1_3_Weights :members: """weights=MNASNet1_3_Weights.verify(weights)return_mnasnet(1.3,weights,progress,**kwargs)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.