Shortcuts

Source code for torchvision.ops.deform_conv

import math

import torch
from torch import nn, Tensor
from torch.nn import init
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair
from typing import Optional, Tuple
from torchvision.extension import _assert_has_ops


[docs]def deform_conv2d( input: Tensor, offset: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Tuple[int, int] = (1, 1), padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), mask: Optional[Tensor] = None, ) -> Tensor: r""" Performs Deformable Convolution v2, described in `Deformable ConvNets v2: More Deformable, Better Results <https://arxiv.org/abs/1811.11168>`__ if :attr:`mask` is not ``None`` and Performs Deformable Convolution, described in `Deformable Convolutional Networks <https://arxiv.org/abs/1703.06211>`__ if :attr:`mask` is ``None``. Args: input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, out_height, out_width]): offsets to be applied for each position in the convolution kernel. weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]): convolution weights, split into groups of size (in_channels // groups) bias (Tensor[out_channels]): optional bias of shape (out_channels,). Default: None stride (int or Tuple[int, int]): distance between convolution centers. Default: 1 padding (int or Tuple[int, int]): height/width of padding of zeroes around each image. Default: 0 dilation (int or Tuple[int, int]): the spacing between kernel elements. Default: 1 mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width, out_height, out_width]): masks to be applied for each position in the convolution kernel. Default: None Returns: output (Tensor[batch_sz, out_channels, out_h, out_w]): result of convolution Examples:: >>> input = torch.rand(4, 3, 10, 10) >>> kh, kw = 3, 3 >>> weight = torch.rand(5, 3, kh, kw) >>> # offset and mask should have the same spatial size as the output >>> # of the convolution. In this case, for an input of 10, stride of 1 >>> # and kernel size of 3, without padding, the output size is 8 >>> offset = torch.rand(4, 2 * kh * kw, 8, 8) >>> mask = torch.rand(4, kh * kw, 8, 8) >>> out = deform_conv2d(input, offset, weight, mask=mask) >>> print(out.shape) >>> # returns >>> torch.Size([4, 5, 8, 8]) """ _assert_has_ops() out_channels = weight.shape[0] use_mask = mask is not None if mask is None: mask = torch.zeros((input.shape[0], 0), device=input.device, dtype=input.dtype) if bias is None: bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype) stride_h, stride_w = _pair(stride) pad_h, pad_w = _pair(padding) dil_h, dil_w = _pair(dilation) weights_h, weights_w = weight.shape[-2:] _, n_in_channels, in_h, in_w = input.shape n_offset_grps = offset.shape[1] // (2 * weights_h * weights_w) n_weight_grps = n_in_channels // weight.shape[1] if n_offset_grps == 0: raise RuntimeError( "the shape of the offset tensor at dimension 1 is not valid. It should " "be a multiple of 2 * weight.size[2] * weight.size[3].\n" "Got offset.shape[1]={}, while 2 * weight.size[2] * weight.size[3]={}".format( offset.shape[1], 2 * weights_h * weights_w)) return torch.ops.torchvision.deform_conv2d( input, weight, offset, mask, bias, stride_h, stride_w, pad_h, pad_w, dil_h, dil_w, n_weight_grps, n_offset_grps, use_mask,)
[docs]class DeformConv2d(nn.Module): """ See deform_conv2d """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, bias: bool = True, ): super(DeformConv2d, self).__init__() if in_channels % groups != 0: raise ValueError('in_channels must be divisible by groups') if out_channels % groups != 0: raise ValueError('out_channels must be divisible by groups') self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = _pair(kernel_size) self.stride = _pair(stride) self.padding = _pair(padding) self.dilation = _pair(dilation) self.groups = groups self.weight = Parameter(torch.empty(out_channels, in_channels // groups, self.kernel_size[0], self.kernel_size[1])) if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self) -> None: init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) def forward(self, input: Tensor, offset: Tensor, mask: Tensor = None) -> Tensor: """ Args: input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, out_height, out_width]): offsets to be applied for each position in the convolution kernel. mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width, out_height, out_width]): masks to be applied for each position in the convolution kernel. """ return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, mask=mask) def __repr__(self) -> str: s = self.__class__.__name__ + '(' s += '{in_channels}' s += ', {out_channels}' s += ', kernel_size={kernel_size}' s += ', stride={stride}' s += ', padding={padding}' if self.padding != (0, 0) else '' s += ', dilation={dilation}' if self.dilation != (1, 1) else '' s += ', groups={groups}' if self.groups != 1 else '' s += ', bias=False' if self.bias is None else '' s += ')' return s.format(**self.__dict__)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources