Source code for torchvision.models.segmentation.lraspp
from collections import OrderedDict
from typing import Any, Dict
from torch import nn, Tensor
from torch.nn import functional as F
from ...utils import _log_api_usage_once
from .. import mobilenetv3
from .._utils import IntermediateLayerGetter
from ._utils import _load_weights
__all__ = ["LRASPP", "lraspp_mobilenet_v3_large"]
model_urls = {
"lraspp_mobilenet_v3_large_coco": "https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",
}
class LRASPP(nn.Module):
"""
Implements a Lite R-ASPP Network for semantic segmentation from
`"Searching for MobileNetV3"
<https://arxiv.org/abs/1905.02244>`_.
Args:
backbone (nn.Module): the network used to compute the features for the model.
The backbone should return an OrderedDict[Tensor], with the key being
"high" for the high level feature map and "low" for the low level feature map.
low_channels (int): the number of channels of the low level features.
high_channels (int): the number of channels of the high level features.
num_classes (int): number of output classes of the model (including the background).
inter_channels (int, optional): the number of channels for intermediate computations.
"""
def __init__(
self, backbone: nn.Module, low_channels: int, high_channels: int, num_classes: int, inter_channels: int = 128
) -> None:
super().__init__()
_log_api_usage_once(self)
self.backbone = backbone
self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels)
def forward(self, input: Tensor) -> Dict[str, Tensor]:
features = self.backbone(input)
out = self.classifier(features)
out = F.interpolate(out, size=input.shape[-2:], mode="bilinear", align_corners=False)
result = OrderedDict()
result["out"] = out
return result
class LRASPPHead(nn.Module):
def __init__(self, low_channels: int, high_channels: int, num_classes: int, inter_channels: int) -> None:
super().__init__()
self.cbr = nn.Sequential(
nn.Conv2d(high_channels, inter_channels, 1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU(inplace=True),
)
self.scale = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(high_channels, inter_channels, 1, bias=False),
nn.Sigmoid(),
)
self.low_classifier = nn.Conv2d(low_channels, num_classes, 1)
self.high_classifier = nn.Conv2d(inter_channels, num_classes, 1)
def forward(self, input: Dict[str, Tensor]) -> Tensor:
low = input["low"]
high = input["high"]
x = self.cbr(high)
s = self.scale(high)
x = x * s
x = F.interpolate(x, size=low.shape[-2:], mode="bilinear", align_corners=False)
return self.low_classifier(low) + self.high_classifier(x)
def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) -> LRASPP:
backbone = backbone.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
low_pos = stage_indices[-4] # use C2 here which has output_stride = 8
high_pos = stage_indices[-1] # use C5 which has output_stride = 16
low_channels = backbone[low_pos].out_channels
high_channels = backbone[high_pos].out_channels
backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): "low", str(high_pos): "high"})
return LRASPP(backbone, low_channels, high_channels, num_classes)
[docs]def lraspp_mobilenet_v3_large(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
pretrained_backbone: bool = True,
**kwargs: Any,
) -> LRASPP:
"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, the backbone will be pre-trained.
"""
if kwargs.pop("aux_loss", False):
raise NotImplementedError("This model does not use auxiliary loss")
if pretrained:
pretrained_backbone = False
backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True)
model = _lraspp_mobilenetv3(backbone, num_classes)
if pretrained:
arch = "lraspp_mobilenet_v3_large_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model