Shortcuts

Source code for torchvision.ops.roi_pool

from typing import List, Union

import torch
import torch.fx
from torch import nn, Tensor
from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops

from ..utils import _log_api_usage_once
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format


[docs]@torch.fx.wrap def roi_pool( input: Tensor, boxes: Union[Tensor, List[Tensor]], output_size: BroadcastingList2[int], spatial_scale: float = 1.0, ) -> Tensor: """ Performs Region of Interest (RoI) Pool operator described in Fast R-CNN Args: input (Tensor[N, C, H, W]): The input tensor, i.e. a batch with ``N`` elements. Each element contains ``C`` feature maps of dimensions ``H x W``. boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2) format where the regions will be taken from. The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``. If a single Tensor is passed, then the first column should contain the index of the corresponding element in the batch, i.e. a number in ``[0, N - 1]``. If a list of Tensors is passed, then each Tensor will correspond to the boxes for an element i in the batch. output_size (int or Tuple[int, int]): the size of the output after the cropping is performed, as (height, width) spatial_scale (float): a scaling factor that maps the box coordinates to the input coordinates. For example, if your boxes are defined on the scale of a 224x224 image and your input is a 112x112 feature map (resulting from a 0.5x scaling of the original image), you'll want to set this to 0.5. Default: 1.0 Returns: Tensor[K, C, output_size[0], output_size[1]]: The pooled RoIs. """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(roi_pool) _assert_has_ops() check_roi_boxes_shape(boxes) rois = boxes output_size = _pair(output_size) if not isinstance(rois, torch.Tensor): rois = convert_boxes_to_roi_format(rois) output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale, output_size[0], output_size[1]) return output
[docs]class RoIPool(nn.Module): """ See :func:`roi_pool`. """ def __init__(self, output_size: BroadcastingList2[int], spatial_scale: float): super().__init__() _log_api_usage_once(self) self.output_size = output_size self.spatial_scale = spatial_scale
[docs] def forward(self, input: Tensor, rois: Union[Tensor, List[Tensor]]) -> Tensor: return roi_pool(input, rois, self.output_size, self.spatial_scale)
def __repr__(self) -> str: s = f"{self.__class__.__name__}(output_size={self.output_size}, spatial_scale={self.spatial_scale})" return s

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources