Source code for torchvision.models.efficientnet
import copy
import math
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch
from torch import nn, Tensor
from torchvision.ops import StochasticDepth
from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
__all__ = [
"EfficientNet",
"EfficientNet_B0_Weights",
"EfficientNet_B1_Weights",
"EfficientNet_B2_Weights",
"EfficientNet_B3_Weights",
"EfficientNet_B4_Weights",
"EfficientNet_B5_Weights",
"EfficientNet_B6_Weights",
"EfficientNet_B7_Weights",
"EfficientNet_V2_S_Weights",
"EfficientNet_V2_M_Weights",
"EfficientNet_V2_L_Weights",
"efficientnet_b0",
"efficientnet_b1",
"efficientnet_b2",
"efficientnet_b3",
"efficientnet_b4",
"efficientnet_b5",
"efficientnet_b6",
"efficientnet_b7",
"efficientnet_v2_s",
"efficientnet_v2_m",
"efficientnet_v2_l",
]
@dataclass
class _MBConvConfig:
expand_ratio: float
kernel: int
stride: int
input_channels: int
out_channels: int
num_layers: int
block: Callable[..., nn.Module]
@staticmethod
def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int:
return _make_divisible(channels * width_mult, 8, min_value)
class MBConvConfig(_MBConvConfig):
# Stores information listed at Table 1 of the EfficientNet paper & Table 4 of the EfficientNetV2 paper
def __init__(
self,
expand_ratio: float,
kernel: int,
stride: int,
input_channels: int,
out_channels: int,
num_layers: int,
width_mult: float = 1.0,
depth_mult: float = 1.0,
block: Optional[Callable[..., nn.Module]] = None,
) -> None:
input_channels = self.adjust_channels(input_channels, width_mult)
out_channels = self.adjust_channels(out_channels, width_mult)
num_layers = self.adjust_depth(num_layers, depth_mult)
if block is None:
block = MBConv
super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
@staticmethod
def adjust_depth(num_layers: int, depth_mult: float):
return int(math.ceil(num_layers * depth_mult))
class FusedMBConvConfig(_MBConvConfig):
# Stores information listed at Table 4 of the EfficientNetV2 paper
def __init__(
self,
expand_ratio: float,
kernel: int,
stride: int,
input_channels: int,
out_channels: int,
num_layers: int,
block: Optional[Callable[..., nn.Module]] = None,
) -> None:
if block is None:
block = FusedMBConv
super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
class MBConv(nn.Module):
def __init__(
self,
cnf: MBConvConfig,
stochastic_depth_prob: float,
norm_layer: Callable[..., nn.Module],
se_layer: Callable[..., nn.Module] = SqueezeExcitation,
) -> None:
super().__init__()
if not (1 <= cnf.stride <= 2):
raise ValueError("illegal stride value")
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
layers: List[nn.Module] = []
activation_layer = nn.SiLU
# expand
expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
if expanded_channels != cnf.input_channels:
layers.append(
Conv2dNormActivation(
cnf.input_channels,
expanded_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)
# depthwise
layers.append(
Conv2dNormActivation(
expanded_channels,
expanded_channels,
kernel_size=cnf.kernel,
stride=cnf.stride,
groups=expanded_channels,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)
# squeeze and excitation
squeeze_channels = max(1, cnf.input_channels // 4)
layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True)))
# project
layers.append(
Conv2dNormActivation(
expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
)
)
self.block = nn.Sequential(*layers)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
self.out_channels = cnf.out_channels
def forward(self, input: Tensor) -> Tensor:
result = self.block(input)
if self.use_res_connect:
result = self.stochastic_depth(result)
result += input
return result
class FusedMBConv(nn.Module):
def __init__(
self,
cnf: FusedMBConvConfig,
stochastic_depth_prob: float,
norm_layer: Callable[..., nn.Module],
) -> None:
super().__init__()
if not (1 <= cnf.stride <= 2):
raise ValueError("illegal stride value")
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
layers: List[nn.Module] = []
activation_layer = nn.SiLU
expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
if expanded_channels != cnf.input_channels:
# fused expand
layers.append(
Conv2dNormActivation(
cnf.input_channels,
expanded_channels,
kernel_size=cnf.kernel,
stride=cnf.stride,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)
# project
layers.append(
Conv2dNormActivation(
expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
)
)
else:
layers.append(
Conv2dNormActivation(
cnf.input_channels,
cnf.out_channels,
kernel_size=cnf.kernel,
stride=cnf.stride,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)
self.block = nn.Sequential(*layers)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
self.out_channels = cnf.out_channels
def forward(self, input: Tensor) -> Tensor:
result = self.block(input)
if self.use_res_connect:
result = self.stochastic_depth(result)
result += input
return result
class EfficientNet(nn.Module):
def __init__(
self,
inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
dropout: float,
stochastic_depth_prob: float = 0.2,
num_classes: int = 1000,
norm_layer: Optional[Callable[..., nn.Module]] = None,
last_channel: Optional[int] = None,
) -> None:
"""
EfficientNet V1 and V2 main class
Args:
inverted_residual_setting (Sequence[Union[MBConvConfig, FusedMBConvConfig]]): Network structure
dropout (float): The droupout probability
stochastic_depth_prob (float): The stochastic depth probability
num_classes (int): Number of classes
norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
last_channel (int): The number of channels on the penultimate layer
"""
super().__init__()
_log_api_usage_once(self)
if not inverted_residual_setting:
raise ValueError("The inverted_residual_setting should not be empty")
elif not (
isinstance(inverted_residual_setting, Sequence)
and all([isinstance(s, _MBConvConfig) for s in inverted_residual_setting])
):
raise TypeError("The inverted_residual_setting should be List[MBConvConfig]")
if norm_layer is None:
norm_layer = nn.BatchNorm2d
layers: List[nn.Module] = []
# building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels
layers.append(
Conv2dNormActivation(
3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU
)
)
# building inverted residual blocks
total_stage_blocks = sum(cnf.num_layers for cnf in inverted_residual_setting)
stage_block_id = 0
for cnf in inverted_residual_setting:
stage: List[nn.Module] = []
for _ in range(cnf.num_layers):
# copy to avoid modifications. shallow copy is enough
block_cnf = copy.copy(cnf)
# overwrite info if not the first conv in the stage
if stage:
block_cnf.input_channels = block_cnf.out_channels
block_cnf.stride = 1
# adjust stochastic depth probability based on the depth of the stage block
sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks
stage.append(block_cnf.block(block_cnf, sd_prob, norm_layer))
stage_block_id += 1
layers.append(nn.Sequential(*stage))
# building last several layers
lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = last_channel if last_channel is not None else 4 * lastconv_input_channels
layers.append(
Conv2dNormActivation(
lastconv_input_channels,
lastconv_output_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=nn.SiLU,
)
)
self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Sequential(
nn.Dropout(p=dropout, inplace=True),
nn.Linear(lastconv_output_channels, num_classes),
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
init_range = 1.0 / math.sqrt(m.out_features)
nn.init.uniform_(m.weight, -init_range, init_range)
nn.init.zeros_(m.bias)
def _forward_impl(self, x: Tensor) -> Tensor:
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
def _efficientnet(
inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
dropout: float,
last_channel: Optional[int],
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> EfficientNet:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
def _efficientnet_conf(
arch: str,
**kwargs: Any,
) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]:
inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]]
if arch.startswith("efficientnet_b"):
bneck_conf = partial(MBConvConfig, width_mult=kwargs.pop("width_mult"), depth_mult=kwargs.pop("depth_mult"))
inverted_residual_setting = [
bneck_conf(1, 3, 1, 32, 16, 1),
bneck_conf(6, 3, 2, 16, 24, 2),
bneck_conf(6, 5, 2, 24, 40, 2),
bneck_conf(6, 3, 2, 40, 80, 3),
bneck_conf(6, 5, 1, 80, 112, 3),
bneck_conf(6, 5, 2, 112, 192, 4),
bneck_conf(6, 3, 1, 192, 320, 1),
]
last_channel = None
elif arch.startswith("efficientnet_v2_s"):
inverted_residual_setting = [
FusedMBConvConfig(1, 3, 1, 24, 24, 2),
FusedMBConvConfig(4, 3, 2, 24, 48, 4),
FusedMBConvConfig(4, 3, 2, 48, 64, 4),
MBConvConfig(4, 3, 2, 64, 128, 6),
MBConvConfig(6, 3, 1, 128, 160, 9),
MBConvConfig(6, 3, 2, 160, 256, 15),
]
last_channel = 1280
elif arch.startswith("efficientnet_v2_m"):
inverted_residual_setting = [
FusedMBConvConfig(1, 3, 1, 24, 24, 3),
FusedMBConvConfig(4, 3, 2, 24, 48, 5),
FusedMBConvConfig(4, 3, 2, 48, 80, 5),
MBConvConfig(4, 3, 2, 80, 160, 7),
MBConvConfig(6, 3, 1, 160, 176, 14),
MBConvConfig(6, 3, 2, 176, 304, 18),
MBConvConfig(6, 3, 1, 304, 512, 5),
]
last_channel = 1280
elif arch.startswith("efficientnet_v2_l"):
inverted_residual_setting = [
FusedMBConvConfig(1, 3, 1, 32, 32, 4),
FusedMBConvConfig(4, 3, 2, 32, 64, 7),
FusedMBConvConfig(4, 3, 2, 64, 96, 7),
MBConvConfig(4, 3, 2, 96, 192, 10),
MBConvConfig(6, 3, 1, 192, 224, 19),
MBConvConfig(6, 3, 2, 224, 384, 25),
MBConvConfig(6, 3, 1, 384, 640, 7),
]
last_channel = 1280
else:
raise ValueError(f"Unsupported model type {arch}")
return inverted_residual_setting, last_channel
_COMMON_META: Dict[str, Any] = {
"categories": _IMAGENET_CATEGORIES,
}
_COMMON_META_V1 = {
**_COMMON_META,
"min_size": (1, 1),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v1",
}
_COMMON_META_V2 = {
**_COMMON_META,
"min_size": (33, 33),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v2",
}
[docs]class EfficientNet_B0_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
# Weights ported from https://github.com/rwightman/pytorch-image-models/
url="https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth",
transforms=partial(
ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 5288548,
"_metrics": {
"ImageNet-1K": {
"acc@1": 77.692,
"acc@5": 93.532,
}
},
"_ops": 0.386,
"_file_size": 20.451,
"_docs": """These weights are ported from the original paper.""",
},
)
DEFAULT = IMAGENET1K_V1
[docs]class EfficientNet_B1_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
# Weights ported from https://github.com/rwightman/pytorch-image-models/
url="https://download.pytorch.org/models/efficientnet_b1_rwightman-bac287d4.pth",
transforms=partial(
ImageClassification, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 7794184,
"_metrics": {
"ImageNet-1K": {
"acc@1": 78.642,
"acc@5": 94.186,
}
},
"_ops": 0.687,
"_file_size": 30.134,
"_docs": """These weights are ported from the original paper.""",
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth",
transforms=partial(
ImageClassification, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR
),
meta={
**_COMMON_META_V1,
"num_params": 7794184,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning",
"_metrics": {
"ImageNet-1K": {
"acc@1": 79.838,
"acc@5": 94.934,
}
},
"_ops": 0.687,
"_file_size": 30.136,
"_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
""",
},
)
DEFAULT = IMAGENET1K_V2
[docs]class EfficientNet_B2_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
# Weights ported from https://github.com/rwightman/pytorch-image-models/
url="https://download.pytorch.org/models/efficientnet_b2_rwightman-c35c1473.pth",
transforms=partial(
ImageClassification, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 9109994,
"_metrics": {
"ImageNet-1K": {
"acc@1": 80.608,
"acc@5": 95.310,
}
},
"_ops": 1.088,
"_file_size": 35.174,
"_docs": """These weights are ported from the original paper.""",
},
)
DEFAULT = IMAGENET1K_V1
[docs]class EfficientNet_B3_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
# Weights ported from https://github.com/rwightman/pytorch-image-models/
url="https://download.pytorch.org/models/efficientnet_b3_rwightman-b3899882.pth",
transforms=partial(
ImageClassification, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 12233232,
"_metrics": {
"ImageNet-1K": {
"acc@1": 82.008,
"acc@5": 96.054,
}
},
"_ops": 1.827,
"_file_size": 47.184,
"_docs": """These weights are ported from the original paper.""",
},
)
DEFAULT = IMAGENET1K_V1
[docs]class EfficientNet_B4_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
# Weights ported from https://github.com/rwightman/pytorch-image-models/
url="https://download.pytorch.org/models/efficientnet_b4_rwightman-23ab8bcd.pth",
transforms=partial(
ImageClassification, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 19341616,
"_metrics": {
"ImageNet-1K": {
"acc@1": 83.384,
"acc@5": 96.594,
}
},
"_ops": 4.394,
"_file_size": 74.489,
"_docs": """These weights are ported from the original paper.""",
},
)
DEFAULT = IMAGENET1K_V1
[docs]class EfficientNet_B5_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
# Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-1a07897c.pth",
transforms=partial(
ImageClassification, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 30389784,
"_metrics": {
"ImageNet-1K": {
"acc@1": 83.444,
"acc@5": 96.628,
}
},
"_ops": 10.266,
"_file_size": 116.864,
"_docs": """These weights are ported from the original paper.""",
},
)
DEFAULT = IMAGENET1K_V1
[docs]class EfficientNet_B6_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
# Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-24a108a5.pth",
transforms=partial(
ImageClassification, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 43040704,
"_metrics": {
"ImageNet-1K": {
"acc@1": 84.008,
"acc@5": 96.916,
}
},
"_ops": 19.068,
"_file_size": 165.362,
"_docs": """These weights are ported from the original paper.""",
},
)
DEFAULT = IMAGENET1K_V1
[docs]class EfficientNet_B7_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
# Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-c5b4e57e.pth",
transforms=partial(
ImageClassification, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 66347960,
"_metrics": {
"ImageNet-1K": {
"acc@1": 84.122,
"acc@5": 96.908,
}
},
"_ops": 37.746,
"_file_size": 254.675,
"_docs": """These weights are ported from the original paper.""",
},
)
DEFAULT = IMAGENET1K_V1
[docs]class EfficientNet_V2_S_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth",
transforms=partial(
ImageClassification,
crop_size=384,
resize_size=384,
interpolation=InterpolationMode.BILINEAR,
),
meta={
**_COMMON_META_V2,
"num_params": 21458488,
"_metrics": {
"ImageNet-1K": {
"acc@1": 84.228,
"acc@5": 96.878,
}
},
"_ops": 8.366,
"_file_size": 82.704,
"_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
""",
},
)
DEFAULT = IMAGENET1K_V1
[docs]class EfficientNet_V2_M_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth",
transforms=partial(
ImageClassification,
crop_size=480,
resize_size=480,
interpolation=InterpolationMode.BILINEAR,
),
meta={
**_COMMON_META_V2,
"num_params": 54139356,
"_metrics": {
"ImageNet-1K": {
"acc@1": 85.112,
"acc@5": 97.156,
}
},
"_ops": 24.582,
"_file_size": 208.01,
"_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
""",
},
)
DEFAULT = IMAGENET1K_V1
[docs]class EfficientNet_V2_L_Weights(WeightsEnum):
# Weights ported from https://github.com/google/automl/tree/master/efficientnetv2
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth",
transforms=partial(
ImageClassification,
crop_size=480,
resize_size=480,
interpolation=InterpolationMode.BICUBIC,
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5),
),
meta={
**_COMMON_META_V2,
"num_params": 118515272,
"_metrics": {
"ImageNet-1K": {
"acc@1": 85.808,
"acc@5": 97.788,
}
},
"_ops": 56.08,
"_file_size": 454.573,
"_docs": """These weights are ported from the original paper.""",
},
)
DEFAULT = IMAGENET1K_V1
[docs]@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1))
def efficientnet_b0(
*, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
"""EfficientNet B0 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
Args:
weights (:class:`~torchvision.models.EfficientNet_B0_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.EfficientNet_B0_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.EfficientNet_B0_Weights
:members:
"""
weights = EfficientNet_B0_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0)
return _efficientnet(
inverted_residual_setting, kwargs.pop("dropout", 0.2), last_channel, weights, progress, **kwargs
)
[docs]@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1))
def efficientnet_b1(
*, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
"""EfficientNet B1 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
Args:
weights (:class:`~torchvision.models.EfficientNet_B1_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.EfficientNet_B1_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.EfficientNet_B1_Weights
:members:
"""
weights = EfficientNet_B1_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1)
return _efficientnet(
inverted_residual_setting, kwargs.pop("dropout", 0.2), last_channel, weights, progress, **kwargs
)
[docs]@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1))
def efficientnet_b2(
*, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
"""EfficientNet B2 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
Args:
weights (:class:`~torchvision.models.EfficientNet_B2_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.EfficientNet_B2_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.EfficientNet_B2_Weights
:members:
"""
weights = EfficientNet_B2_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b2", width_mult=1.1, depth_mult=1.2)
return _efficientnet(
inverted_residual_setting, kwargs.pop("dropout", 0.3), last_channel, weights, progress, **kwargs
)
[docs]@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1))
def efficientnet_b3(
*, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
"""EfficientNet B3 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
Args:
weights (:class:`~torchvision.models.EfficientNet_B3_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.EfficientNet_B3_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.EfficientNet_B3_Weights
:members:
"""
weights = EfficientNet_B3_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b3", width_mult=1.2, depth_mult=1.4)
return _efficientnet(
inverted_residual_setting,
kwargs.pop("dropout", 0.3),
last_channel,
weights,
progress,
**kwargs,
)
[docs]@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1))
def efficientnet_b4(
*, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
"""EfficientNet B4 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
Args:
weights (:class:`~torchvision.models.EfficientNet_B4_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.EfficientNet_B4_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.EfficientNet_B4_Weights
:members:
"""
weights = EfficientNet_B4_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b4", width_mult=1.4, depth_mult=1.8)
return _efficientnet(
inverted_residual_setting,
kwargs.pop("dropout", 0.4),
last_channel,
weights,
progress,
**kwargs,
)
[docs]@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1))
def efficientnet_b5(
*, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
"""EfficientNet B5 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
Args:
weights (:class:`~torchvision.models.EfficientNet_B5_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.EfficientNet_B5_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.EfficientNet_B5_Weights
:members:
"""
weights = EfficientNet_B5_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b5", width_mult=1.6, depth_mult=2.2)
return _efficientnet(
inverted_residual_setting,
kwargs.pop("dropout", 0.4),
last_channel,
weights,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
**kwargs,
)
[docs]@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1))
def efficientnet_b6(
*, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
"""EfficientNet B6 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
Args:
weights (:class:`~torchvision.models.EfficientNet_B6_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.EfficientNet_B6_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.EfficientNet_B6_Weights
:members:
"""
weights = EfficientNet_B6_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b6", width_mult=1.8, depth_mult=2.6)
return _efficientnet(
inverted_residual_setting,
kwargs.pop("dropout", 0.5),
last_channel,
weights,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
**kwargs,
)
[docs]@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1))
def efficientnet_b7(
*, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
"""EfficientNet B7 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
Args:
weights (:class:`~torchvision.models.EfficientNet_B7_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.EfficientNet_B7_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.EfficientNet_B7_Weights
:members:
"""
weights = EfficientNet_B7_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b7", width_mult=2.0, depth_mult=3.1)
return _efficientnet(
inverted_residual_setting,
kwargs.pop("dropout", 0.5),
last_channel,
weights,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
**kwargs,
)
[docs]@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1))
def efficientnet_v2_s(
*, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
"""
Constructs an EfficientNetV2-S architecture from
`EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
Args:
weights (:class:`~torchvision.models.EfficientNet_V2_S_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.EfficientNet_V2_S_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.EfficientNet_V2_S_Weights
:members:
"""
weights = EfficientNet_V2_S_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s")
return _efficientnet(
inverted_residual_setting,
kwargs.pop("dropout", 0.2),
last_channel,
weights,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
**kwargs,
)
[docs]@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1))
def efficientnet_v2_m(
*, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
"""
Constructs an EfficientNetV2-M architecture from
`EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
Args:
weights (:class:`~torchvision.models.EfficientNet_V2_M_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.EfficientNet_V2_M_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.EfficientNet_V2_M_Weights
:members:
"""
weights = EfficientNet_V2_M_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_m")
return _efficientnet(
inverted_residual_setting,
kwargs.pop("dropout", 0.3),
last_channel,
weights,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
**kwargs,
)
[docs]@register_model()
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1))
def efficientnet_v2_l(
*, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
"""
Constructs an EfficientNetV2-L architecture from
`EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
Args:
weights (:class:`~torchvision.models.EfficientNet_V2_L_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.EfficientNet_V2_L_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
for more details about this class.
.. autoclass:: torchvision.models.EfficientNet_V2_L_Weights
:members:
"""
weights = EfficientNet_V2_L_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_l")
return _efficientnet(
inverted_residual_setting,
kwargs.pop("dropout", 0.4),
last_channel,
weights,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
**kwargs,
)