Shortcuts

Source code for torchvision.models.detection.faster_rcnn

from typing import Any, Callable, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import nn
from torchvision.ops import MultiScaleRoIAlign

from ...ops import misc as misc_nn_ops
from ...transforms._presets import ObjectDetection
from .._api import register_model, Weights, WeightsEnum
from .._meta import _COCO_CATEGORIES
from .._utils import _ovewrite_value_param, handle_legacy_interface
from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
from ..resnet import resnet50, ResNet50_Weights
from ._utils import overwrite_eps
from .anchor_utils import AnchorGenerator
from .backbone_utils import _mobilenet_extractor, _resnet_fpn_extractor, _validate_trainable_layers
from .generalized_rcnn import GeneralizedRCNN
from .roi_heads import RoIHeads
from .rpn import RegionProposalNetwork, RPNHead
from .transform import GeneralizedRCNNTransform


__all__ = [
    "FasterRCNN",
    "FasterRCNN_ResNet50_FPN_Weights",
    "FasterRCNN_ResNet50_FPN_V2_Weights",
    "FasterRCNN_MobileNet_V3_Large_FPN_Weights",
    "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights",
    "fasterrcnn_resnet50_fpn",
    "fasterrcnn_resnet50_fpn_v2",
    "fasterrcnn_mobilenet_v3_large_fpn",
    "fasterrcnn_mobilenet_v3_large_320_fpn",
]


def _default_anchorgen():
    anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
    return AnchorGenerator(anchor_sizes, aspect_ratios)


