fromfunctoolsimportpartialfromtypingimportAny,Callable,List,Optional,Sequenceimporttorchfromtorchimportnn,Tensorfromtorch.nnimportfunctionalasFfrom..ops.miscimportConv2dNormActivation,Permutefrom..ops.stochastic_depthimportStochasticDepthfrom..transforms._presetsimportImageClassificationfrom..utilsimport_log_api_usage_oncefrom._apiimportWeightsEnum,Weightsfrom._metaimport_IMAGENET_CATEGORIESfrom._utilsimporthandle_legacy_interface,_ovewrite_named_param__all__=["ConvNeXt","ConvNeXt_Tiny_Weights","ConvNeXt_Small_Weights","ConvNeXt_Base_Weights","ConvNeXt_Large_Weights","convnext_tiny","convnext_small","convnext_base","convnext_large",]classLayerNorm2d(nn.LayerNorm):defforward(self,x:Tensor)->Tensor:x=x.permute(0,2,3,1)x=F.layer_norm(x,self.normalized_shape,self.weight,self.bias,self.eps)x=x.permute(0,3,1,2)returnxclassCNBlock(nn.Module):def__init__(self,dim,layer_scale:float,stochastic_depth_prob:float,norm_layer:Optional[Callable[...,nn.Module]]=None,)->None:super().__init__()ifnorm_layerisNone:norm_layer=partial(nn.LayerNorm,eps=1e-6)self.block=nn.Sequential(nn.Conv2d(dim,dim,kernel_size=7,padding=3,groups=dim,bias=True),Permute([0,2,3,1]),norm_layer(dim),nn.Linear(in_features=dim,out_features=4*dim,bias=True),nn.GELU(),nn.Linear(in_features=4*dim,out_features=dim,bias=True),Permute([0,3,1,2]),)self.layer_scale=nn.Parameter(torch.ones(dim,1,1)*layer_scale)self.stochastic_depth=StochasticDepth(stochastic_depth_prob,"row")defforward(self,input:Tensor)->Tensor:result=self.layer_scale*self.block(input)result=self.stochastic_depth(result)result+=inputreturnresultclassCNBlockConfig:# Stores information listed at Section 3 of the ConvNeXt paperdef__init__(self,input_channels:int,out_channels:Optional[int],num_layers:int,)->None:self.input_channels=input_channelsself.out_channels=out_channelsself.num_layers=num_layersdef__repr__(self)->str:s=self.__class__.__name__+"("s+="input_channels={input_channels}"s+=", out_channels={out_channels}"s+=", num_layers={num_layers}"s+=")"returns.format(**self.__dict__)classConvNeXt(nn.Module):def__init__(self,block_setting:List[CNBlockConfig],stochastic_depth_prob:float=0.0,layer_scale:float=1e-6,num_classes:int=1000,block:Optional[Callable[...,nn.Module]]=None,norm_layer:Optional[Callable[...,nn.Module]]=None,**kwargs:Any,)->None:super().__init__()_log_api_usage_once(self)ifnotblock_setting:raiseValueError("The block_setting should not be empty")elifnot(isinstance(block_setting,Sequence)andall([isinstance(s,CNBlockConfig)forsinblock_setting])):raiseTypeError("The block_setting should be List[CNBlockConfig]")ifblockisNone:block=CNBlockifnorm_layerisNone:norm_layer=partial(LayerNorm2d,eps=1e-6)layers:List[nn.Module]=[]# Stemfirstconv_output_channels=block_setting[0].input_channelslayers.append(Conv2dNormActivation(3,firstconv_output_channels,kernel_size=4,stride=4,padding=0,norm_layer=norm_layer,activation_layer=None,bias=True,))total_stage_blocks=sum(cnf.num_layersforcnfinblock_setting)stage_block_id=0forcnfinblock_setting:# Bottlenecksstage:List[nn.Module]=[]for_inrange(cnf.num_layers):# adjust stochastic depth probability based on the depth of the stage blocksd_prob=stochastic_depth_prob*stage_block_id/(total_stage_blocks-1.0)stage.append(block(cnf.input_channels,layer_scale,sd_prob))stage_block_id+=1layers.append(nn.Sequential(*stage))ifcnf.out_channelsisnotNone:# Downsamplinglayers.append(nn.Sequential(norm_layer(cnf.input_channels),nn.Conv2d(cnf.input_channels,cnf.out_channels,kernel_size=2,stride=2),))self.features=nn.Sequential(*layers)self.avgpool=nn.AdaptiveAvgPool2d(1)lastblock=block_setting[-1]lastconv_output_channels=(lastblock.out_channelsiflastblock.out_channelsisnotNoneelselastblock.input_channels)self.classifier=nn.Sequential(norm_layer(lastconv_output_channels),nn.Flatten(1),nn.Linear(lastconv_output_channels,num_classes))forminself.modules():ifisinstance(m,(nn.Conv2d,nn.Linear)):nn.init.trunc_normal_(m.weight,std=0.02)ifm.biasisnotNone:nn.init.zeros_(m.bias)def_forward_impl(self,x:Tensor)->Tensor:x=self.features(x)x=self.avgpool(x)x=self.classifier(x)returnxdefforward(self,x:Tensor)->Tensor:returnself._forward_impl(x)def_convnext(block_setting:List[CNBlockConfig],stochastic_depth_prob:float,weights:Optional[WeightsEnum],progress:bool,**kwargs:Any,)->ConvNeXt:ifweightsisnotNone:_ovewrite_named_param(kwargs,"num_classes",len(weights.meta["categories"]))model=ConvNeXt(block_setting,stochastic_depth_prob=stochastic_depth_prob,**kwargs)ifweightsisnotNone:model.load_state_dict(weights.get_state_dict(progress=progress))returnmodel_COMMON_META={"min_size":(32,32),"categories":_IMAGENET_CATEGORIES,"recipe":"https://github.com/pytorch/vision/tree/main/references/classification#convnext","_docs":""" These weights improve upon the results of the original paper by using a modified version of TorchVision's `new training recipe <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. """,}
[docs]@handle_legacy_interface(weights=("pretrained",ConvNeXt_Tiny_Weights.IMAGENET1K_V1))defconvnext_tiny(*,weights:Optional[ConvNeXt_Tiny_Weights]=None,progress:bool=True,**kwargs:Any)->ConvNeXt:"""ConvNeXt Tiny model architecture from the `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper. Args: weights (:class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights` below for more details and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_ for more details about this class. .. autoclass:: torchvision.models.ConvNeXt_Tiny_Weights :members: """weights=ConvNeXt_Tiny_Weights.verify(weights)block_setting=[CNBlockConfig(96,192,3),CNBlockConfig(192,384,3),CNBlockConfig(384,768,9),CNBlockConfig(768,None,3),]stochastic_depth_prob=kwargs.pop("stochastic_depth_prob",0.1)return_convnext(block_setting,stochastic_depth_prob,weights,progress,**kwargs)
[docs]@handle_legacy_interface(weights=("pretrained",ConvNeXt_Small_Weights.IMAGENET1K_V1))defconvnext_small(*,weights:Optional[ConvNeXt_Small_Weights]=None,progress:bool=True,**kwargs:Any)->ConvNeXt:"""ConvNeXt Small model architecture from the `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper. Args: weights (:class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Small_Weights` below for more details and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_ for more details about this class. .. autoclass:: torchvision.models.ConvNeXt_Small_Weights :members: """weights=ConvNeXt_Small_Weights.verify(weights)block_setting=[CNBlockConfig(96,192,3),CNBlockConfig(192,384,3),CNBlockConfig(384,768,27),CNBlockConfig(768,None,3),]stochastic_depth_prob=kwargs.pop("stochastic_depth_prob",0.4)return_convnext(block_setting,stochastic_depth_prob,weights,progress,**kwargs)
[docs]@handle_legacy_interface(weights=("pretrained",ConvNeXt_Base_Weights.IMAGENET1K_V1))defconvnext_base(*,weights:Optional[ConvNeXt_Base_Weights]=None,progress:bool=True,**kwargs:Any)->ConvNeXt:"""ConvNeXt Base model architecture from the `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper. Args: weights (:class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Base_Weights` below for more details and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_ for more details about this class. .. autoclass:: torchvision.models.ConvNeXt_Base_Weights :members: """weights=ConvNeXt_Base_Weights.verify(weights)block_setting=[CNBlockConfig(128,256,3),CNBlockConfig(256,512,3),CNBlockConfig(512,1024,27),CNBlockConfig(1024,None,3),]stochastic_depth_prob=kwargs.pop("stochastic_depth_prob",0.5)return_convnext(block_setting,stochastic_depth_prob,weights,progress,**kwargs)
[docs]@handle_legacy_interface(weights=("pretrained",ConvNeXt_Large_Weights.IMAGENET1K_V1))defconvnext_large(*,weights:Optional[ConvNeXt_Large_Weights]=None,progress:bool=True,**kwargs:Any)->ConvNeXt:"""ConvNeXt Large model architecture from the `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper. Args: weights (:class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Large_Weights` below for more details and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_ for more details about this class. .. autoclass:: torchvision.models.ConvNeXt_Large_Weights :members: """weights=ConvNeXt_Large_Weights.verify(weights)block_setting=[CNBlockConfig(192,384,3),CNBlockConfig(384,768,3),CNBlockConfig(768,1536,27),CNBlockConfig(1536,None,3),]stochastic_depth_prob=kwargs.pop("stochastic_depth_prob",0.5)return_convnext(block_setting,stochastic_depth_prob,weights,progress,**kwargs)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.