Source code for torchvision.models.segmentation.deeplabv3
fromfunctoolsimportpartialfromtypingimportAny,Optional,Sequenceimporttorchfromtorchimportnnfromtorch.nnimportfunctionalasFfrom...transforms._presetsimportSemanticSegmentationfrom.._apiimportregister_model,Weights,WeightsEnumfrom.._metaimport_VOC_CATEGORIESfrom.._utilsimport_ovewrite_value_param,handle_legacy_interface,IntermediateLayerGetterfrom..mobilenetv3importmobilenet_v3_large,MobileNet_V3_Large_Weights,MobileNetV3from..resnetimportResNet,resnet101,ResNet101_Weights,resnet50,ResNet50_Weightsfrom._utilsimport_SimpleSegmentationModelfrom.fcnimportFCNHead__all__=["DeepLabV3","DeepLabV3_ResNet50_Weights","DeepLabV3_ResNet101_Weights","DeepLabV3_MobileNet_V3_Large_Weights","deeplabv3_mobilenet_v3_large","deeplabv3_resnet50","deeplabv3_resnet101",]classDeepLabV3(_SimpleSegmentationModel):""" Implements DeepLabV3 model from `"Rethinking Atrous Convolution for Semantic Image Segmentation" <https://arxiv.org/abs/1706.05587>`_. Args: backbone (nn.Module): the network used to compute the features for the model. The backbone should return an OrderedDict[Tensor], with the key being "out" for the last feature map used, and "aux" if an auxiliary classifier is used. classifier (nn.Module): module that takes the "out" element returned from the backbone and returns a dense prediction. aux_classifier (nn.Module, optional): auxiliary classifier used during training """passclassDeepLabHead(nn.Sequential):def__init__(self,in_channels:int,num_classes:int,atrous_rates:Sequence[int]=(12,24,36))->None:super().__init__(ASPP(in_channels,atrous_rates),nn.Conv2d(256,256,3,padding=1,bias=False),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(256,num_classes,1),)classASPPConv(nn.Sequential):def__init__(self,in_channels:int,out_channels:int,dilation:int)->None:modules=[nn.Conv2d(in_channels,out_channels,3,padding=dilation,dilation=dilation,bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(),]super().__init__(*modules)classASPPPooling(nn.Sequential):def__init__(self,in_channels:int,out_channels:int)->None:super().__init__(nn.AdaptiveAvgPool2d(1),nn.Conv2d(in_channels,out_channels,1,bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(),)defforward(self,x:torch.Tensor)->torch.Tensor:size=x.shape[-2:]formodinself:x=mod(x)returnF.interpolate(x,size=size,mode="bilinear",align_corners=False)classASPP(nn.Module):def__init__(self,in_channels:int,atrous_rates:Sequence[int],out_channels:int=256)->None:super().__init__()modules=[]modules.append(nn.Sequential(nn.Conv2d(in_channels,out_channels,1,bias=False),nn.BatchNorm2d(out_channels),nn.ReLU()))rates=tuple(atrous_rates)forrateinrates:modules.append(ASPPConv(in_channels,out_channels,rate))modules.append(ASPPPooling(in_channels,out_channels))self.convs=nn.ModuleList(modules)self.project=nn.Sequential(nn.Conv2d(len(self.convs)*out_channels,out_channels,1,bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(),nn.Dropout(0.5),)defforward(self,x:torch.Tensor)->torch.Tensor:_res=[]forconvinself.convs:_res.append(conv(x))res=torch.cat(_res,dim=1)returnself.project(res)def_deeplabv3_resnet(backbone:ResNet,num_classes:int,aux:Optional[bool],)->DeepLabV3:return_layers={"layer4":"out"}ifaux:return_layers["layer3"]="aux"backbone=IntermediateLayerGetter(backbone,return_layers=return_layers)aux_classifier=FCNHead(1024,num_classes)ifauxelseNoneclassifier=DeepLabHead(2048,num_classes)returnDeepLabV3(backbone,classifier,aux_classifier)_COMMON_META={"categories":_VOC_CATEGORIES,"min_size":(1,1),"_docs":""" These weights were trained on a subset of COCO, using only the 20 categories that are present in the Pascal VOC dataset. """,}
def_deeplabv3_mobilenetv3(backbone:MobileNetV3,num_classes:int,aux:Optional[bool],)->DeepLabV3:backbone=backbone.features# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.# The first and last blocks are always included because they are the C0 (conv1) and Cn.stage_indices=[0]+[ifori,binenumerate(backbone)ifgetattr(b,"_is_cn",False)]+[len(backbone)-1]out_pos=stage_indices[-1]# use C5 which has output_stride = 16out_inplanes=backbone[out_pos].out_channelsaux_pos=stage_indices[-4]# use C2 here which has output_stride = 8aux_inplanes=backbone[aux_pos].out_channelsreturn_layers={str(out_pos):"out"}ifaux:return_layers[str(aux_pos)]="aux"backbone=IntermediateLayerGetter(backbone,return_layers=return_layers)aux_classifier=FCNHead(aux_inplanes,num_classes)ifauxelseNoneclassifier=DeepLabHead(out_inplanes,num_classes)returnDeepLabV3(backbone,classifier,aux_classifier)
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1),weights_backbone=("pretrained_backbone",ResNet50_Weights.IMAGENET1K_V1),)defdeeplabv3_resnet50(*,weights:Optional[DeepLabV3_ResNet50_Weights]=None,progress:bool=True,num_classes:Optional[int]=None,aux_loss:Optional[bool]=None,weights_backbone:Optional[ResNet50_Weights]=ResNet50_Weights.IMAGENET1K_V1,**kwargs:Any,)->DeepLabV3:"""Constructs a DeepLabV3 model with a ResNet-50 backbone. .. betastatus:: segmentation module Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation <https://arxiv.org/abs/1706.05587>`__. Args: weights (:class:`~torchvision.models.segmentation.DeepLabV3_ResNet50_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.segmentation.DeepLabV3_ResNet50_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. num_classes (int, optional): number of output classes of the model (including the background) aux_loss (bool, optional): If True, it uses an auxiliary loss weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for the backbone **kwargs: unused .. autoclass:: torchvision.models.segmentation.DeepLabV3_ResNet50_Weights :members: """weights=DeepLabV3_ResNet50_Weights.verify(weights)weights_backbone=ResNet50_Weights.verify(weights_backbone)ifweightsisnotNone:weights_backbone=Nonenum_classes=_ovewrite_value_param("num_classes",num_classes,len(weights.meta["categories"]))aux_loss=_ovewrite_value_param("aux_loss",aux_loss,True)elifnum_classesisNone:num_classes=21backbone=resnet50(weights=weights_backbone,replace_stride_with_dilation=[False,True,True])model=_deeplabv3_resnet(backbone,num_classes,aux_loss)ifweightsisnotNone:model.load_state_dict(weights.get_state_dict(progress=progress,check_hash=True))returnmodel
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1),weights_backbone=("pretrained_backbone",ResNet101_Weights.IMAGENET1K_V1),)defdeeplabv3_resnet101(*,weights:Optional[DeepLabV3_ResNet101_Weights]=None,progress:bool=True,num_classes:Optional[int]=None,aux_loss:Optional[bool]=None,weights_backbone:Optional[ResNet101_Weights]=ResNet101_Weights.IMAGENET1K_V1,**kwargs:Any,)->DeepLabV3:"""Constructs a DeepLabV3 model with a ResNet-101 backbone. .. betastatus:: segmentation module Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation <https://arxiv.org/abs/1706.05587>`__. Args: weights (:class:`~torchvision.models.segmentation.DeepLabV3_ResNet101_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.segmentation.DeepLabV3_ResNet101_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. num_classes (int, optional): number of output classes of the model (including the background) aux_loss (bool, optional): If True, it uses an auxiliary loss weights_backbone (:class:`~torchvision.models.ResNet101_Weights`, optional): The pretrained weights for the backbone **kwargs: unused .. autoclass:: torchvision.models.segmentation.DeepLabV3_ResNet101_Weights :members: """weights=DeepLabV3_ResNet101_Weights.verify(weights)weights_backbone=ResNet101_Weights.verify(weights_backbone)ifweightsisnotNone:weights_backbone=Nonenum_classes=_ovewrite_value_param("num_classes",num_classes,len(weights.meta["categories"]))aux_loss=_ovewrite_value_param("aux_loss",aux_loss,True)elifnum_classesisNone:num_classes=21backbone=resnet101(weights=weights_backbone,replace_stride_with_dilation=[False,True,True])model=_deeplabv3_resnet(backbone,num_classes,aux_loss)ifweightsisnotNone:model.load_state_dict(weights.get_state_dict(progress=progress,check_hash=True))returnmodel
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1),weights_backbone=("pretrained_backbone",MobileNet_V3_Large_Weights.IMAGENET1K_V1),)defdeeplabv3_mobilenet_v3_large(*,weights:Optional[DeepLabV3_MobileNet_V3_Large_Weights]=None,progress:bool=True,num_classes:Optional[int]=None,aux_loss:Optional[bool]=None,weights_backbone:Optional[MobileNet_V3_Large_Weights]=MobileNet_V3_Large_Weights.IMAGENET1K_V1,**kwargs:Any,)->DeepLabV3:"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone. Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation <https://arxiv.org/abs/1706.05587>`__. Args: weights (:class:`~torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. num_classes (int, optional): number of output classes of the model (including the background) aux_loss (bool, optional): If True, it uses an auxiliary loss weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The pretrained weights for the backbone **kwargs: unused .. autoclass:: torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights :members: """weights=DeepLabV3_MobileNet_V3_Large_Weights.verify(weights)weights_backbone=MobileNet_V3_Large_Weights.verify(weights_backbone)ifweightsisnotNone:weights_backbone=Nonenum_classes=_ovewrite_value_param("num_classes",num_classes,len(weights.meta["categories"]))aux_loss=_ovewrite_value_param("aux_loss",aux_loss,True)elifnum_classesisNone:num_classes=21backbone=mobilenet_v3_large(weights=weights_backbone,dilated=True)model=_deeplabv3_mobilenetv3(backbone,num_classes,aux_loss)ifweightsisnotNone:model.load_state_dict(weights.get_state_dict(progress=progress,check_hash=True))returnmodel
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.