Source code for torchvision.models.detection.ssdlite
importwarningsfromcollectionsimportOrderedDictfromfunctoolsimportpartialfromtypingimportAny,Callable,Dict,List,Optional,Unionimporttorchfromtorchimportnn,Tensorfrom..._internally_replaced_utilsimportload_state_dict_from_urlfrom...ops.miscimportConvNormActivationfrom...utilsimport_log_api_usage_oncefrom..importmobilenetfrom.import_utilsasdet_utilsfrom.anchor_utilsimportDefaultBoxGeneratorfrom.backbone_utilsimport_validate_trainable_layersfrom.ssdimportSSD,SSDScoringHead__all__=["ssdlite320_mobilenet_v3_large"]model_urls={"ssdlite320_mobilenet_v3_large_coco":"https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth"}# Building blocks of SSDlite as described in section 6.2 of MobileNetV2 paperdef_prediction_block(in_channels:int,out_channels:int,kernel_size:int,norm_layer:Callable[...,nn.Module])->nn.Sequential:returnnn.Sequential(# 3x3 depthwise with stride 1 and padding 1ConvNormActivation(in_channels,in_channels,kernel_size=kernel_size,groups=in_channels,norm_layer=norm_layer,activation_layer=nn.ReLU6,),# 1x1 projetion to output channelsnn.Conv2d(in_channels,out_channels,1),)def_extra_block(in_channels:int,out_channels:int,norm_layer:Callable[...,nn.Module])->nn.Sequential:activation=nn.ReLU6intermediate_channels=out_channels//2returnnn.Sequential(# 1x1 projection to half output channelsConvNormActivation(in_channels,intermediate_channels,kernel_size=1,norm_layer=norm_layer,activation_layer=activation),# 3x3 depthwise with stride 2 and padding 1ConvNormActivation(intermediate_channels,intermediate_channels,kernel_size=3,stride=2,groups=intermediate_channels,norm_layer=norm_layer,activation_layer=activation,),# 1x1 projetion to output channelsConvNormActivation(intermediate_channels,out_channels,kernel_size=1,norm_layer=norm_layer,activation_layer=activation),)def_normal_init(conv:nn.Module):forlayerinconv.modules():ifisinstance(layer,nn.Conv2d):torch.nn.init.normal_(layer.weight,mean=0.0,std=0.03)iflayer.biasisnotNone:torch.nn.init.constant_(layer.bias,0.0)classSSDLiteHead(nn.Module):def__init__(self,in_channels:List[int],num_anchors:List[int],num_classes:int,norm_layer:Callable[...,nn.Module]):super().__init__()self.classification_head=SSDLiteClassificationHead(in_channels,num_anchors,num_classes,norm_layer)self.regression_head=SSDLiteRegressionHead(in_channels,num_anchors,norm_layer)defforward(self,x:List[Tensor])->Dict[str,Tensor]:return{"bbox_regression":self.regression_head(x),"cls_logits":self.classification_head(x),}classSSDLiteClassificationHead(SSDScoringHead):def__init__(self,in_channels:List[int],num_anchors:List[int],num_classes:int,norm_layer:Callable[...,nn.Module]):cls_logits=nn.ModuleList()forchannels,anchorsinzip(in_channels,num_anchors):cls_logits.append(_prediction_block(channels,num_classes*anchors,3,norm_layer))_normal_init(cls_logits)super().__init__(cls_logits,num_classes)classSSDLiteRegressionHead(SSDScoringHead):def__init__(self,in_channels:List[int],num_anchors:List[int],norm_layer:Callable[...,nn.Module]):bbox_reg=nn.ModuleList()forchannels,anchorsinzip(in_channels,num_anchors):bbox_reg.append(_prediction_block(channels,4*anchors,3,norm_layer))_normal_init(bbox_reg)super().__init__(bbox_reg,4)classSSDLiteFeatureExtractorMobileNet(nn.Module):def__init__(self,backbone:nn.Module,c4_pos:int,norm_layer:Callable[...,nn.Module],width_mult:float=1.0,min_depth:int=16,):super().__init__()_log_api_usage_once(self)assertnotbackbone[c4_pos].use_res_connectself.features=nn.Sequential(# As described in section 6.3 of MobileNetV3 papernn.Sequential(*backbone[:c4_pos],backbone[c4_pos].block[0]),# from start until C4 expansion layernn.Sequential(backbone[c4_pos].block[1:],*backbone[c4_pos+1:]),# from C4 depthwise until end)get_depth=lambdad:max(min_depth,int(d*width_mult))# noqa: E731extra=nn.ModuleList([_extra_block(backbone[-1].out_channels,get_depth(512),norm_layer),_extra_block(get_depth(512),get_depth(256),norm_layer),_extra_block(get_depth(256),get_depth(256),norm_layer),_extra_block(get_depth(256),get_depth(128),norm_layer),])_normal_init(extra)self.extra=extradefforward(self,x:Tensor)->Dict[str,Tensor]:# Get feature maps from backbone and extra. Can't be refactored due to JIT limitations.output=[]forblockinself.features:x=block(x)output.append(x)forblockinself.extra:x=block(x)output.append(x)returnOrderedDict([(str(i),v)fori,vinenumerate(output)])def_mobilenet_extractor(backbone:Union[mobilenet.MobileNetV2,mobilenet.MobileNetV3],trainable_layers:int,norm_layer:Callable[...,nn.Module],):backbone=backbone.features# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.# The first and last blocks are always included because they are the C0 (conv1) and Cn.stage_indices=[0]+[ifori,binenumerate(backbone)ifgetattr(b,"_is_cn",False)]+[len(backbone)-1]num_stages=len(stage_indices)# find the index of the layer from which we wont freezeassert0<=trainable_layers<=num_stagesfreeze_before=len(backbone)iftrainable_layers==0elsestage_indices[num_stages-trainable_layers]forbinbackbone[:freeze_before]:forparameterinb.parameters():parameter.requires_grad_(False)returnSSDLiteFeatureExtractorMobileNet(backbone,stage_indices[-2],norm_layer)
[docs]defssdlite320_mobilenet_v3_large(pretrained:bool=False,progress:bool=True,num_classes:int=91,pretrained_backbone:bool=False,trainable_backbone_layers:Optional[int]=None,norm_layer:Optional[Callable[...,nn.Module]]=None,**kwargs:Any,):"""Constructs an SSDlite model with input size 320x320 and a MobileNetV3 Large backbone, as described at `"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_ and `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_. See :func:`~torchvision.models.detection.ssd300_vgg16` for more details. Example: >>> model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=True) >>> model.eval() >>> x = [torch.rand(3, 320, 320), torch.rand(3, 500, 400)] >>> predictions = model(x) Args: pretrained (bool): If True, returns a model pre-trained on COCO train2017 progress (bool): If True, displays a progress bar of the download to stderr num_classes (int): number of output classes of the model (including the background) pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 6. norm_layer (callable, optional): Module specifying the normalization layer to use. """if"size"inkwargs:warnings.warn("The size of the model is already fixed; ignoring the argument.")trainable_backbone_layers=_validate_trainable_layers(pretrainedorpretrained_backbone,trainable_backbone_layers,6,6)ifpretrained:pretrained_backbone=False# Enable reduced tail if no pretrained backbone is selected. See Table 6 of MobileNetV3 paper.reduce_tail=notpretrained_backboneifnorm_layerisNone:norm_layer=partial(nn.BatchNorm2d,eps=0.001,momentum=0.03)backbone=mobilenet.mobilenet_v3_large(pretrained=pretrained_backbone,progress=progress,norm_layer=norm_layer,reduced_tail=reduce_tail,**kwargs)ifnotpretrained_backbone:# Change the default initialization scheme if not pretrained_normal_init(backbone)backbone=_mobilenet_extractor(backbone,trainable_backbone_layers,norm_layer,)size=(320,320)anchor_generator=DefaultBoxGenerator([[2,3]for_inrange(6)],min_ratio=0.2,max_ratio=0.95)out_channels=det_utils.retrieve_out_channels(backbone,size)num_anchors=anchor_generator.num_anchors_per_location()assertlen(out_channels)==len(anchor_generator.aspect_ratios)defaults={"score_thresh":0.001,"nms_thresh":0.55,"detections_per_img":300,"topk_candidates":300,# Rescale the input in a way compatible to the backbone:# The following mean/std rescale the data from [0, 1] to [-1, 1]"image_mean":[0.5,0.5,0.5],"image_std":[0.5,0.5,0.5],}kwargs={**defaults,**kwargs}model=SSD(backbone,anchor_generator,size,num_classes,head=SSDLiteHead(out_channels,num_anchors,num_classes,norm_layer),**kwargs,)ifpretrained:weights_name="ssdlite320_mobilenet_v3_large_coco"ifmodel_urls.get(weights_name,None)isNone:raiseValueError(f"No checkpoint is available for model {weights_name}")state_dict=load_state_dict_from_url(model_urls[weights_name],progress=progress)model.load_state_dict(state_dict)returnmodel
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.