importfunctoolsfromtypingimportList,Unionimporttorchimporttorch.fxfromtorchimportnn,Tensorfromtorch._dynamo.utilsimportis_compile_supportedfromtorch.jit.annotationsimportBroadcastingList2fromtorch.nn.modules.utilsimport_pairfromtorchvision.extensionimport_assert_has_ops,_has_opsfrom..utilsimport_log_api_usage_oncefrom._utilsimportcheck_roi_boxes_shape,convert_boxes_to_roi_formatdeflazy_compile(**compile_kwargs):"""Lazily wrap a function with torch.compile on the first call This avoids eagerly importing dynamo. """defdecorate_fn(fn):@functools.wraps(fn)defcompile_hook(*args,**kwargs):compiled_fn=torch.compile(fn,**compile_kwargs)globals()[fn.__name__]=functools.wraps(fn)(compiled_fn)returncompiled_fn(*args,**kwargs)returncompile_hookreturndecorate_fn# NB: all inputs are tensorsdef_bilinear_interpolate(input,# [N, C, H, W]roi_batch_ind,# [K]y,# [K, PH, IY]x,# [K, PW, IX]ymask,# [K, IY]xmask,# [K, IX]):_,channels,height,width=input.size()# deal with inverse element out of feature map boundaryy=y.clamp(min=0)x=x.clamp(min=0)y_low=y.int()x_low=x.int()y_high=torch.where(y_low>=height-1,height-1,y_low+1)y_low=torch.where(y_low>=height-1,height-1,y_low)y=torch.where(y_low>=height-1,y.to(input.dtype),y)x_high=torch.where(x_low>=width-1,width-1,x_low+1)x_low=torch.where(x_low>=width-1,width-1,x_low)x=torch.where(x_low>=width-1,x.to(input.dtype),x)ly=y-y_lowlx=x-x_lowhy=1.0-lyhx=1.0-lx# do bilinear interpolation, but respect the masking!# TODO: It's possible the masking here is unnecessary if y and# x were clamped appropriately; hard to telldefmasked_index(y,# [K, PH, IY]x,# [K, PW, IX]):ifymaskisnotNone:assertxmaskisnotNoney=torch.where(ymask[:,None,:],y,0)x=torch.where(xmask[:,None,:],x,0)returninput[roi_batch_ind[:,None,None,None,None,None],torch.arange(channels,device=input.device)[None,:,None,None,None,None],y[:,None,:,None,:,None],# prev [K, PH, IY]x[:,None,None,:,None,:],# prev [K, PW, IX]]# [K, C, PH, PW, IY, IX]v1=masked_index(y_low,x_low)v2=masked_index(y_low,x_high)v3=masked_index(y_high,x_low)v4=masked_index(y_high,x_high)# all ws preemptively [K, C, PH, PW, IY, IX]defouter_prod(y,x):returny[:,None,:,None,:,None]*x[:,None,None,:,None,:]w1=outer_prod(hy,hx)w2=outer_prod(hy,lx)w3=outer_prod(ly,hx)w4=outer_prod(ly,lx)val=w1*v1+w2*v2+w3*v3+w4*v4returnval# TODO: this doesn't actually cache# TODO: main library should make this easier to dodefmaybe_cast(tensor):iftorch.is_autocast_enabled()andtensor.is_cudaandtensor.dtype!=torch.double:returntensor.float()else:returntensor# This is a pure Python and differentiable implementation of roi_align. When# run in eager mode, it uses a lot of memory, but when compiled it has# acceptable memory usage. The main point of this implementation is that# its backwards is deterministic.# It is transcribed directly off of the roi_align CUDA kernel, see# https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266@lazy_compile(dynamic=True)def_roi_align(input,rois,spatial_scale,pooled_height,pooled_width,sampling_ratio,aligned):orig_dtype=input.dtypeinput=maybe_cast(input)rois=maybe_cast(rois)_,_,height,width=input.size()ph=torch.arange(pooled_height,device=input.device)# [PH]pw=torch.arange(pooled_width,device=input.device)# [PW]# input: [N, C, H, W]# rois: [K, 5]roi_batch_ind=rois[:,0].int()# [K]offset=0.5ifalignedelse0.0roi_start_w=rois[:,1]*spatial_scale-offset# [K]roi_start_h=rois[:,2]*spatial_scale-offset# [K]roi_end_w=rois[:,3]*spatial_scale-offset# [K]roi_end_h=rois[:,4]*spatial_scale-offset# [K]roi_width=roi_end_w-roi_start_w# [K]roi_height=roi_end_h-roi_start_h# [K]ifnotaligned:roi_width=torch.clamp(roi_width,min=1.0)# [K]roi_height=torch.clamp(roi_height,min=1.0)# [K]bin_size_h=roi_height/pooled_height# [K]bin_size_w=roi_width/pooled_width# [K]exact_sampling=sampling_ratio>0roi_bin_grid_h=sampling_ratioifexact_samplingelsetorch.ceil(roi_height/pooled_height)# scalar or [K]roi_bin_grid_w=sampling_ratioifexact_samplingelsetorch.ceil(roi_width/pooled_width)# scalar or [K]""" iy, ix = dims(2) """ifexact_sampling:count=max(roi_bin_grid_h*roi_bin_grid_w,1)# scalariy=torch.arange(roi_bin_grid_h,device=input.device)# [IY]ix=torch.arange(roi_bin_grid_w,device=input.device)# [IX]ymask=Nonexmask=Noneelse:count=torch.clamp(roi_bin_grid_h*roi_bin_grid_w,min=1)# [K]# When doing adaptive sampling, the number of samples we need to do# is data-dependent based on how big the ROIs are. This is a bit# awkward because first-class dims can't actually handle this.# So instead, we inefficiently suppose that we needed to sample ALL# the points and mask out things that turned out to be unnecessaryiy=torch.arange(height,device=input.device)# [IY]ix=torch.arange(width,device=input.device)# [IX]ymask=iy[None,:]<roi_bin_grid_h[:,None]# [K, IY]xmask=ix[None,:]<roi_bin_grid_w[:,None]# [K, IX]deffrom_K(t):returnt[:,None,None]y=(from_K(roi_start_h)+ph[None,:,None]*from_K(bin_size_h)+(iy[None,None,:]+0.5).to(input.dtype)*from_K(bin_size_h/roi_bin_grid_h))# [K, PH, IY]x=(from_K(roi_start_w)+pw[None,:,None]*from_K(bin_size_w)+(ix[None,None,:]+0.5).to(input.dtype)*from_K(bin_size_w/roi_bin_grid_w))# [K, PW, IX]val=_bilinear_interpolate(input,roi_batch_ind,y,x,ymask,xmask)# [K, C, PH, PW, IY, IX]# Mask out samples that weren't actually adaptively neededifnotexact_sampling:val=torch.where(ymask[:,None,None,None,:,None],val,0)val=torch.where(xmask[:,None,None,None,None,:],val,0)output=val.sum((-1,-2))# remove IY, IX ~> [K, C, PH, PW]ifisinstance(count,torch.Tensor):output/=count[:,None,None,None]else:output/=countoutput=output.to(orig_dtype)returnoutput
[docs]@torch.fx.wrapdefroi_align(input:Tensor,boxes:Union[Tensor,List[Tensor]],output_size:BroadcastingList2[int],spatial_scale:float=1.0,sampling_ratio:int=-1,aligned:bool=False,)->Tensor:""" Performs Region of Interest (RoI) Align operator with average pooling, as described in Mask R-CNN. Args: input (Tensor[N, C, H, W]): The input tensor, i.e. a batch with ``N`` elements. Each element contains ``C`` feature maps of dimensions ``H x W``. If the tensor is quantized, we expect a batch size of ``N == 1``. boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2) format where the regions will be taken from. The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``. If a single Tensor is passed, then the first column should contain the index of the corresponding element in the batch, i.e. a number in ``[0, N - 1]``. If a list of Tensors is passed, then each Tensor will correspond to the boxes for an element i in the batch. output_size (int or Tuple[int, int]): the size of the output (in bins or pixels) after the pooling is performed, as (height, width). spatial_scale (float): a scaling factor that maps the box coordinates to the input coordinates. For example, if your boxes are defined on the scale of a 224x224 image and your input is a 112x112 feature map (resulting from a 0.5x scaling of the original image), you'll want to set this to 0.5. Default: 1.0 sampling_ratio (int): number of sampling points in the interpolation grid used to compute the output value of each pooled output bin. If > 0, then exactly ``sampling_ratio x sampling_ratio`` sampling points per bin are used. If <= 0, then an adaptive number of grid points are used (computed as ``ceil(roi_width / output_width)``, and likewise for height). Default: -1 aligned (bool): If False, use the legacy implementation. If True, pixel shift the box coordinates it by -0.5 for a better alignment with the two neighboring pixel indices. This version is used in Detectron2 Returns: Tensor[K, C, output_size[0], output_size[1]]: The pooled RoIs. """ifnottorch.jit.is_scripting()andnottorch.jit.is_tracing():_log_api_usage_once(roi_align)check_roi_boxes_shape(boxes)rois=boxesoutput_size=_pair(output_size)ifnotisinstance(rois,torch.Tensor):rois=convert_boxes_to_roi_format(rois)ifnottorch.jit.is_scripting():if(not_has_ops()or(torch.are_deterministic_algorithms_enabled()and(input.is_cudaorinput.is_mpsorinput.is_xpu)))andis_compile_supported(input.device.type):return_roi_align(input,rois,spatial_scale,output_size[0],output_size[1],sampling_ratio,aligned)_assert_has_ops()returntorch.ops.torchvision.roi_align(input,rois,spatial_scale,output_size[0],output_size[1],sampling_ratio,aligned)
[docs]classRoIAlign(nn.Module):""" See :func:`roi_align`. """def__init__(self,output_size:BroadcastingList2[int],spatial_scale:float,sampling_ratio:int,aligned:bool=False,):super().__init__()_log_api_usage_once(self)self.output_size=output_sizeself.spatial_scale=spatial_scaleself.sampling_ratio=sampling_ratioself.aligned=aligned
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.