class FasterRCNN(GeneralizedRCNN):
    """
    Implements Faster R-CNN.

    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
    image, and should be in 0-1 range. Different images can have different sizes.

    The behavior of the model changes depending if it is in training or evaluation mode.

    During training, the model expects both the input tensors, as well as a targets (list of dictionary),
    containing:
        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
        - labels (Int64Tensor[N]): the class label for each ground-truth box

    The model returns a Dict[Tensor] during training, containing the classification and regression
    losses for both the RPN and the R-CNN.

    During inference, the model requires only the input tensors, and returns the post-processed
    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
    follows:
        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
        - labels (Int64Tensor[N]): the predicted labels for each image
        - scores (Tensor[N]): the scores or each prediction

    Args:
        backbone (nn.Module): the network used to compute the features for the model.
            It should contain a out_channels attribute, which indicates the number of output
            channels that each feature map has (and it should be the same for all feature maps).
            The backbone should return a single Tensor or and OrderedDict[Tensor].
        num_classes (int): number of output classes of the model (including the background).
            If box_predictor is specified, num_classes should be None.
        min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
        max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
        image_mean (Tuple[float, float, float]): mean values used for input normalization.
            They are generally the mean values of the dataset on which the backbone has been trained
            on
        image_std (Tuple[float, float, float]): std values used for input normalization.
            They are generally the std values of the dataset on which the backbone has been trained on
        rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
            maps.
        rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
        rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
        rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
        rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
        rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
        rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
        rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
            considered as positive during training of the RPN.
        rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
            considered as negative during training of the RPN.
        rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
            for computing the loss
        rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
            of the RPN
        rpn_score_thresh (float): during inference, only return proposals with a classification score
            greater than rpn_score_thresh
        box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
            the locations indicated by the bounding boxes
        box_head (nn.Module): module that takes the cropped feature maps as input
        box_predictor (nn.Module): module that takes the output of box_head and returns the
            classification logits and box regression deltas.
        box_score_thresh (float): during inference, only return proposals with a classification score
            greater than box_score_thresh
        box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
        box_detections_per_img (int): maximum number of detections per image, for all classes.
        box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
            considered as positive during training of the classification head
        box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
            considered as negative during training of the classification head
        box_batch_size_per_image (int): number of proposals that are sampled during training of the
            classification head
        box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
            of the classification head
        bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
            bounding boxes

    Example::

        >>> import torch
        >>> import torchvision
        >>> from torchvision.models.detection import FasterRCNN
        >>> from torchvision.models.detection.rpn import AnchorGenerator
        >>> # load a pre-trained model for classification and return
        >>> # only the features
        >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
        >>> # FasterRCNN needs to know the number of
        >>> # output channels in a backbone. For mobilenet_v2, it's 1280
        >>> # so we need to add it here
        >>> backbone.out_channels = 1280
        >>>
        >>> # let's make the RPN generate 5 x 3 anchors per spatial
        >>> # location, with 5 different sizes and 3 different aspect
        >>> # ratios. We have a Tuple[Tuple[int]] because each feature
        >>> # map could potentially have different sizes and
        >>> # aspect ratios
        >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
        >>>                                    aspect_ratios=((0.5, 1.0, 2.0),))
        >>>
        >>> # let's define what are the feature maps that we will
        >>> # use to perform the region of interest cropping, as well as
        >>> # the size of the crop after rescaling.
        >>> # if your backbone returns a Tensor, featmap_names is expected to
        >>> # be ['0']. More generally, the backbone should return an
        >>> # OrderedDict[Tensor], and in featmap_names you can choose which
        >>> # feature maps to use.
        >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
        >>>                                                 output_size=7,
        >>>                                                 sampling_ratio=2)
        >>>
        >>> # put the pieces together inside a FasterRCNN model
        >>> model = FasterRCNN(backbone,
        >>>                    num_classes=2,
        >>>                    rpn_anchor_generator=anchor_generator,
        >>>                    box_roi_pool=roi_pooler)
        >>> model.eval()
        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
        >>> predictions = model(x)
    """

    def __init__(
        self,
        backbone,
        num_classes=None,
        # transform parameters
        min_size=800,
        max_size=1333,
        image_mean=None,
        image_std=None,
        # RPN parameters
        rpn_anchor_generator=None,
        rpn_head=None,
        rpn_pre_nms_top_n_train=2000,
        rpn_pre_nms_top_n_test=1000,
        rpn_post_nms_top_n_train=2000,
        rpn_post_nms_top_n_test=1000,
        rpn_nms_thresh=0.7,
        rpn_fg_iou_thresh=0.7,
        rpn_bg_iou_thresh=0.3,
        rpn_batch_size_per_image=256,
        rpn_positive_fraction=0.5,
        rpn_score_thresh=0.0,
        # Box parameters
        box_roi_pool=None,
        box_head=None,
        box_predictor=None,
        box_score_thresh=0.05,
        box_nms_thresh=0.5,
        box_detections_per_img=100,
        box_fg_iou_thresh=0.5,
        box_bg_iou_thresh=0.5,
        box_batch_size_per_image=512,
        box_positive_fraction=0.25,
        bbox_reg_weights=None,
        **kwargs,
    ):

        if not hasattr(backbone, "out_channels"):
            raise ValueError(
                "backbone should contain an attribute out_channels "
                "specifying the number of output channels (assumed to be the "
                "same for all the levels)"
            )

        if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
            raise TypeError(
                f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
            )
        if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
            raise TypeError(
                f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
            )

        if num_classes is not None:
            if box_predictor is not None:
                raise ValueError("num_classes should be None when box_predictor is specified")
        else:
            if box_predictor is None:
                raise ValueError("num_classes should not be None when box_predictor is not specified")

        out_channels = backbone.out_channels

        if rpn_anchor_generator is None:
            rpn_anchor_generator = _default_anchorgen()
        if rpn_head is None:
            rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])

        rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
        rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)

        rpn = RegionProposalNetwork(
            rpn_anchor_generator,
            rpn_head,
            rpn_fg_iou_thresh,
            rpn_bg_iou_thresh,
            rpn_batch_size_per_image,
            rpn_positive_fraction,
            rpn_pre_nms_top_n,
            rpn_post_nms_top_n,
            rpn_nms_thresh,
            score_thresh=rpn_score_thresh,
        )

        if box_roi_pool is None:
            box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)

        if box_head is None:
            resolution = box_roi_pool.output_size[0]
            representation_size = 1024
            box_head = TwoMLPHead(out_channels * resolution**2, representation_size)

        if box_predictor is None:
            representation_size = 1024
            box_predictor = FastRCNNPredictor(representation_size, num_classes)

        roi_heads = RoIHeads(
            # Box
            box_roi_pool,
            box_head,
            box_predictor,
            box_fg_iou_thresh,
            box_bg_iou_thresh,
            box_batch_size_per_image,
            box_positive_fraction,
            bbox_reg_weights,
            box_score_thresh,
            box_nms_thresh,
            box_detections_per_img,
        )

        if image_mean is None:
            image_mean = [0.485, 0.456, 0.406]
        if image_std is None:
            image_std = [0.229, 0.224, 0.225]
        transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)

        super().__init__(backbone, rpn, roi_heads, transform)


class TwoMLPHead(nn.Module):
    """
    Standard heads for FPN-based models

    Args:
        in_channels (int): number of input channels
        representation_size (int): size of the intermediate representation
    """

    def __init__(self, in_channels, representation_size):
        super().__init__()

        self.fc6 = nn.Linear(in_channels, representation_size)
        self.fc7 = nn.Linear(representation_size, representation_size)

    def forward(self, x):
        x = x.flatten(start_dim=1)

        x = F.relu(self.fc6(x))
        x = F.relu(self.fc7(x))

        return x


