fromtypingimportCallable,Any,Listimporttorchimporttorch.nnasnnfromtorchimportTensorfrom.._internally_replaced_utilsimportload_state_dict_from_urlfrom..utilsimport_log_api_usage_once__all__=["ShuffleNetV2","shufflenet_v2_x0_5","shufflenet_v2_x1_0","shufflenet_v2_x1_5","shufflenet_v2_x2_0"]model_urls={"shufflenetv2_x0.5":"https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth","shufflenetv2_x1.0":"https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth","shufflenetv2_x1.5":None,"shufflenetv2_x2.0":None,}defchannel_shuffle(x:Tensor,groups:int)->Tensor:batchsize,num_channels,height,width=x.size()channels_per_group=num_channels//groups# reshapex=x.view(batchsize,groups,channels_per_group,height,width)x=torch.transpose(x,1,2).contiguous()# flattenx=x.view(batchsize,-1,height,width)returnxclassInvertedResidual(nn.Module):def__init__(self,inp:int,oup:int,stride:int)->None:super().__init__()ifnot(1<=stride<=3):raiseValueError("illegal stride value")self.stride=stridebranch_features=oup//2assert(self.stride!=1)or(inp==branch_features<<1)ifself.stride>1:self.branch1=nn.Sequential(self.depthwise_conv(inp,inp,kernel_size=3,stride=self.stride,padding=1),nn.BatchNorm2d(inp),nn.Conv2d(inp,branch_features,kernel_size=1,stride=1,padding=0,bias=False),nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),)else:self.branch1=nn.Sequential()self.branch2=nn.Sequential(nn.Conv2d(inpif(self.stride>1)elsebranch_features,branch_features,kernel_size=1,stride=1,padding=0,bias=False,),nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),self.depthwise_conv(branch_features,branch_features,kernel_size=3,stride=self.stride,padding=1),nn.BatchNorm2d(branch_features),nn.Conv2d(branch_features,branch_features,kernel_size=1,stride=1,padding=0,bias=False),nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),)@staticmethoddefdepthwise_conv(i:int,o:int,kernel_size:int,stride:int=1,padding:int=0,bias:bool=False)->nn.Conv2d:returnnn.Conv2d(i,o,kernel_size,stride,padding,bias=bias,groups=i)defforward(self,x:Tensor)->Tensor:ifself.stride==1:x1,x2=x.chunk(2,dim=1)out=torch.cat((x1,self.branch2(x2)),dim=1)else:out=torch.cat((self.branch1(x),self.branch2(x)),dim=1)out=channel_shuffle(out,2)returnoutclassShuffleNetV2(nn.Module):def__init__(self,stages_repeats:List[int],stages_out_channels:List[int],num_classes:int=1000,inverted_residual:Callable[...,nn.Module]=InvertedResidual,)->None:super().__init__()_log_api_usage_once(self)iflen(stages_repeats)!=3:raiseValueError("expected stages_repeats as list of 3 positive ints")iflen(stages_out_channels)!=5:raiseValueError("expected stages_out_channels as list of 5 positive ints")self._stage_out_channels=stages_out_channelsinput_channels=3output_channels=self._stage_out_channels[0]self.conv1=nn.Sequential(nn.Conv2d(input_channels,output_channels,3,2,1,bias=False),nn.BatchNorm2d(output_channels),nn.ReLU(inplace=True),)input_channels=output_channelsself.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)# Static annotations for mypyself.stage2:nn.Sequentialself.stage3:nn.Sequentialself.stage4:nn.Sequentialstage_names=[f"stage{i}"foriin[2,3,4]]forname,repeats,output_channelsinzip(stage_names,stages_repeats,self._stage_out_channels[1:]):seq=[inverted_residual(input_channels,output_channels,2)]foriinrange(repeats-1):seq.append(inverted_residual(output_channels,output_channels,1))setattr(self,name,nn.Sequential(*seq))input_channels=output_channelsoutput_channels=self._stage_out_channels[-1]self.conv5=nn.Sequential(nn.Conv2d(input_channels,output_channels,1,1,0,bias=False),nn.BatchNorm2d(output_channels),nn.ReLU(inplace=True),)self.fc=nn.Linear(output_channels,num_classes)def_forward_impl(self,x:Tensor)->Tensor:# See note [TorchScript super()]x=self.conv1(x)x=self.maxpool(x)x=self.stage2(x)x=self.stage3(x)x=self.stage4(x)x=self.conv5(x)x=x.mean([2,3])# globalpoolx=self.fc(x)returnxdefforward(self,x:Tensor)->Tensor:returnself._forward_impl(x)def_shufflenetv2(arch:str,pretrained:bool,progress:bool,*args:Any,**kwargs:Any)->ShuffleNetV2:model=ShuffleNetV2(*args,**kwargs)ifpretrained:model_url=model_urls[arch]ifmodel_urlisNone:raiseValueError(f"No checkpoint is available for model type {arch}")else:state_dict=load_state_dict_from_url(model_url,progress=progress)model.load_state_dict(state_dict)returnmodeldefshufflenet_v2_x0_5(pretrained:bool=False,progress:bool=True,**kwargs:Any)->ShuffleNetV2:""" Constructs a ShuffleNetV2 with 0.5x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" <https://arxiv.org/abs/1807.11164>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """return_shufflenetv2("shufflenetv2_x0.5",pretrained,progress,[4,8,4],[24,48,96,192,1024],**kwargs)defshufflenet_v2_x1_0(pretrained:bool=False,progress:bool=True,**kwargs:Any)->ShuffleNetV2:""" Constructs a ShuffleNetV2 with 1.0x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" <https://arxiv.org/abs/1807.11164>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """return_shufflenetv2("shufflenetv2_x1.0",pretrained,progress,[4,8,4],[24,116,232,464,1024],**kwargs)defshufflenet_v2_x1_5(pretrained:bool=False,progress:bool=True,**kwargs:Any)->ShuffleNetV2:""" Constructs a ShuffleNetV2 with 1.5x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" <https://arxiv.org/abs/1807.11164>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """return_shufflenetv2("shufflenetv2_x1.5",pretrained,progress,[4,8,4],[24,176,352,704,1024],**kwargs)
[docs]defshufflenet_v2_x2_0(pretrained:bool=False,progress:bool=True,**kwargs:Any)->ShuffleNetV2:""" Constructs a ShuffleNetV2 with 2.0x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" <https://arxiv.org/abs/1807.11164>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """return_shufflenetv2("shufflenetv2_x2.0",pretrained,progress,[4,8,4],[24,244,488,976,2048],**kwargs)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.