Source code for torchvision.models.segmentation.lraspp
fromcollectionsimportOrderedDictfromtypingimportAny,Dictfromtorchimportnn,Tensorfromtorch.nnimportfunctionalasFfrom...utilsimport_log_api_usage_oncefrom..importmobilenetv3from.._utilsimportIntermediateLayerGetterfrom._utilsimport_load_weights__all__=["LRASPP","lraspp_mobilenet_v3_large"]model_urls={"lraspp_mobilenet_v3_large_coco":"https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",}classLRASPP(nn.Module):""" Implements a Lite R-ASPP Network for semantic segmentation from `"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_. Args: backbone (nn.Module): the network used to compute the features for the model. The backbone should return an OrderedDict[Tensor], with the key being "high" for the high level feature map and "low" for the low level feature map. low_channels (int): the number of channels of the low level features. high_channels (int): the number of channels of the high level features. num_classes (int): number of output classes of the model (including the background). inter_channels (int, optional): the number of channels for intermediate computations. """def__init__(self,backbone:nn.Module,low_channels:int,high_channels:int,num_classes:int,inter_channels:int=128)->None:super().__init__()_log_api_usage_once(self)self.backbone=backboneself.classifier=LRASPPHead(low_channels,high_channels,num_classes,inter_channels)defforward(self,input:Tensor)->Dict[str,Tensor]:features=self.backbone(input)out=self.classifier(features)out=F.interpolate(out,size=input.shape[-2:],mode="bilinear",align_corners=False)result=OrderedDict()result["out"]=outreturnresultclassLRASPPHead(nn.Module):def__init__(self,low_channels:int,high_channels:int,num_classes:int,inter_channels:int)->None:super().__init__()self.cbr=nn.Sequential(nn.Conv2d(high_channels,inter_channels,1,bias=False),nn.BatchNorm2d(inter_channels),nn.ReLU(inplace=True),)self.scale=nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(high_channels,inter_channels,1,bias=False),nn.Sigmoid(),)self.low_classifier=nn.Conv2d(low_channels,num_classes,1)self.high_classifier=nn.Conv2d(inter_channels,num_classes,1)defforward(self,input:Dict[str,Tensor])->Tensor:low=input["low"]high=input["high"]x=self.cbr(high)s=self.scale(high)x=x*sx=F.interpolate(x,size=low.shape[-2:],mode="bilinear",align_corners=False)returnself.low_classifier(low)+self.high_classifier(x)def_lraspp_mobilenetv3(backbone:mobilenetv3.MobileNetV3,num_classes:int)->LRASPP:backbone=backbone.features# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.# The first and last blocks are always included because they are the C0 (conv1) and Cn.stage_indices=[0]+[ifori,binenumerate(backbone)ifgetattr(b,"_is_cn",False)]+[len(backbone)-1]low_pos=stage_indices[-4]# use C2 here which has output_stride = 8high_pos=stage_indices[-1]# use C5 which has output_stride = 16low_channels=backbone[low_pos].out_channelshigh_channels=backbone[high_pos].out_channelsbackbone=IntermediateLayerGetter(backbone,return_layers={str(low_pos):"low",str(high_pos):"high"})returnLRASPP(backbone,low_channels,high_channels,num_classes)
[docs]deflraspp_mobilenet_v3_large(pretrained:bool=False,progress:bool=True,num_classes:int=21,pretrained_backbone:bool=True,**kwargs:Any,)->LRASPP:"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone. Args: pretrained (bool): If True, returns a model pre-trained on COCO train2017 which contains the same classes as Pascal VOC progress (bool): If True, displays a progress bar of the download to stderr num_classes (int): number of output classes of the model (including the background) pretrained_backbone (bool): If True, the backbone will be pre-trained. """ifkwargs.pop("aux_loss",False):raiseNotImplementedError("This model does not use auxiliary loss")ifpretrained:pretrained_backbone=Falsebackbone=mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone,dilated=True)model=_lraspp_mobilenetv3(backbone,num_classes)ifpretrained:arch="lraspp_mobilenet_v3_large_coco"_load_weights(arch,model,model_urls.get(arch,None),progress)returnmodel
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.