Source code for torchvision.models.convnext
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation
from ..ops.stochastic_depth import StochasticDepth
from ..utils import _log_api_usage_once
__all__ = [
"ConvNeXt",
"convnext_tiny",
"convnext_small",
"convnext_base",
"convnext_large",
]
_MODELS_URLS: Dict[str, Optional[str]] = {
"convnext_tiny": "https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
"convnext_small": "https://download.pytorch.org/models/convnext_small-0c510722.pth",
"convnext_base": "https://download.pytorch.org/models/convnext_base-6075fbad.pth",
"convnext_large": "https://download.pytorch.org/models/convnext_large-ea097f82.pth",
}
class LayerNorm2d(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
x = x.permute(0, 2, 3, 1)
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.permute(0, 3, 1, 2)
return x
class Permute(nn.Module):
def __init__(self, dims: List[int]):
super().__init__()
self.dims = dims
def forward(self, x):
return torch.permute(x, self.dims)
class CNBlock(nn.Module):
def __init__(
self,
dim,
layer_scale: float,
stochastic_depth_prob: float,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.block = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
Permute([0, 2, 3, 1]),
norm_layer(dim),
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
nn.GELU(),
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
Permute([0, 3, 1, 2]),
)
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
def forward(self, input: Tensor) -> Tensor:
result = self.layer_scale * self.block(input)
result = self.stochastic_depth(result)
result += input
return result
class CNBlockConfig:
# Stores information listed at Section 3 of the ConvNeXt paper
def __init__(
self,
input_channels: int,
out_channels: Optional[int],
num_layers: int,
) -> None:
self.input_channels = input_channels
self.out_channels = out_channels
self.num_layers = num_layers
def __repr__(self) -> str:
s = self.__class__.__name__ + "("
s += "input_channels={input_channels}"
s += ", out_channels={out_channels}"
s += ", num_layers={num_layers}"
s += ")"
return s.format(**self.__dict__)
class ConvNeXt(nn.Module):
def __init__(
self,
block_setting: List[CNBlockConfig],
stochastic_depth_prob: float = 0.0,
layer_scale: float = 1e-6,
num_classes: int = 1000,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any,
) -> None:
super().__init__()
_log_api_usage_once(self)
if not block_setting:
raise ValueError("The block_setting should not be empty")
elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
raise TypeError("The block_setting should be List[CNBlockConfig]")
if block is None:
block = CNBlock
if norm_layer is None:
norm_layer = partial(LayerNorm2d, eps=1e-6)
layers: List[nn.Module] = []
# Stem
firstconv_output_channels = block_setting[0].input_channels
layers.append(
ConvNormActivation(
3,
firstconv_output_channels,
kernel_size=4,
stride=4,
padding=0,
norm_layer=norm_layer,
activation_layer=None,
bias=True,
)
)
total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
stage_block_id = 0
for cnf in block_setting:
# Bottlenecks
stage: List[nn.Module] = []
for _ in range(cnf.num_layers):
# adjust stochastic depth probability based on the depth of the stage block
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
stage.append(block(cnf.input_channels, layer_scale, sd_prob))
stage_block_id += 1
layers.append(nn.Sequential(*stage))
if cnf.out_channels is not None:
# Downsampling
layers.append(
nn.Sequential(
norm_layer(cnf.input_channels),
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
)
)
self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)
lastblock = block_setting[-1]
lastconv_output_channels = (
lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels
)
self.classifier = nn.Sequential(
norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes)
)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def _forward_impl(self, x: Tensor) -> Tensor:
x = self.features(x)
x = self.avgpool(x)
x = self.classifier(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
def _convnext(
arch: str,
block_setting: List[CNBlockConfig],
stochastic_depth_prob: float,
pretrained: bool,
progress: bool,
**kwargs: Any,
) -> ConvNeXt:
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
if pretrained:
if arch not in _MODELS_URLS:
raise ValueError(f"No checkpoint is available for model type {arch}")
state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress)
model.load_state_dict(state_dict)
return model
[docs]def convnext_tiny(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt:
r"""ConvNeXt Tiny model architecture from the
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
block_setting = [
CNBlockConfig(96, 192, 3),
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 9),
CNBlockConfig(768, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
return _convnext("convnext_tiny", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs)
[docs]def convnext_small(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt:
r"""ConvNeXt Small model architecture from the
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
block_setting = [
CNBlockConfig(96, 192, 3),
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 27),
CNBlockConfig(768, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4)
return _convnext("convnext_small", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs)
[docs]def convnext_base(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt:
r"""ConvNeXt Base model architecture from the
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
block_setting = [
CNBlockConfig(128, 256, 3),
CNBlockConfig(256, 512, 3),
CNBlockConfig(512, 1024, 27),
CNBlockConfig(1024, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
return _convnext("convnext_base", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs)
[docs]def convnext_large(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt:
r"""ConvNeXt Large model architecture from the
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
block_setting = [
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 3),
CNBlockConfig(768, 1536, 27),
CNBlockConfig(1536, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
return _convnext("convnext_large", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs)