Source code for torchvision.models.segmentation.fcn
fromfunctoolsimportpartialfromtypingimportAny,Optionalfromtorchimportnnfrom...transforms._presetsimportSemanticSegmentationfrom.._apiimportregister_model,Weights,WeightsEnumfrom.._metaimport_VOC_CATEGORIESfrom.._utilsimport_ovewrite_value_param,handle_legacy_interface,IntermediateLayerGetterfrom..resnetimportResNet,resnet101,ResNet101_Weights,resnet50,ResNet50_Weightsfrom._utilsimport_SimpleSegmentationModel__all__=["FCN","FCN_ResNet50_Weights","FCN_ResNet101_Weights","fcn_resnet50","fcn_resnet101"]classFCN(_SimpleSegmentationModel):""" Implements FCN model from `"Fully Convolutional Networks for Semantic Segmentation" <https://arxiv.org/abs/1411.4038>`_. Args: backbone (nn.Module): the network used to compute the features for the model. The backbone should return an OrderedDict[Tensor], with the key being "out" for the last feature map used, and "aux" if an auxiliary classifier is used. classifier (nn.Module): module that takes the "out" element returned from the backbone and returns a dense prediction. aux_classifier (nn.Module, optional): auxiliary classifier used during training """passclassFCNHead(nn.Sequential):def__init__(self,in_channels:int,channels:int)->None:inter_channels=in_channels//4layers=[nn.Conv2d(in_channels,inter_channels,3,padding=1,bias=False),nn.BatchNorm2d(inter_channels),nn.ReLU(),nn.Dropout(0.1),nn.Conv2d(inter_channels,channels,1),]super().__init__(*layers)_COMMON_META={"categories":_VOC_CATEGORIES,"min_size":(1,1),"_docs":""" These weights were trained on a subset of COCO, using only the 20 categories that are present in the Pascal VOC dataset. """,}
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1),weights_backbone=("pretrained_backbone",ResNet50_Weights.IMAGENET1K_V1),)deffcn_resnet50(*,weights:Optional[FCN_ResNet50_Weights]=None,progress:bool=True,num_classes:Optional[int]=None,aux_loss:Optional[bool]=None,weights_backbone:Optional[ResNet50_Weights]=ResNet50_Weights.IMAGENET1K_V1,**kwargs:Any,)->FCN:"""Fully-Convolutional Network model with a ResNet-50 backbone from the `Fully Convolutional Networks for Semantic Segmentation <https://arxiv.org/abs/1411.4038>`_ paper. .. betastatus:: segmentation module Args: weights (:class:`~torchvision.models.segmentation.FCN_ResNet50_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.segmentation.FCN_ResNet50_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. num_classes (int, optional): number of output classes of the model (including the background). aux_loss (bool, optional): If True, it uses an auxiliary loss. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for the backbone. **kwargs: parameters passed to the ``torchvision.models.segmentation.fcn.FCN`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/segmentation/fcn.py>`_ for more details about this class. .. autoclass:: torchvision.models.segmentation.FCN_ResNet50_Weights :members: """weights=FCN_ResNet50_Weights.verify(weights)weights_backbone=ResNet50_Weights.verify(weights_backbone)ifweightsisnotNone:weights_backbone=Nonenum_classes=_ovewrite_value_param("num_classes",num_classes,len(weights.meta["categories"]))aux_loss=_ovewrite_value_param("aux_loss",aux_loss,True)elifnum_classesisNone:num_classes=21backbone=resnet50(weights=weights_backbone,replace_stride_with_dilation=[False,True,True])model=_fcn_resnet(backbone,num_classes,aux_loss)ifweightsisnotNone:model.load_state_dict(weights.get_state_dict(progress=progress,check_hash=True))returnmodel
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1),weights_backbone=("pretrained_backbone",ResNet101_Weights.IMAGENET1K_V1),)deffcn_resnet101(*,weights:Optional[FCN_ResNet101_Weights]=None,progress:bool=True,num_classes:Optional[int]=None,aux_loss:Optional[bool]=None,weights_backbone:Optional[ResNet101_Weights]=ResNet101_Weights.IMAGENET1K_V1,**kwargs:Any,)->FCN:"""Fully-Convolutional Network model with a ResNet-101 backbone from the `Fully Convolutional Networks for Semantic Segmentation <https://arxiv.org/abs/1411.4038>`_ paper. .. betastatus:: segmentation module Args: weights (:class:`~torchvision.models.segmentation.FCN_ResNet101_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.segmentation.FCN_ResNet101_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. num_classes (int, optional): number of output classes of the model (including the background). aux_loss (bool, optional): If True, it uses an auxiliary loss. weights_backbone (:class:`~torchvision.models.ResNet101_Weights`, optional): The pretrained weights for the backbone. **kwargs: parameters passed to the ``torchvision.models.segmentation.fcn.FCN`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/segmentation/fcn.py>`_ for more details about this class. .. autoclass:: torchvision.models.segmentation.FCN_ResNet101_Weights :members: """weights=FCN_ResNet101_Weights.verify(weights)weights_backbone=ResNet101_Weights.verify(weights_backbone)ifweightsisnotNone:weights_backbone=Nonenum_classes=_ovewrite_value_param("num_classes",num_classes,len(weights.meta["categories"]))aux_loss=_ovewrite_value_param("aux_loss",aux_loss,True)elifnum_classesisNone:num_classes=21backbone=resnet101(weights=weights_backbone,replace_stride_with_dilation=[False,True,True])model=_fcn_resnet(backbone,num_classes,aux_loss)ifweightsisnotNone:model.load_state_dict(weights.get_state_dict(progress=progress,check_hash=True))returnmodel
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.