class FastRCNNConvFCHead(nn.Sequential):
    def __init__(
        self,
        input_size: Tuple[int, int, int],
        conv_layers: List[int],
        fc_layers: List[int],
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ):
        """
        Args:
            input_size (Tuple[int, int, int]): the input size in CHW format.
            conv_layers (list): feature dimensions of each Convolution layer
            fc_layers (list): feature dimensions of each FCN layer
            norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
        """
        in_channels, in_height, in_width = input_size

        blocks = []
        previous_channels = in_channels
        for current_channels in conv_layers:
            blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
            previous_channels = current_channels
        blocks.append(nn.Flatten())
        previous_channels = previous_channels * in_height * in_width
        for current_channels in fc_layers:
            blocks.append(nn.Linear(previous_channels, current_channels))
            blocks.append(nn.ReLU(inplace=True))
            previous_channels = current_channels

        super().__init__(*blocks)
        for layer in self.modules():
            if isinstance(layer, nn.Conv2d):
                nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
                if layer.bias is not None:
                    nn.init.zeros_(layer.bias)


class FastRCNNPredictor(nn.Module):
    """
    Standard classification + bounding box regression layers
    for Fast R-CNN.

    Args:
        in_channels (int): number of input channels
        num_classes (int): number of output classes (including background)
    """

    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.cls_score = nn.Linear(in_channels, num_classes)
        self.bbox_pred = nn.Linear(in_channels, num_classes * 4)

    def forward(self, x):
        if x.dim() == 4:
            torch._assert(
                list(x.shape[2:]) == [1, 1],
                f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
            )
        x = x.flatten(start_dim=1)
        scores = self.cls_score(x)
        bbox_deltas = self.bbox_pred(x)

        return scores, bbox_deltas


_COMMON_META = {
    "categories": _COCO_CATEGORIES,
    "min_size": (1, 1),
}


