Source code for torchvision.models.segmentation.fcn
from typing import Optional
from torch import nn
from .. import resnet
from .._utils import IntermediateLayerGetter
from ._utils import _SimpleSegmentationModel, _load_weights
__all__ = ["FCN", "fcn_resnet50", "fcn_resnet101"]
model_urls = {
"fcn_resnet50_coco": "https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth",
"fcn_resnet101_coco": "https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth",
}
class FCN(_SimpleSegmentationModel):
"""
Implements FCN model from
`"Fully Convolutional Networks for Semantic Segmentation"
<https://arxiv.org/abs/1411.4038>`_.
Args:
backbone (nn.Module): the network used to compute the features for the model.
The backbone should return an OrderedDict[Tensor], with the key being
"out" for the last feature map used, and "aux" if an auxiliary classifier
is used.
classifier (nn.Module): module that takes the "out" element returned from
the backbone and returns a dense prediction.
aux_classifier (nn.Module, optional): auxiliary classifier used during training
"""
pass
class FCNHead(nn.Sequential):
def __init__(self, in_channels: int, channels: int) -> None:
inter_channels = in_channels // 4
layers = [
nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU(),
nn.Dropout(0.1),
nn.Conv2d(inter_channels, channels, 1),
]
super().__init__(*layers)
def _fcn_resnet(
backbone: resnet.ResNet,
num_classes: int,
aux: Optional[bool],
) -> FCN:
return_layers = {"layer4": "out"}
if aux:
return_layers["layer3"] = "aux"
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
aux_classifier = FCNHead(1024, num_classes) if aux else None
classifier = FCNHead(2048, num_classes)
return FCN(backbone, classifier, aux_classifier)
[docs]def fcn_resnet50(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True,
) -> FCN:
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool, optional): If True, it uses an auxiliary loss
pretrained_backbone (bool): If True, the backbone will be pre-trained.
"""
if pretrained:
aux_loss = True
pretrained_backbone = False
backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
model = _fcn_resnet(backbone, num_classes, aux_loss)
if pretrained:
arch = "fcn_resnet50_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model
[docs]def fcn_resnet101(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True,
) -> FCN:
"""Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool, optional): If True, it uses an auxiliary loss
pretrained_backbone (bool): If True, the backbone will be pre-trained.
"""
if pretrained:
aux_loss = True
pretrained_backbone = False
backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
model = _fcn_resnet(backbone, num_classes, aux_loss)
if pretrained:
arch = "fcn_resnet101_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model