Source code for torchvision.models.detection.ssdlite
importwarningsfromcollectionsimportOrderedDictfromfunctoolsimportpartialfromtypingimportAny,Callable,Dict,List,Optional,Unionimporttorchfromtorchimportnn,Tensorfrom...ops.miscimportConv2dNormActivationfrom...transforms._presetsimportObjectDetectionfrom...utilsimport_log_api_usage_oncefrom..importmobilenetfrom.._apiimportregister_model,Weights,WeightsEnumfrom.._metaimport_COCO_CATEGORIESfrom.._utilsimport_ovewrite_value_param,handle_legacy_interfacefrom..mobilenetv3importmobilenet_v3_large,MobileNet_V3_Large_Weightsfrom.import_utilsasdet_utilsfrom.anchor_utilsimportDefaultBoxGeneratorfrom.backbone_utilsimport_validate_trainable_layersfrom.ssdimportSSD,SSDScoringHead__all__=["SSDLite320_MobileNet_V3_Large_Weights","ssdlite320_mobilenet_v3_large",]# Building blocks of SSDlite as described in section 6.2 of MobileNetV2 paperdef_prediction_block(in_channels:int,out_channels:int,kernel_size:int,norm_layer:Callable[...,nn.Module])->nn.Sequential:returnnn.Sequential(# 3x3 depthwise with stride 1 and padding 1Conv2dNormActivation(in_channels,in_channels,kernel_size=kernel_size,groups=in_channels,norm_layer=norm_layer,activation_layer=nn.ReLU6,),# 1x1 projetion to output channelsnn.Conv2d(in_channels,out_channels,1),)def_extra_block(in_channels:int,out_channels:int,norm_layer:Callable[...,nn.Module])->nn.Sequential:activation=nn.ReLU6intermediate_channels=out_channels//2returnnn.Sequential(# 1x1 projection to half output channelsConv2dNormActivation(in_channels,intermediate_channels,kernel_size=1,norm_layer=norm_layer,activation_layer=activation),# 3x3 depthwise with stride 2 and padding 1Conv2dNormActivation(intermediate_channels,intermediate_channels,kernel_size=3,stride=2,groups=intermediate_channels,norm_layer=norm_layer,activation_layer=activation,),# 1x1 projetion to output channelsConv2dNormActivation(intermediate_channels,out_channels,kernel_size=1,norm_layer=norm_layer,activation_layer=activation),)def_normal_init(conv:nn.Module):forlayerinconv.modules():ifisinstance(layer,nn.Conv2d):torch.nn.init.normal_(layer.weight,mean=0.0,std=0.03)iflayer.biasisnotNone:torch.nn.init.constant_(layer.bias,0.0)classSSDLiteHead(nn.Module):def__init__(self,in_channels:List[int],num_anchors:List[int],num_classes:int,norm_layer:Callable[...,nn.Module]):super().__init__()self.classification_head=SSDLiteClassificationHead(in_channels,num_anchors,num_classes,norm_layer)self.regression_head=SSDLiteRegressionHead(in_channels,num_anchors,norm_layer)defforward(self,x:List[Tensor])->Dict[str,Tensor]:return{"bbox_regression":self.regression_head(x),"cls_logits":self.classification_head(x),}classSSDLiteClassificationHead(SSDScoringHead):def__init__(self,in_channels:List[int],num_anchors:List[int],num_classes:int,norm_layer:Callable[...,nn.Module]):cls_logits=nn.ModuleList()forchannels,anchorsinzip(in_channels,num_anchors):cls_logits.append(_prediction_block(channels,num_classes*anchors,3,norm_layer))_normal_init(cls_logits)super().__init__(cls_logits,num_classes)classSSDLiteRegressionHead(SSDScoringHead):def__init__(self,in_channels:List[int],num_anchors:List[int],norm_layer:Callable[...,nn.Module]):bbox_reg=nn.ModuleList()forchannels,anchorsinzip(in_channels,num_anchors):bbox_reg.append(_prediction_block(channels,4*anchors,3,norm_layer))_normal_init(bbox_reg)super().__init__(bbox_reg,4)classSSDLiteFeatureExtractorMobileNet(nn.Module):def__init__(self,backbone:nn.Module,c4_pos:int,norm_layer:Callable[...,nn.Module],width_mult:float=1.0,min_depth:int=16,):super().__init__()_log_api_usage_once(self)ifbackbone[c4_pos].use_res_connect:raiseValueError("backbone[c4_pos].use_res_connect should be False")self.features=nn.Sequential(# As described in section 6.3 of MobileNetV3 papernn.Sequential(*backbone[:c4_pos],backbone[c4_pos].block[0]),# from start until C4 expansion layernn.Sequential(backbone[c4_pos].block[1:],*backbone[c4_pos+1:]),# from C4 depthwise until end)get_depth=lambdad:max(min_depth,int(d*width_mult))# noqa: E731extra=nn.ModuleList([_extra_block(backbone[-1].out_channels,get_depth(512),norm_layer),_extra_block(get_depth(512),get_depth(256),norm_layer),_extra_block(get_depth(256),get_depth(256),norm_layer),_extra_block(get_depth(256),get_depth(128),norm_layer),])_normal_init(extra)self.extra=extradefforward(self,x:Tensor)->Dict[str,Tensor]:# Get feature maps from backbone and extra. Can't be refactored due to JIT limitations.output=[]forblockinself.features:x=block(x)output.append(x)forblockinself.extra:x=block(x)output.append(x)returnOrderedDict([(str(i),v)fori,vinenumerate(output)])def_mobilenet_extractor(backbone:Union[mobilenet.MobileNetV2,mobilenet.MobileNetV3],trainable_layers:int,norm_layer:Callable[...,nn.Module],):backbone=backbone.features# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.# The first and last blocks are always included because they are the C0 (conv1) and Cn.stage_indices=[0]+[ifori,binenumerate(backbone)ifgetattr(b,"_is_cn",False)]+[len(backbone)-1]num_stages=len(stage_indices)# find the index of the layer from which we won't freezeifnot0<=trainable_layers<=num_stages:raiseValueError("trainable_layers should be in the range [0, {num_stages}], instead got {trainable_layers}")freeze_before=len(backbone)iftrainable_layers==0elsestage_indices[num_stages-trainable_layers]forbinbackbone[:freeze_before]:forparameterinb.parameters():parameter.requires_grad_(False)returnSSDLiteFeatureExtractorMobileNet(backbone,stage_indices[-2],norm_layer)
[docs]classSSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):COCO_V1=Weights(url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth",transforms=ObjectDetection,meta={"num_params":3440060,"categories":_COCO_CATEGORIES,"min_size":(1,1),"recipe":"https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large","_metrics":{"COCO-val2017":{"box_map":21.3,}},"_ops":0.583,"_file_size":13.418,"_docs":"""These weights were produced by following a similar training recipe as on the paper.""",},)DEFAULT=COCO_V1
[docs]@register_model()@handle_legacy_interface(weights=("pretrained",SSDLite320_MobileNet_V3_Large_Weights.COCO_V1),weights_backbone=("pretrained_backbone",MobileNet_V3_Large_Weights.IMAGENET1K_V1),)defssdlite320_mobilenet_v3_large(*,weights:Optional[SSDLite320_MobileNet_V3_Large_Weights]=None,progress:bool=True,num_classes:Optional[int]=None,weights_backbone:Optional[MobileNet_V3_Large_Weights]=MobileNet_V3_Large_Weights.IMAGENET1K_V1,trainable_backbone_layers:Optional[int]=None,norm_layer:Optional[Callable[...,nn.Module]]=None,**kwargs:Any,)->SSD:"""SSDlite model architecture with input size 320x320 and a MobileNetV3 Large backbone, as described at `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`__ and `MobileNetV2: Inverted Residuals and Linear Bottlenecks <https://arxiv.org/abs/1801.04381>`__. .. betastatus:: detection module See :func:`~torchvision.models.detection.ssd300_vgg16` for more details. Example: >>> model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(weights=SSDLite320_MobileNet_V3_Large_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 320, 320), torch.rand(3, 500, 400)] >>> predictions = model(x) Args: weights (:class:`~torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. num_classes (int, optional): number of output classes of the model (including the background). weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The pretrained weights for the backbone. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 6. norm_layer (callable, optional): Module specifying the normalization layer to use. **kwargs: parameters passed to the ``torchvision.models.detection.ssd.SSD`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/ssdlite.py>`_ for more details about this class. .. autoclass:: torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights :members: """weights=SSDLite320_MobileNet_V3_Large_Weights.verify(weights)weights_backbone=MobileNet_V3_Large_Weights.verify(weights_backbone)if"size"inkwargs:warnings.warn("The size of the model is already fixed; ignoring the parameter.")ifweightsisnotNone:weights_backbone=Nonenum_classes=_ovewrite_value_param("num_classes",num_classes,len(weights.meta["categories"]))elifnum_classesisNone:num_classes=91trainable_backbone_layers=_validate_trainable_layers(weightsisnotNoneorweights_backboneisnotNone,trainable_backbone_layers,6,6)# Enable reduced tail if no pretrained backbone is selected. See Table 6 of MobileNetV3 paper.reduce_tail=weights_backboneisNoneifnorm_layerisNone:norm_layer=partial(nn.BatchNorm2d,eps=0.001,momentum=0.03)backbone=mobilenet_v3_large(weights=weights_backbone,progress=progress,norm_layer=norm_layer,reduced_tail=reduce_tail,**kwargs)ifweights_backboneisNone:# Change the default initialization scheme if not pretrained_normal_init(backbone)backbone=_mobilenet_extractor(backbone,trainable_backbone_layers,norm_layer,)size=(320,320)anchor_generator=DefaultBoxGenerator([[2,3]for_inrange(6)],min_ratio=0.2,max_ratio=0.95)out_channels=det_utils.retrieve_out_channels(backbone,size)num_anchors=anchor_generator.num_anchors_per_location()iflen(out_channels)!=len(anchor_generator.aspect_ratios):raiseValueError(f"The length of the output channels from the backbone {len(out_channels)} do not match the length of the anchor generator aspect ratios {len(anchor_generator.aspect_ratios)}")defaults={"score_thresh":0.001,"nms_thresh":0.55,"detections_per_img":300,"topk_candidates":300,# Rescale the input in a way compatible to the backbone:# The following mean/std rescale the data from [0, 1] to [-1, 1]"image_mean":[0.5,0.5,0.5],"image_std":[0.5,0.5,0.5],}kwargs:Any={**defaults,**kwargs}model=SSD(backbone,anchor_generator,size,num_classes,head=SSDLiteHead(out_channels,num_anchors,num_classes,norm_layer),**kwargs,)ifweightsisnotNone:model.load_state_dict(weights.get_state_dict(progress=progress,check_hash=True))returnmodel
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.