Source code for torchvision.models.segmentation.lraspp
fromcollectionsimportOrderedDictfromfunctoolsimportpartialfromtypingimportAny,Dict,Optionalfromtorchimportnn,Tensorfromtorch.nnimportfunctionalasFfrom...transforms._presetsimportSemanticSegmentationfrom...utilsimport_log_api_usage_oncefrom.._apiimportregister_model,Weights,WeightsEnumfrom.._metaimport_VOC_CATEGORIESfrom.._utilsimport_ovewrite_value_param,handle_legacy_interface,IntermediateLayerGetterfrom..mobilenetv3importmobilenet_v3_large,MobileNet_V3_Large_Weights,MobileNetV3__all__=["LRASPP","LRASPP_MobileNet_V3_Large_Weights","lraspp_mobilenet_v3_large"]classLRASPP(nn.Module):""" Implements a Lite R-ASPP Network for semantic segmentation from `"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_. Args: backbone (nn.Module): the network used to compute the features for the model. The backbone should return an OrderedDict[Tensor], with the key being "high" for the high level feature map and "low" for the low level feature map. low_channels (int): the number of channels of the low level features. high_channels (int): the number of channels of the high level features. num_classes (int, optional): number of output classes of the model (including the background). inter_channels (int, optional): the number of channels for intermediate computations. """def__init__(self,backbone:nn.Module,low_channels:int,high_channels:int,num_classes:int,inter_channels:int=128)->None:super().__init__()_log_api_usage_once(self)self.backbone=backboneself.classifier=LRASPPHead(low_channels,high_channels,num_classes,inter_channels)defforward(self,input:Tensor)->Dict[str,Tensor]:features=self.backbone(input)out=self.classifier(features)out=F.interpolate(out,size=input.shape[-2:],mode="bilinear",align_corners=False)result=OrderedDict()result["out"]=outreturnresultclassLRASPPHead(nn.Module):def__init__(self,low_channels:int,high_channels:int,num_classes:int,inter_channels:int)->None:super().__init__()self.cbr=nn.Sequential(nn.Conv2d(high_channels,inter_channels,1,bias=False),nn.BatchNorm2d(inter_channels),nn.ReLU(inplace=True),)self.scale=nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(high_channels,inter_channels,1,bias=False),nn.Sigmoid(),)self.low_classifier=nn.Conv2d(low_channels,num_classes,1)self.high_classifier=nn.Conv2d(inter_channels,num_classes,1)defforward(self,input:Dict[str,Tensor])->Tensor:low=input["low"]high=input["high"]x=self.cbr(high)s=self.scale(high)x=x*sx=F.interpolate(x,size=low.shape[-2:],mode="bilinear",align_corners=False)returnself.low_classifier(low)+self.high_classifier(x)def_lraspp_mobilenetv3(backbone:MobileNetV3,num_classes:int)->LRASPP:backbone=backbone.features# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.# The first and last blocks are always included because they are the C0 (conv1) and Cn.stage_indices=[0]+[ifori,binenumerate(backbone)ifgetattr(b,"_is_cn",False)]+[len(backbone)-1]low_pos=stage_indices[-4]# use C2 here which has output_stride = 8high_pos=stage_indices[-1]# use C5 which has output_stride = 16low_channels=backbone[low_pos].out_channelshigh_channels=backbone[high_pos].out_channelsbackbone=IntermediateLayerGetter(backbone,return_layers={str(low_pos):"low",str(high_pos):"high"})returnLRASPP(backbone,low_channels,high_channels,num_classes)
[docs]classLRASPP_MobileNet_V3_Large_Weights(WeightsEnum):COCO_WITH_VOC_LABELS_V1=Weights(url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",transforms=partial(SemanticSegmentation,resize_size=520),meta={"num_params":3221538,"categories":_VOC_CATEGORIES,"min_size":(1,1),"recipe":"https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large","_metrics":{"COCO-val2017-VOC-labels":{"miou":57.9,"pixel_acc":91.2,}},"_ops":2.086,"_file_size":12.49,"_docs":""" These weights were trained on a subset of COCO, using only the 20 categories that are present in the Pascal VOC dataset. """,},)DEFAULT=COCO_WITH_VOC_LABELS_V1
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1),weights_backbone=("pretrained_backbone",MobileNet_V3_Large_Weights.IMAGENET1K_V1),)deflraspp_mobilenet_v3_large(*,weights:Optional[LRASPP_MobileNet_V3_Large_Weights]=None,progress:bool=True,num_classes:Optional[int]=None,weights_backbone:Optional[MobileNet_V3_Large_Weights]=MobileNet_V3_Large_Weights.IMAGENET1K_V1,**kwargs:Any,)->LRASPP:"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone from `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_ paper. .. betastatus:: segmentation module Args: weights (:class:`~torchvision.models.segmentation.LRASPP_MobileNet_V3_Large_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.segmentation.LRASPP_MobileNet_V3_Large_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. num_classes (int, optional): number of output classes of the model (including the background). aux_loss (bool, optional): If True, it uses an auxiliary loss. weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The pretrained weights for the backbone. **kwargs: parameters passed to the ``torchvision.models.segmentation.LRASPP`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/segmentation/lraspp.py>`_ for more details about this class. .. autoclass:: torchvision.models.segmentation.LRASPP_MobileNet_V3_Large_Weights :members: """ifkwargs.pop("aux_loss",False):raiseNotImplementedError("This model does not use auxiliary loss")weights=LRASPP_MobileNet_V3_Large_Weights.verify(weights)weights_backbone=MobileNet_V3_Large_Weights.verify(weights_backbone)ifweightsisnotNone:weights_backbone=Nonenum_classes=_ovewrite_value_param("num_classes",num_classes,len(weights.meta["categories"]))elifnum_classesisNone:num_classes=21backbone=mobilenet_v3_large(weights=weights_backbone,dilated=True)model=_lraspp_mobilenetv3(backbone,num_classes)ifweightsisnotNone:model.load_state_dict(weights.get_state_dict(progress=progress,check_hash=True))returnmodel
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.