Source code for torchvision.models.detection.keypoint_rcnn
importtorchfromtorchimportnnfromtorchvision.opsimportMultiScaleRoIAlignfrom..._internally_replaced_utilsimportload_state_dict_from_urlfrom...opsimportmiscasmisc_nn_opsfrom..resnetimportresnet50from._utilsimportoverwrite_epsfrom.backbone_utilsimport_resnet_fpn_extractor,_validate_trainable_layersfrom.faster_rcnnimportFasterRCNN__all__=["KeypointRCNN","keypointrcnn_resnet50_fpn"]classKeypointRCNN(FasterRCNN):""" Implements Keypoint R-CNN. The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each image, and should be in 0-1 range. Different images can have different sizes. The behavior of the model changes depending if it is in training or evaluation mode. During training, the model expects both the input tensors, as well as a targets (list of dictionary), containing: - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. - labels (Int64Tensor[N]): the class label for each ground-truth box - keypoints (FloatTensor[N, K, 3]): the K keypoints location for each of the N instances, in the format [x, y, visibility], where visibility=0 means that the keypoint is not visible. The model returns a Dict[Tensor] during training, containing the classification and regression losses for both the RPN and the R-CNN, and the keypoint loss. During inference, the model requires only the input tensors, and returns the post-processed predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as follows: - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. - labels (Int64Tensor[N]): the predicted labels for each image - scores (Tensor[N]): the scores or each prediction - keypoints (FloatTensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format. Args: backbone (nn.Module): the network used to compute the features for the model. It should contain a out_channels attribute, which indicates the number of output channels that each feature map has (and it should be the same for all feature maps). The backbone should return a single Tensor or and OrderedDict[Tensor]. num_classes (int): number of output classes of the model (including the background). If box_predictor is specified, num_classes should be None. min_size (int): minimum size of the image to be rescaled before feeding it to the backbone max_size (int): maximum size of the image to be rescaled before feeding it to the backbone image_mean (Tuple[float, float, float]): mean values used for input normalization. They are generally the mean values of the dataset on which the backbone has been trained on image_std (Tuple[float, float, float]): std values used for input normalization. They are generally the std values of the dataset on which the backbone has been trained on rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature maps. rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be considered as positive during training of the RPN. rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be considered as negative during training of the RPN. rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN for computing the loss rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training of the RPN rpn_score_thresh (float): during inference, only return proposals with a classification score greater than rpn_score_thresh box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in the locations indicated by the bounding boxes box_head (nn.Module): module that takes the cropped feature maps as input box_predictor (nn.Module): module that takes the output of box_head and returns the classification logits and box regression deltas. box_score_thresh (float): during inference, only return proposals with a classification score greater than box_score_thresh box_nms_thresh (float): NMS threshold for the prediction head. Used during inference box_detections_per_img (int): maximum number of detections per image, for all classes. box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be considered as positive during training of the classification head box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be considered as negative during training of the classification head box_batch_size_per_image (int): number of proposals that are sampled during training of the classification head box_positive_fraction (float): proportion of positive proposals in a mini-batch during training of the classification head bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the bounding boxes keypoint_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in the locations indicated by the bounding boxes, which will be used for the keypoint head. keypoint_head (nn.Module): module that takes the cropped feature maps as input keypoint_predictor (nn.Module): module that takes the output of the keypoint_head and returns the heatmap logits Example:: >>> import torch >>> import torchvision >>> from torchvision.models.detection import KeypointRCNN >>> from torchvision.models.detection.anchor_utils import AnchorGenerator >>> >>> # load a pre-trained model for classification and return >>> # only the features >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features >>> # KeypointRCNN needs to know the number of >>> # output channels in a backbone. For mobilenet_v2, it's 1280 >>> # so we need to add it here >>> backbone.out_channels = 1280 >>> >>> # let's make the RPN generate 5 x 3 anchors per spatial >>> # location, with 5 different sizes and 3 different aspect >>> # ratios. We have a Tuple[Tuple[int]] because each feature >>> # map could potentially have different sizes and >>> # aspect ratios >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),), >>> aspect_ratios=((0.5, 1.0, 2.0),)) >>> >>> # let's define what are the feature maps that we will >>> # use to perform the region of interest cropping, as well as >>> # the size of the crop after rescaling. >>> # if your backbone returns a Tensor, featmap_names is expected to >>> # be ['0']. More generally, the backbone should return an >>> # OrderedDict[Tensor], and in featmap_names you can choose which >>> # feature maps to use. >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], >>> output_size=7, >>> sampling_ratio=2) >>> >>> keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], >>> output_size=14, >>> sampling_ratio=2) >>> # put the pieces together inside a KeypointRCNN model >>> model = KeypointRCNN(backbone, >>> num_classes=2, >>> rpn_anchor_generator=anchor_generator, >>> box_roi_pool=roi_pooler, >>> keypoint_roi_pool=keypoint_roi_pooler) >>> model.eval() >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) """def__init__(self,backbone,num_classes=None,# transform parametersmin_size=None,max_size=1333,image_mean=None,image_std=None,# RPN parametersrpn_anchor_generator=None,rpn_head=None,rpn_pre_nms_top_n_train=2000,rpn_pre_nms_top_n_test=1000,rpn_post_nms_top_n_train=2000,rpn_post_nms_top_n_test=1000,rpn_nms_thresh=0.7,rpn_fg_iou_thresh=0.7,rpn_bg_iou_thresh=0.3,rpn_batch_size_per_image=256,rpn_positive_fraction=0.5,rpn_score_thresh=0.0,# Box parametersbox_roi_pool=None,box_head=None,box_predictor=None,box_score_thresh=0.05,box_nms_thresh=0.5,box_detections_per_img=100,box_fg_iou_thresh=0.5,box_bg_iou_thresh=0.5,box_batch_size_per_image=512,box_positive_fraction=0.25,bbox_reg_weights=None,# keypoint parameterskeypoint_roi_pool=None,keypoint_head=None,keypoint_predictor=None,num_keypoints=None,):assertisinstance(keypoint_roi_pool,(MultiScaleRoIAlign,type(None)))ifmin_sizeisNone:min_size=(640,672,704,736,768,800)ifnum_keypointsisnotNone:ifkeypoint_predictorisnotNone:raiseValueError("num_keypoints should be None when keypoint_predictor is specified")else:num_keypoints=17out_channels=backbone.out_channelsifkeypoint_roi_poolisNone:keypoint_roi_pool=MultiScaleRoIAlign(featmap_names=["0","1","2","3"],output_size=14,sampling_ratio=2)ifkeypoint_headisNone:keypoint_layers=tuple(512for_inrange(8))keypoint_head=KeypointRCNNHeads(out_channels,keypoint_layers)ifkeypoint_predictorisNone:keypoint_dim_reduced=512# == keypoint_layers[-1]keypoint_predictor=KeypointRCNNPredictor(keypoint_dim_reduced,num_keypoints)super().__init__(backbone,num_classes,# transform parametersmin_size,max_size,image_mean,image_std,# RPN-specific parametersrpn_anchor_generator,rpn_head,rpn_pre_nms_top_n_train,rpn_pre_nms_top_n_test,rpn_post_nms_top_n_train,rpn_post_nms_top_n_test,rpn_nms_thresh,rpn_fg_iou_thresh,rpn_bg_iou_thresh,rpn_batch_size_per_image,rpn_positive_fraction,rpn_score_thresh,# Box parametersbox_roi_pool,box_head,box_predictor,box_score_thresh,box_nms_thresh,box_detections_per_img,box_fg_iou_thresh,box_bg_iou_thresh,box_batch_size_per_image,box_positive_fraction,bbox_reg_weights,)self.roi_heads.keypoint_roi_pool=keypoint_roi_poolself.roi_heads.keypoint_head=keypoint_headself.roi_heads.keypoint_predictor=keypoint_predictorclassKeypointRCNNHeads(nn.Sequential):def__init__(self,in_channels,layers):d=[]next_feature=in_channelsforout_channelsinlayers:d.append(nn.Conv2d(next_feature,out_channels,3,stride=1,padding=1))d.append(nn.ReLU(inplace=True))next_feature=out_channelssuper().__init__(*d)forminself.children():ifisinstance(m,nn.Conv2d):nn.init.kaiming_normal_(m.weight,mode="fan_out",nonlinearity="relu")nn.init.constant_(m.bias,0)classKeypointRCNNPredictor(nn.Module):def__init__(self,in_channels,num_keypoints):super().__init__()input_features=in_channelsdeconv_kernel=4self.kps_score_lowres=nn.ConvTranspose2d(input_features,num_keypoints,deconv_kernel,stride=2,padding=deconv_kernel//2-1,)nn.init.kaiming_normal_(self.kps_score_lowres.weight,mode="fan_out",nonlinearity="relu")nn.init.constant_(self.kps_score_lowres.bias,0)self.up_scale=2self.out_channels=num_keypointsdefforward(self,x):x=self.kps_score_lowres(x)returntorch.nn.functional.interpolate(x,scale_factor=float(self.up_scale),mode="bilinear",align_corners=False,recompute_scale_factor=False)model_urls={# legacy model for BC reasons, see https://github.com/pytorch/vision/issues/1606"keypointrcnn_resnet50_fpn_coco_legacy":"https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth","keypointrcnn_resnet50_fpn_coco":"https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",}
[docs]defkeypointrcnn_resnet50_fpn(pretrained=False,progress=True,num_classes=2,num_keypoints=17,pretrained_backbone=True,trainable_backbone_layers=None,**kwargs,):""" Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone. Reference: `"Mask R-CNN" <https://arxiv.org/abs/1703.06870>`_. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each image, and should be in ``0-1`` range. Different images can have different sizes. The behavior of the model changes depending if it is in training or evaluation mode. During training, the model expects both the input tensors, as well as a targets (list of dictionary), containing: - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. - labels (``Int64Tensor[N]``): the class label for each ground-truth box - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible. The model returns a ``Dict[Tensor]`` during training, containing the classification and regression losses for both the RPN and the R-CNN, and the keypoint loss. During inference, the model requires only the input tensors, and returns the post-processed predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as follows, where ``N`` is the number of detected instances: - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. - labels (``Int64Tensor[N]``): the predicted labels for each instance - scores (``Tensor[N]``): the scores or each instance - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format. For more details on the output, you may refer to :ref:`instance_seg_output`. Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size. Example:: >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=True) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) >>> >>> # optionally, if you want to export the model to ONNX: >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11) Args: pretrained (bool): If True, returns a model pre-trained on COCO train2017 progress (bool): If True, displays a progress bar of the download to stderr num_classes (int): number of output classes of the model (including the background) num_keypoints (int): number of keypoints, default 17 pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """trainable_backbone_layers=_validate_trainable_layers(pretrainedorpretrained_backbone,trainable_backbone_layers,5,3)ifpretrained:# no need to download the backbone if pretrained is setpretrained_backbone=Falsebackbone=resnet50(pretrained=pretrained_backbone,progress=progress,norm_layer=misc_nn_ops.FrozenBatchNorm2d)backbone=_resnet_fpn_extractor(backbone,trainable_backbone_layers)model=KeypointRCNN(backbone,num_classes,num_keypoints=num_keypoints,**kwargs)ifpretrained:key="keypointrcnn_resnet50_fpn_coco"ifpretrained=="legacy":key+="_legacy"state_dict=load_state_dict_from_url(model_urls[key],progress=progress)model.load_state_dict(state_dict)overwrite_eps(model,0.0)returnmodel
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.