[docs]class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", transforms=ObjectDetection, meta={ **_COMMON_META, "num_params": 41755286, "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn", "_metrics": { "COCO-val2017": { "box_map": 37.0, } }, "_docs": """These weights were produced by following a similar training recipe as on the paper.""", }, ) DEFAULT = COCO_V1
[docs]class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth", transforms=ObjectDetection, meta={ **_COMMON_META, "num_params": 43712278, "recipe": "https://github.com/pytorch/vision/pull/5763", "_metrics": { "COCO-val2017": { "box_map": 46.7, } }, "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""", }, ) DEFAULT = COCO_V1
[docs]class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", transforms=ObjectDetection, meta={ **_COMMON_META, "num_params": 19386354, "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn", "_metrics": { "COCO-val2017": { "box_map": 32.8, } }, "_docs": """These weights were produced by following a similar training recipe as on the paper.""", }, ) DEFAULT = COCO_V1
[docs]class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", transforms=ObjectDetection, meta={ **_COMMON_META, "num_params": 19386354, "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn", "_metrics": { "COCO-val2017": { "box_map": 22.8, } }, "_docs": """These weights were produced by following a similar training recipe as on the paper.""", }, ) DEFAULT = COCO_V1
[docs]@register_model() @handle_legacy_interface( weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1), weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), ) def fasterrcnn_resnet50_fpn( *, weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> FasterRCNN: """ Faster R-CNN model with a ResNet-50-FPN backbone from the `Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks <https://arxiv.org/abs/1506.01497>`__ paper. .. betastatus:: detection module The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each image, and should be in ``0-1`` range. Different images can have different sizes. The behavior of the model changes depending if it is in training or evaluation mode. During training, the model expects both the input tensors, as well as a targets (list of dictionary), containing: - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. - labels (``Int64Tensor[N]``): the class label for each ground-truth box The model returns a ``Dict[Tensor]`` during training, containing the classification and regression losses for both the RPN and the R-CNN. During inference, the model requires only the input tensors, and returns the post-processed predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as follows, where ``N`` is the number of detections: - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. - labels (``Int64Tensor[N]``): the predicted labels for each detection - scores (``Tensor[N]``): the scores of each detection For more details on the output, you may refer to :ref:`instance_seg_output`. Faster R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size. Example:: >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT) >>> # For training >>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4) >>> boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4] >>> labels = torch.randint(1, 91, (4, 11)) >>> images = list(image for image in images) >>> targets = [] >>> for i in range(len(images)): >>> d = {} >>> d['boxes'] = boxes[i] >>> d['labels'] = labels[i] >>> targets.append(d) >>> output = model(images, targets) >>> # For inference >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) >>> >>> # optionally, if you want to export the model to ONNX: >>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11) Args: weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. num_classes (int, optional): number of output classes of the model (including the background) weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for the backbone. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_ for more details about this class. .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights :members: """ weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights) weights_backbone = ResNet50_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) elif num_classes is None: num_classes = 91 is_trained = weights is not None or weights_backbone is not None trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) model = FasterRCNN(backbone, num_classes=num_classes, **kwargs) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1: overwrite_eps(model, 0.0) return model
[docs]@register_model() @handle_legacy_interface( weights=("pretrained", FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1), weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), ) def fasterrcnn_resnet50_fpn_v2( *, weights: Optional[FasterRCNN_ResNet50_FPN_V2_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, weights_backbone: Optional[ResNet50_Weights] = None, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> FasterRCNN: """ Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone from `Benchmarking Detection Transfer Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`__ paper. .. betastatus:: detection module It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more details. Args: weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. num_classes (int, optional): number of output classes of the model (including the background) weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for the backbone. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_ for more details about this class. .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights :members: """ weights = FasterRCNN_ResNet50_FPN_V2_Weights.verify(weights) weights_backbone = ResNet50_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) elif num_classes is None: num_classes = 91 is_trained = weights is not None or weights_backbone is not None trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) backbone = resnet50(weights=weights_backbone, progress=progress) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d) rpn_anchor_generator = _default_anchorgen() rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2) box_head = FastRCNNConvFCHead( (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d ) model = FasterRCNN( backbone, num_classes=num_classes, rpn_anchor_generator=rpn_anchor_generator, rpn_head=rpn_head, box_head=box_head, **kwargs, ) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) return model
def _fasterrcnn_mobilenet_v3_large_fpn( *, weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]], progress: bool, num_classes: Optional[int], weights_backbone: Optional[MobileNet_V3_Large_Weights], trainable_backbone_layers: Optional[int], **kwargs: Any, ) -> FasterRCNN: if weights is not None: weights_backbone = None num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) elif num_classes is None: num_classes = 91 is_trained = weights is not None or weights_backbone is not None trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers) anchor_sizes = ( ( 32, 64, 128, 256, 512, ), ) * 3 aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) model = FasterRCNN( backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs ) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) return model
[docs]@register_model() @handle_legacy_interface( weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1), weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), ) def fasterrcnn_mobilenet_v3_large_320_fpn( *, weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> FasterRCNN: """ Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tunned for mobile use cases. .. betastatus:: detection module It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more details. Example:: >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) Args: weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. num_classes (int, optional): number of output classes of the model (including the background) weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The pretrained weights for the backbone. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_ for more details about this class. .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights :members: """ weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights) weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) defaults = { "min_size": 320, "max_size": 640, "rpn_pre_nms_top_n_test": 150, "rpn_post_nms_top_n_test": 150, "rpn_score_thresh": 0.05, } kwargs = {**defaults, **kwargs} return _fasterrcnn_mobilenet_v3_large_fpn( weights=weights, progress=progress, num_classes=num_classes, weights_backbone=weights_backbone, trainable_backbone_layers=trainable_backbone_layers, **kwargs, )
[docs]@register_model() @handle_legacy_interface( weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1), weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), ) def fasterrcnn_mobilenet_v3_large_fpn( *, weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> FasterRCNN: """ Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone. .. betastatus:: detection module It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more details. Example:: >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) Args: weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. num_classes (int, optional): number of output classes of the model (including the background) weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The pretrained weights for the backbone. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_ for more details about this class. .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights :members: """ weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights) weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) defaults = { "rpn_score_thresh": 0.05, } kwargs = {**defaults, **kwargs} return _fasterrcnn_mobilenet_v3_large_fpn( weights=weights, progress=progress, num_classes=num_classes, weights_backbone=weights_backbone, trainable_backbone_layers=trainable_backbone_layers, **kwargs, )
# The dictionary below is internal implementation detail and will be removed in v0.15 from .._utils import _ModelURLs model_urls = _ModelURLs( { "fasterrcnn_resnet50_fpn_coco": FasterRCNN_ResNet50_FPN_Weights.COCO_V1.url, "fasterrcnn_mobilenet_v3_large_320_fpn_coco": FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1.url, "fasterrcnn_mobilenet_v3_large_fpn_coco": FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1.url, } )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources