Source code for torchvision.models.convnext
from functools import partial
from typing import Any, Callable, List, Optional, Sequence
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from ..ops.misc import Conv2dNormActivation, Permute
from ..ops.stochastic_depth import StochasticDepth
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
"ConvNeXt",
"ConvNeXt_Tiny_Weights",
"ConvNeXt_Small_Weights",
"ConvNeXt_Base_Weights",
"ConvNeXt_Large_Weights",
"convnext_tiny",
"convnext_small",
"convnext_base",
"convnext_large",
]
class LayerNorm2d(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
x = x.permute(0, 2, 3, 1)
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.permute(0, 3, 1, 2)
return x
class CNBlock(nn.Module):
def __init__(
self,
dim,
layer_scale: float,
stochastic_depth_prob: float,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.block = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
Permute([0, 2, 3, 1]),
norm_layer(dim),
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
nn.GELU(),
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
Permute([0, 3, 1, 2]),
)
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
def forward(self, input: Tensor) -> Tensor:
result = self.layer_scale * self.block(input)
result = self.stochastic_depth(result)
result += input
return result
class CNBlockConfig:
# Stores information listed at Section 3 of the ConvNeXt paper
def __init__(
self,
input_channels: int,
out_channels: Optional[int],
num_layers: int,
) -> None:
self.input_channels = input_channels
self.out_channels = out_channels
self.num_layers = num_layers
def __repr__(self) -> str:
s = self.__class__.__name__ + "("
s += "input_channels={input_channels}"
s += ", out_channels={out_channels}"
s += ", num_layers={num_layers}"
s += ")"
return s.format(**self.__dict__)
class ConvNeXt(nn.Module):
def __init__(
self,
block_setting: List[CNBlockConfig],
stochastic_depth_prob: float = 0.0,
layer_scale: float = 1e-6,
num_classes: int = 1000,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any,
) -> None:
super().__init__()
_log_api_usage_once(self)
if not block_setting:
raise ValueError("The block_setting should not be empty")
elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
raise TypeError("The block_setting should be List[CNBlockConfig]")
if block is None:
block = CNBlock
if norm_layer is None:
norm_layer = partial(LayerNorm2d, eps=1e-6)
layers: List[nn.Module] = []
# Stem
firstconv_output_channels = block_setting[0].input_channels
layers.append(
Conv2dNormActivation(
3,
firstconv_output_channels,
kernel_size=4,
stride=4,
padding=0,
norm_layer=norm_layer,
activation_layer=None,
bias=True,
)
)
total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
stage_block_id = 0
for cnf in block_setting:
# Bottlenecks
stage: List[nn.Module] = []
for _ in range(cnf.num_layers):
# adjust stochastic depth probability based on the depth of the stage block
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
stage.append(block(cnf.input_channels, layer_scale, sd_prob))
stage_block_id += 1
layers.append(nn.Sequential(*stage))
if cnf.out_channels is not None:
# Downsampling
layers.append(
nn.Sequential(
norm_layer(cnf.input_channels),
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
)
)
self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)
lastblock = block_setting[-1]
lastconv_output_channels = (
lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels
)
self.classifier = nn.Sequential(
norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes)
)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def _forward_impl(self, x: Tensor) -> Tensor:
x = self.features(x)
x = self.avgpool(x)
x = self.classifier(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
def _convnext(
block_setting: List[CNBlockConfig],
stochastic_depth_prob: float,
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> ConvNeXt:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
_COMMON_META = {
"min_size": (32, 32),
"categories": _IMAGENET_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext",
"_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
""",
}
[docs]class ConvNeXt_Tiny_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=236),
meta={
**_COMMON_META,
"num_params": 28589128,
"_metrics": {
"ImageNet-1K": {
"acc@1": 82.520,
"acc@5": 96.146,
}
},
},
)
DEFAULT = IMAGENET1K_V1
[docs]class ConvNeXt_Small_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_small-0c510722.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=230),
meta={
**_COMMON_META,
"num_params": 50223688,
"_metrics": {
"ImageNet-1K": {
"acc@1": 83.616,
"acc@5": 96.650,
}
},
},
)
DEFAULT = IMAGENET1K_V1
[docs]class ConvNeXt_Base_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_base-6075fbad.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 88591464,
"_metrics": {
"ImageNet-1K": {
"acc@1": 84.062,
"acc@5": 96.870,
}
},
},
)
DEFAULT = IMAGENET1K_V1
[docs]class ConvNeXt_Large_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_large-ea097f82.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 197767336,
"_metrics": {
"ImageNet-1K": {
"acc@1": 84.414,
"acc@5": 96.976,
}
},
},
)
DEFAULT = IMAGENET1K_V1
[docs]@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
"""ConvNeXt Tiny model architecture from the
`A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
Args:
weights (:class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`, optional): The pretrained
weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`
below for more details and possible values. By default, no pre-trained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
for more details about this class.
.. autoclass:: torchvision.models.ConvNeXt_Tiny_Weights
:members:
"""
weights = ConvNeXt_Tiny_Weights.verify(weights)
block_setting = [
CNBlockConfig(96, 192, 3),
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 9),
CNBlockConfig(768, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
[docs]@handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1))
def convnext_small(
*, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any
) -> ConvNeXt:
"""ConvNeXt Small model architecture from the
`A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
Args:
weights (:class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`, optional): The pretrained
weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`
below for more details and possible values. By default, no pre-trained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
for more details about this class.
.. autoclass:: torchvision.models.ConvNeXt_Small_Weights
:members:
"""
weights = ConvNeXt_Small_Weights.verify(weights)
block_setting = [
CNBlockConfig(96, 192, 3),
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 27),
CNBlockConfig(768, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4)
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
[docs]@handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1))
def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
"""ConvNeXt Base model architecture from the
`A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
Args:
weights (:class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`, optional): The pretrained
weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`
below for more details and possible values. By default, no pre-trained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
for more details about this class.
.. autoclass:: torchvision.models.ConvNeXt_Base_Weights
:members:
"""
weights = ConvNeXt_Base_Weights.verify(weights)
block_setting = [
CNBlockConfig(128, 256, 3),
CNBlockConfig(256, 512, 3),
CNBlockConfig(512, 1024, 27),
CNBlockConfig(1024, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
[docs]@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1))
def convnext_large(
*, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any
) -> ConvNeXt:
"""ConvNeXt Large model architecture from the
`A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
Args:
weights (:class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`, optional): The pretrained
weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`
below for more details and possible values. By default, no pre-trained weights are used.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
for more details about this class.
.. autoclass:: torchvision.models.ConvNeXt_Large_Weights
:members:
"""
weights = ConvNeXt_Large_Weights.verify(weights)
block_setting = [
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 3),
CNBlockConfig(768, 1536, 27),
CNBlockConfig(1536, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)