fromfunctoolsimportpartialfromtypingimportAny,Callable,List,Optionalimporttorchimporttorch.nnasnnfromtorchimportTensorfrom..transforms._presetsimportImageClassificationfrom..utilsimport_log_api_usage_oncefrom._apiimportregister_model,Weights,WeightsEnumfrom._metaimport_IMAGENET_CATEGORIESfrom._utilsimport_ovewrite_named_param,handle_legacy_interface__all__=["ShuffleNetV2","ShuffleNet_V2_X0_5_Weights","ShuffleNet_V2_X1_0_Weights","ShuffleNet_V2_X1_5_Weights","ShuffleNet_V2_X2_0_Weights","shufflenet_v2_x0_5","shufflenet_v2_x1_0","shufflenet_v2_x1_5","shufflenet_v2_x2_0",]defchannel_shuffle(x:Tensor,groups:int)->Tensor:batchsize,num_channels,height,width=x.size()channels_per_group=num_channels//groups# reshapex=x.view(batchsize,groups,channels_per_group,height,width)x=torch.transpose(x,1,2).contiguous()# flattenx=x.view(batchsize,num_channels,height,width)returnxclassInvertedResidual(nn.Module):def__init__(self,inp:int,oup:int,stride:int)->None:super().__init__()ifnot(1<=stride<=3):raiseValueError("illegal stride value")self.stride=stridebranch_features=oup//2if(self.stride==1)and(inp!=branch_features<<1):raiseValueError(f"Invalid combination of stride {stride}, inp {inp} and oup {oup} values. If stride == 1 then inp should be equal to oup // 2 << 1.")ifself.stride>1:self.branch1=nn.Sequential(self.depthwise_conv(inp,inp,kernel_size=3,stride=self.stride,padding=1),nn.BatchNorm2d(inp),nn.Conv2d(inp,branch_features,kernel_size=1,stride=1,padding=0,bias=False),nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),)else:self.branch1=nn.Sequential()self.branch2=nn.Sequential(nn.Conv2d(inpif(self.stride>1)elsebranch_features,branch_features,kernel_size=1,stride=1,padding=0,bias=False,),nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),self.depthwise_conv(branch_features,branch_features,kernel_size=3,stride=self.stride,padding=1),nn.BatchNorm2d(branch_features),nn.Conv2d(branch_features,branch_features,kernel_size=1,stride=1,padding=0,bias=False),nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),)@staticmethoddefdepthwise_conv(i:int,o:int,kernel_size:int,stride:int=1,padding:int=0,bias:bool=False)->nn.Conv2d:returnnn.Conv2d(i,o,kernel_size,stride,padding,bias=bias,groups=i)defforward(self,x:Tensor)->Tensor:ifself.stride==1:x1,x2=x.chunk(2,dim=1)out=torch.cat((x1,self.branch2(x2)),dim=1)else:out=torch.cat((self.branch1(x),self.branch2(x)),dim=1)out=channel_shuffle(out,2)returnoutclassShuffleNetV2(nn.Module):def__init__(self,stages_repeats:List[int],stages_out_channels:List[int],num_classes:int=1000,inverted_residual:Callable[...,nn.Module]=InvertedResidual,)->None:super().__init__()_log_api_usage_once(self)iflen(stages_repeats)!=3:raiseValueError("expected stages_repeats as list of 3 positive ints")iflen(stages_out_channels)!=5:raiseValueError("expected stages_out_channels as list of 5 positive ints")self._stage_out_channels=stages_out_channelsinput_channels=3output_channels=self._stage_out_channels[0]self.conv1=nn.Sequential(nn.Conv2d(input_channels,output_channels,3,2,1,bias=False),nn.BatchNorm2d(output_channels),nn.ReLU(inplace=True),)input_channels=output_channelsself.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)# Static annotations for mypyself.stage2:nn.Sequentialself.stage3:nn.Sequentialself.stage4:nn.Sequentialstage_names=[f"stage{i}"foriin[2,3,4]]forname,repeats,output_channelsinzip(stage_names,stages_repeats,self._stage_out_channels[1:]):seq=[inverted_residual(input_channels,output_channels,2)]foriinrange(repeats-1):seq.append(inverted_residual(output_channels,output_channels,1))setattr(self,name,nn.Sequential(*seq))input_channels=output_channelsoutput_channels=self._stage_out_channels[-1]self.conv5=nn.Sequential(nn.Conv2d(input_channels,output_channels,1,1,0,bias=False),nn.BatchNorm2d(output_channels),nn.ReLU(inplace=True),)self.fc=nn.Linear(output_channels,num_classes)def_forward_impl(self,x:Tensor)->Tensor:# See note [TorchScript super()]x=self.conv1(x)x=self.maxpool(x)x=self.stage2(x)x=self.stage3(x)x=self.stage4(x)x=self.conv5(x)x=x.mean([2,3])# globalpoolx=self.fc(x)returnxdefforward(self,x:Tensor)->Tensor:returnself._forward_impl(x)def_shufflenetv2(weights:Optional[WeightsEnum],progress:bool,*args:Any,**kwargs:Any,)->ShuffleNetV2:ifweightsisnotNone:_ovewrite_named_param(kwargs,"num_classes",len(weights.meta["categories"]))model=ShuffleNetV2(*args,**kwargs)ifweightsisnotNone:model.load_state_dict(weights.get_state_dict(progress=progress,check_hash=True))returnmodel_COMMON_META={"min_size":(1,1),"categories":_IMAGENET_CATEGORIES,"recipe":"https://github.com/ericsun99/Shufflenet-v2-Pytorch",}
[docs]classShuffleNet_V2_X0_5_Weights(WeightsEnum):IMAGENET1K_V1=Weights(# Weights ported from https://github.com/ericsun99/Shufflenet-v2-Pytorchurl="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",transforms=partial(ImageClassification,crop_size=224),meta={**_COMMON_META,"num_params":1366792,"_metrics":{"ImageNet-1K":{"acc@1":60.552,"acc@5":81.746,}},"_ops":0.04,"_file_size":5.282,"_docs":"""These weights were trained from scratch to reproduce closely the results of the paper.""",},)DEFAULT=IMAGENET1K_V1
[docs]classShuffleNet_V2_X1_0_Weights(WeightsEnum):IMAGENET1K_V1=Weights(# Weights ported from https://github.com/ericsun99/Shufflenet-v2-Pytorchurl="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",transforms=partial(ImageClassification,crop_size=224),meta={**_COMMON_META,"num_params":2278604,"_metrics":{"ImageNet-1K":{"acc@1":69.362,"acc@5":88.316,}},"_ops":0.145,"_file_size":8.791,"_docs":"""These weights were trained from scratch to reproduce closely the results of the paper.""",},)DEFAULT=IMAGENET1K_V1
[docs]classShuffleNet_V2_X1_5_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/shufflenetv2_x1_5-3c479a10.pth",transforms=partial(ImageClassification,crop_size=224,resize_size=232),meta={**_COMMON_META,"recipe":"https://github.com/pytorch/vision/pull/5906","num_params":3503624,"_metrics":{"ImageNet-1K":{"acc@1":72.996,"acc@5":91.086,}},"_ops":0.296,"_file_size":13.557,"_docs":""" These weights were trained from scratch by using TorchVision's `new training recipe <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. """,},)DEFAULT=IMAGENET1K_V1
[docs]classShuffleNet_V2_X2_0_Weights(WeightsEnum):IMAGENET1K_V1=Weights(url="https://download.pytorch.org/models/shufflenetv2_x2_0-8be3c8ee.pth",transforms=partial(ImageClassification,crop_size=224,resize_size=232),meta={**_COMMON_META,"recipe":"https://github.com/pytorch/vision/pull/5906","num_params":7393996,"_metrics":{"ImageNet-1K":{"acc@1":76.230,"acc@5":93.006,}},"_ops":0.583,"_file_size":28.433,"_docs":""" These weights were trained from scratch by using TorchVision's `new training recipe <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. """,},)DEFAULT=IMAGENET1K_V1
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1))defshufflenet_v2_x0_5(*,weights:Optional[ShuffleNet_V2_X0_5_Weights]=None,progress:bool=True,**kwargs:Any)->ShuffleNetV2:""" Constructs a ShuffleNetV2 architecture with 0.5x output channels, as described in `ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design <https://arxiv.org/abs/1807.11164>`__. Args: weights (:class:`~torchvision.models.ShuffleNet_V2_X0_5_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ShuffleNet_V2_X0_5_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.shufflenetv2.ShuffleNetV2`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/shufflenetv2.py>`_ for more details about this class. .. autoclass:: torchvision.models.ShuffleNet_V2_X0_5_Weights :members: """weights=ShuffleNet_V2_X0_5_Weights.verify(weights)return_shufflenetv2(weights,progress,[4,8,4],[24,48,96,192,1024],**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1))defshufflenet_v2_x1_0(*,weights:Optional[ShuffleNet_V2_X1_0_Weights]=None,progress:bool=True,**kwargs:Any)->ShuffleNetV2:""" Constructs a ShuffleNetV2 architecture with 1.0x output channels, as described in `ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design <https://arxiv.org/abs/1807.11164>`__. Args: weights (:class:`~torchvision.models.ShuffleNet_V2_X1_0_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ShuffleNet_V2_X1_0_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.shufflenetv2.ShuffleNetV2`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/shufflenetv2.py>`_ for more details about this class. .. autoclass:: torchvision.models.ShuffleNet_V2_X1_0_Weights :members: """weights=ShuffleNet_V2_X1_0_Weights.verify(weights)return_shufflenetv2(weights,progress,[4,8,4],[24,116,232,464,1024],**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1))defshufflenet_v2_x1_5(*,weights:Optional[ShuffleNet_V2_X1_5_Weights]=None,progress:bool=True,**kwargs:Any)->ShuffleNetV2:""" Constructs a ShuffleNetV2 architecture with 1.5x output channels, as described in `ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design <https://arxiv.org/abs/1807.11164>`__. Args: weights (:class:`~torchvision.models.ShuffleNet_V2_X1_5_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ShuffleNet_V2_X1_5_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.shufflenetv2.ShuffleNetV2`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/shufflenetv2.py>`_ for more details about this class. .. autoclass:: torchvision.models.ShuffleNet_V2_X1_5_Weights :members: """weights=ShuffleNet_V2_X1_5_Weights.verify(weights)return_shufflenetv2(weights,progress,[4,8,4],[24,176,352,704,1024],**kwargs)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1))defshufflenet_v2_x2_0(*,weights:Optional[ShuffleNet_V2_X2_0_Weights]=None,progress:bool=True,**kwargs:Any)->ShuffleNetV2:""" Constructs a ShuffleNetV2 architecture with 2.0x output channels, as described in `ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design <https://arxiv.org/abs/1807.11164>`__. Args: weights (:class:`~torchvision.models.ShuffleNet_V2_X2_0_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.ShuffleNet_V2_X2_0_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.shufflenetv2.ShuffleNetV2`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/shufflenetv2.py>`_ for more details about this class. .. autoclass:: torchvision.models.ShuffleNet_V2_X2_0_Weights :members: """weights=ShuffleNet_V2_X2_0_Weights.verify(weights)return_shufflenetv2(weights,progress,[4,8,4],[24,244,488,976,2048],**kwargs)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.