Source code for torchvision.transforms.v2.functional._geometry
importmathimportnumbersimportwarningsfromtypingimportAny,List,Optional,Sequence,Tuple,UnionimportPIL.Imageimporttorchfromtorch.nn.functionalimportgrid_sample,interpolate,padastorch_padfromtorchvisionimporttv_tensorsfromtorchvision.transformsimport_functional_pilas_FPfromtorchvision.transforms._functional_tensorimport_pad_symmetricfromtorchvision.transforms.functionalimport(_compute_resized_output_sizeas__compute_resized_output_size,_get_perspective_coeffs,_interpolation_modes_from_int,InterpolationMode,pil_modes_mapping,pil_to_tensor,to_pil_image,)fromtorchvision.utilsimport_log_api_usage_oncefrom._metaimport_get_size_image_pil,clamp_bounding_boxes,convert_bounding_box_formatfrom._utilsimport_FillTypeJIT,_get_kernel,_register_five_ten_crop_kernel_internal,_register_kernel_internaldef_check_interpolation(interpolation:Union[InterpolationMode,int])->InterpolationMode:ifisinstance(interpolation,int):interpolation=_interpolation_modes_from_int(interpolation)elifnotisinstance(interpolation,InterpolationMode):raiseValueError(f"Argument interpolation should be an `InterpolationMode` or a corresponding Pillow integer constant, "f"but got {interpolation}.")returninterpolation
[docs]defhorizontal_flip(inpt:torch.Tensor)->torch.Tensor:"""See :class:`~torchvision.transforms.v2.RandomHorizontalFlip` for details."""iftorch.jit.is_scripting():returnhorizontal_flip_image(inpt)_log_api_usage_once(horizontal_flip)kernel=_get_kernel(horizontal_flip,type(inpt))returnkernel(inpt)
@_register_kernel_internal(horizontal_flip,torch.Tensor)@_register_kernel_internal(horizontal_flip,tv_tensors.Image)defhorizontal_flip_image(image:torch.Tensor)->torch.Tensor:returnimage.flip(-1)@_register_kernel_internal(horizontal_flip,PIL.Image.Image)def_horizontal_flip_image_pil(image:PIL.Image.Image)->PIL.Image.Image:return_FP.hflip(image)@_register_kernel_internal(horizontal_flip,tv_tensors.Mask)defhorizontal_flip_mask(mask:torch.Tensor)->torch.Tensor:returnhorizontal_flip_image(mask)defhorizontal_flip_bounding_boxes(bounding_boxes:torch.Tensor,format:tv_tensors.BoundingBoxFormat,canvas_size:Tuple[int,int])->torch.Tensor:shape=bounding_boxes.shapebounding_boxes=bounding_boxes.clone().reshape(-1,4)ifformat==tv_tensors.BoundingBoxFormat.XYXY:bounding_boxes[:,[2,0]]=bounding_boxes[:,[0,2]].sub_(canvas_size[1]).neg_()elifformat==tv_tensors.BoundingBoxFormat.XYWH:bounding_boxes[:,0].add_(bounding_boxes[:,2]).sub_(canvas_size[1]).neg_()else:# format == tv_tensors.BoundingBoxFormat.CXCYWH:bounding_boxes[:,0].sub_(canvas_size[1]).neg_()returnbounding_boxes.reshape(shape)@_register_kernel_internal(horizontal_flip,tv_tensors.BoundingBoxes,tv_tensor_wrapper=False)def_horizontal_flip_bounding_boxes_dispatch(inpt:tv_tensors.BoundingBoxes)->tv_tensors.BoundingBoxes:output=horizontal_flip_bounding_boxes(inpt.as_subclass(torch.Tensor),format=inpt.format,canvas_size=inpt.canvas_size)returntv_tensors.wrap(output,like=inpt)@_register_kernel_internal(horizontal_flip,tv_tensors.Video)defhorizontal_flip_video(video:torch.Tensor)->torch.Tensor:returnhorizontal_flip_image(video)
[docs]defvertical_flip(inpt:torch.Tensor)->torch.Tensor:"""See :class:`~torchvision.transforms.v2.RandomVerticalFlip` for details."""iftorch.jit.is_scripting():returnvertical_flip_image(inpt)_log_api_usage_once(vertical_flip)kernel=_get_kernel(vertical_flip,type(inpt))returnkernel(inpt)
@_register_kernel_internal(vertical_flip,torch.Tensor)@_register_kernel_internal(vertical_flip,tv_tensors.Image)defvertical_flip_image(image:torch.Tensor)->torch.Tensor:returnimage.flip(-2)@_register_kernel_internal(vertical_flip,PIL.Image.Image)def_vertical_flip_image_pil(image:PIL.Image.Image)->PIL.Image.Image:return_FP.vflip(image)@_register_kernel_internal(vertical_flip,tv_tensors.Mask)defvertical_flip_mask(mask:torch.Tensor)->torch.Tensor:returnvertical_flip_image(mask)defvertical_flip_bounding_boxes(bounding_boxes:torch.Tensor,format:tv_tensors.BoundingBoxFormat,canvas_size:Tuple[int,int])->torch.Tensor:shape=bounding_boxes.shapebounding_boxes=bounding_boxes.clone().reshape(-1,4)ifformat==tv_tensors.BoundingBoxFormat.XYXY:bounding_boxes[:,[1,3]]=bounding_boxes[:,[3,1]].sub_(canvas_size[0]).neg_()elifformat==tv_tensors.BoundingBoxFormat.XYWH:bounding_boxes[:,1].add_(bounding_boxes[:,3]).sub_(canvas_size[0]).neg_()else:# format == tv_tensors.BoundingBoxFormat.CXCYWH:bounding_boxes[:,1].sub_(canvas_size[0]).neg_()returnbounding_boxes.reshape(shape)@_register_kernel_internal(vertical_flip,tv_tensors.BoundingBoxes,tv_tensor_wrapper=False)def_vertical_flip_bounding_boxes_dispatch(inpt:tv_tensors.BoundingBoxes)->tv_tensors.BoundingBoxes:output=vertical_flip_bounding_boxes(inpt.as_subclass(torch.Tensor),format=inpt.format,canvas_size=inpt.canvas_size)returntv_tensors.wrap(output,like=inpt)@_register_kernel_internal(vertical_flip,tv_tensors.Video)defvertical_flip_video(video:torch.Tensor)->torch.Tensor:returnvertical_flip_image(video)# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are# prevalent and well understood. Thus, we just alias them without deprecating the old names.hflip=horizontal_flipvflip=vertical_flipdef_compute_resized_output_size(canvas_size:Tuple[int,int],size:Optional[List[int]],max_size:Optional[int]=None)->List[int]:ifisinstance(size,int):size=[size]elifmax_sizeisnotNoneandsizeisnotNoneandlen(size)!=1:raiseValueError("max_size should only be passed if size is None or specifies the length of the smaller edge, ""i.e. size should be an int or a sequence of length 1 in torchscript mode.")return__compute_resized_output_size(canvas_size,size=size,max_size=max_size,allow_size_none=True)
[docs]defresize(inpt:torch.Tensor,size:Optional[List[int]],interpolation:Union[InterpolationMode,int]=InterpolationMode.BILINEAR,max_size:Optional[int]=None,antialias:Optional[bool]=True,)->torch.Tensor:"""See :class:`~torchvision.transforms.v2.Resize` for details."""iftorch.jit.is_scripting():returnresize_image(inpt,size=size,interpolation=interpolation,max_size=max_size,antialias=antialias)_log_api_usage_once(resize)kernel=_get_kernel(resize,type(inpt))returnkernel(inpt,size=size,interpolation=interpolation,max_size=max_size,antialias=antialias)
# This is an internal helper method for resize_image. We should put it here instead of keeping it# inside resize_image due to torchscript.# uint8 dtype support for bilinear and bicubic is limited to cpu and# according to our benchmarks on eager, non-AVX CPUs should still prefer u8->f32->interpolate->u8 path for bilineardef_do_native_uint8_resize_on_cpu(interpolation:InterpolationMode)->bool:ifinterpolation==InterpolationMode.BILINEAR:iftorch.compiler.is_compiling():returnTrueelse:return"AVX2"intorch.backends.cpu.get_cpu_capability()returninterpolation==InterpolationMode.BICUBIC@_register_kernel_internal(resize,torch.Tensor)@_register_kernel_internal(resize,tv_tensors.Image)defresize_image(image:torch.Tensor,size:Optional[List[int]],interpolation:Union[InterpolationMode,int]=InterpolationMode.BILINEAR,max_size:Optional[int]=None,antialias:Optional[bool]=True,)->torch.Tensor:interpolation=_check_interpolation(interpolation)antialias=FalseifantialiasisNoneelseantialiasalign_corners:Optional[bool]=Noneifinterpolation==InterpolationMode.BILINEARorinterpolation==InterpolationMode.BICUBIC:align_corners=Falseelse:# The default of antialias is True from 0.17, so we don't warn or# error if other interpolation modes are used. This is documented.antialias=Falseshape=image.shapenumel=image.numel()num_channels,old_height,old_width=shape[-3:]new_height,new_width=_compute_resized_output_size((old_height,old_width),size=size,max_size=max_size)if(new_height,new_width)==(old_height,old_width):returnimageelifnumel>0:dtype=image.dtypeacceptable_dtypes=[torch.float32,torch.float64]ifinterpolation==InterpolationMode.NEARESTorinterpolation==InterpolationMode.NEAREST_EXACT:# uint8 dtype can be included for cpu and cuda input if nearest modeacceptable_dtypes.append(torch.uint8)elifimage.device.type=="cpu":if_do_native_uint8_resize_on_cpu(interpolation):acceptable_dtypes.append(torch.uint8)image=image.reshape(-1,num_channels,old_height,old_width)strides=image.stride()ifimage.is_contiguous(memory_format=torch.channels_last)andimage.shape[0]==1andnumel!=strides[0]:# There is a weird behaviour in torch core where the output tensor of `interpolate()` can be allocated as# contiguous even though the input is un-ambiguously channels_last (https://github.com/pytorch/pytorch/issues/68430).# In particular this happens for the typical torchvision use-case of single CHW images where we fake the batch dim# to become 1CHW. Below, we restride those tensors to trick torch core into properly allocating the output as# channels_last, thus preserving the memory format of the input. This is not just for format consistency:# for uint8 bilinear images, this also avoids an extra copy (re-packing) of the output and saves time.# TODO: when https://github.com/pytorch/pytorch/issues/68430 is fixed (possibly by https://github.com/pytorch/pytorch/pull/100373),# we should be able to remove this hack.new_strides=list(strides)new_strides[0]=numelimage=image.as_strided((1,num_channels,old_height,old_width),new_strides)need_cast=dtypenotinacceptable_dtypesifneed_cast:image=image.to(dtype=torch.float32)image=interpolate(image,size=[new_height,new_width],mode=interpolation.value,align_corners=align_corners,antialias=antialias,)ifneed_cast:ifinterpolation==InterpolationMode.BICUBICanddtype==torch.uint8:# This path is hit on non-AVX archs, or on GPU.image=image.clamp_(min=0,max=255)ifdtypein(torch.uint8,torch.int8,torch.int16,torch.int32,torch.int64):image=image.round_()image=image.to(dtype=dtype)returnimage.reshape(shape[:-3]+(num_channels,new_height,new_width))def_resize_image_pil(image:PIL.Image.Image,size:Union[Sequence[int],int],interpolation:Union[InterpolationMode,int]=InterpolationMode.BILINEAR,max_size:Optional[int]=None,)->PIL.Image.Image:old_height,old_width=image.height,image.widthnew_height,new_width=_compute_resized_output_size((old_height,old_width),size=size,# type: ignore[arg-type]max_size=max_size,)interpolation=_check_interpolation(interpolation)if(new_height,new_width)==(old_height,old_width):returnimagereturnimage.resize((new_width,new_height),resample=pil_modes_mapping[interpolation])@_register_kernel_internal(resize,PIL.Image.Image)def__resize_image_pil_dispatch(image:PIL.Image.Image,size:Union[Sequence[int],int],interpolation:Union[InterpolationMode,int]=InterpolationMode.BILINEAR,max_size:Optional[int]=None,antialias:Optional[bool]=True,)->PIL.Image.Image:ifantialiasisFalse:warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")return_resize_image_pil(image,size=size,interpolation=interpolation,max_size=max_size)defresize_mask(mask:torch.Tensor,size:Optional[List[int]],max_size:Optional[int]=None)->torch.Tensor:ifmask.ndim<3:mask=mask.unsqueeze(0)needs_squeeze=Trueelse:needs_squeeze=Falseoutput=resize_image(mask,size=size,interpolation=InterpolationMode.NEAREST,max_size=max_size)ifneeds_squeeze:output=output.squeeze(0)returnoutput@_register_kernel_internal(resize,tv_tensors.Mask,tv_tensor_wrapper=False)def_resize_mask_dispatch(inpt:tv_tensors.Mask,size:List[int],max_size:Optional[int]=None,**kwargs:Any)->tv_tensors.Mask:output=resize_mask(inpt.as_subclass(torch.Tensor),size,max_size=max_size)returntv_tensors.wrap(output,like=inpt)defresize_bounding_boxes(bounding_boxes:torch.Tensor,canvas_size:Tuple[int,int],size:Optional[List[int]],max_size:Optional[int]=None,)->Tuple[torch.Tensor,Tuple[int,int]]:old_height,old_width=canvas_sizenew_height,new_width=_compute_resized_output_size(canvas_size,size=size,max_size=max_size)if(new_height,new_width)==(old_height,old_width):returnbounding_boxes,canvas_sizew_ratio=new_width/old_widthh_ratio=new_height/old_heightratios=torch.tensor([w_ratio,h_ratio,w_ratio,h_ratio],device=bounding_boxes.device)return(bounding_boxes.mul(ratios).to(bounding_boxes.dtype),(new_height,new_width),)@_register_kernel_internal(resize,tv_tensors.BoundingBoxes,tv_tensor_wrapper=False)def_resize_bounding_boxes_dispatch(inpt:tv_tensors.BoundingBoxes,size:Optional[List[int]],max_size:Optional[int]=None,**kwargs:Any)->tv_tensors.BoundingBoxes:output,canvas_size=resize_bounding_boxes(inpt.as_subclass(torch.Tensor),inpt.canvas_size,size,max_size=max_size)returntv_tensors.wrap(output,like=inpt,canvas_size=canvas_size)@_register_kernel_internal(resize,tv_tensors.Video)defresize_video(video:torch.Tensor,size:Optional[List[int]],interpolation:Union[InterpolationMode,int]=InterpolationMode.BILINEAR,max_size:Optional[int]=None,antialias:Optional[bool]=True,)->torch.Tensor:returnresize_image(video,size=size,interpolation=interpolation,max_size=max_size,antialias=antialias)
[docs]defaffine(inpt:torch.Tensor,angle:Union[int,float],translate:List[float],scale:float,shear:List[float],interpolation:Union[InterpolationMode,int]=InterpolationMode.NEAREST,fill:_FillTypeJIT=None,center:Optional[List[float]]=None,)->torch.Tensor:"""See :class:`~torchvision.transforms.v2.RandomAffine` for details."""iftorch.jit.is_scripting():returnaffine_image(inpt,angle=angle,translate=translate,scale=scale,shear=shear,interpolation=interpolation,fill=fill,center=center,)_log_api_usage_once(affine)kernel=_get_kernel(affine,type(inpt))returnkernel(inpt,angle=angle,translate=translate,scale=scale,shear=shear,interpolation=interpolation,fill=fill,center=center,)
def_affine_parse_args(angle:Union[int,float],translate:List[float],scale:float,shear:List[float],interpolation:InterpolationMode=InterpolationMode.NEAREST,center:Optional[List[float]]=None,)->Tuple[float,List[float],List[float],Optional[List[float]]]:ifnotisinstance(angle,(int,float)):raiseTypeError("Argument angle should be int or float")ifnotisinstance(translate,(list,tuple)):raiseTypeError("Argument translate should be a sequence")iflen(translate)!=2:raiseValueError("Argument translate should be a sequence of length 2")ifscale<=0.0:raiseValueError("Argument scale should be positive")ifnotisinstance(shear,(numbers.Number,(list,tuple))):raiseTypeError("Shear should be either a single value or a sequence of two values")ifnotisinstance(interpolation,InterpolationMode):raiseTypeError("Argument interpolation should be a InterpolationMode")ifisinstance(angle,int):angle=float(angle)ifisinstance(translate,tuple):translate=list(translate)ifisinstance(shear,numbers.Number):shear=[shear,0.0]ifisinstance(shear,tuple):shear=list(shear)iflen(shear)==1:shear=[shear[0],shear[0]]iflen(shear)!=2:raiseValueError(f"Shear should be a sequence containing two values. Got {shear}")ifcenterisnotNone:ifnotisinstance(center,(list,tuple)):raiseTypeError("Argument center should be a sequence")else:center=[float(c)forcincenter]returnangle,translate,shear,centerdef_get_inverse_affine_matrix(center:List[float],angle:float,translate:List[float],scale:float,shear:List[float],inverted:bool=True)->List[float]:# Helper method to compute inverse matrix for affine transformation# Pillow requires inverse affine transformation matrix:# Affine matrix is : M = T * C * RotateScaleShear * C^-1## where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]# C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]# RotateScaleShear is rotation with scale and shear matrix## RotateScaleShear(a, s, (sx, sy)) =# = R(a) * S(s) * SHy(sy) * SHx(sx)# = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]# [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]# [ 0 , 0 , 1 ]# where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:# SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]# [0, 1 ] [-tan(s), 1]## Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1rot=math.radians(angle)sx=math.radians(shear[0])sy=math.radians(shear[1])cx,cy=centertx,ty=translate# Cached resultscos_sy=math.cos(sy)tan_sx=math.tan(sx)rot_minus_sy=rot-sycx_plus_tx=cx+txcy_plus_ty=cy+ty# Rotate Scale Shear (RSS) without scalinga=math.cos(rot_minus_sy)/cos_syb=-(a*tan_sx+math.sin(rot))c=math.sin(rot_minus_sy)/cos_syd=math.cos(rot)-c*tan_sxifinverted:# Inverted rotation matrix with scale and shear# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1matrix=[d/scale,-b/scale,0.0,-c/scale,a/scale,0.0]# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1# and then apply center translation: C * RSS^-1 * C^-1 * T^-1matrix[2]+=cx-matrix[0]*cx_plus_tx-matrix[1]*cy_plus_tymatrix[5]+=cy-matrix[3]*cx_plus_tx-matrix[4]*cy_plus_tyelse:matrix=[a*scale,b*scale,0.0,c*scale,d*scale,0.0]# Apply inverse of center translation: RSS * C^-1# and then apply translation and center : T * C * RSS * C^-1matrix[2]+=cx_plus_tx-matrix[0]*cx-matrix[1]*cymatrix[5]+=cy_plus_ty-matrix[3]*cx-matrix[4]*cyreturnmatrixdef_compute_affine_output_size(matrix:List[float],w:int,h:int)->Tuple[int,int]:iftorch.compiler.is_compiling()andnottorch.jit.is_scripting():return_compute_affine_output_size_python(matrix,w,h)else:return_compute_affine_output_size_tensor(matrix,w,h)def_compute_affine_output_size_tensor(matrix:List[float],w:int,h:int)->Tuple[int,int]:# Inspired of PIL implementation:# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054# pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.# Points are shifted due to affine matrix torch convention about# the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)half_w=0.5*whalf_h=0.5*hpts=torch.tensor([[-half_w,-half_h,1.0],[-half_w,half_h,1.0],[half_w,half_h,1.0],[half_w,-half_h,1.0],])theta=torch.tensor(matrix,dtype=torch.float).view(2,3)new_pts=torch.matmul(pts,theta.T)min_vals,max_vals=new_pts.aminmax(dim=0)# shift points to [0, w] and [0, h] interval to match PIL resultshalfs=torch.tensor((half_w,half_h))min_vals.add_(halfs)max_vals.add_(halfs)# Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0tol=1e-4inv_tol=1.0/tolcmax=max_vals.mul_(inv_tol).trunc_().mul_(tol).ceil_()cmin=min_vals.mul_(inv_tol).trunc_().mul_(tol).floor_()size=cmax.sub_(cmin)returnint(size[0]),int(size[1])# w, hdef_compute_affine_output_size_python(matrix:List[float],w:int,h:int)->Tuple[int,int]:# Mostly copied from PIL implementation:# The only difference is with transformed points as input matrix has zero translation part here and# PIL has a centered translation part.# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054a,b,c,d,e,f=matrixxx=[]yy=[]half_w=0.5*whalf_h=0.5*hforx,yin((-half_w,-half_h),(half_w,-half_h),(half_w,half_h),(-half_w,half_h)):nx=a*x+b*y+cny=d*x+e*y+fxx.append(nx+half_w)yy.append(ny+half_h)nw=math.ceil(max(xx))-math.floor(min(xx))nh=math.ceil(max(yy))-math.floor(min(yy))returnint(nw),int(nh)# w, hdef_apply_grid_transform(img:torch.Tensor,grid:torch.Tensor,mode:str,fill:_FillTypeJIT)->torch.Tensor:input_shape=img.shapeoutput_height,output_width=grid.shape[1],grid.shape[2]num_channels,input_height,input_width=input_shape[-3:]output_shape=input_shape[:-3]+(num_channels,output_height,output_width)ifimg.numel()==0:returnimg.reshape(output_shape)img=img.reshape(-1,num_channels,input_height,input_width)squashed_batch_size=img.shape[0]# We are using context knowledge that grid should have float dtypefp=img.dtype==grid.dtypefloat_img=imgiffpelseimg.to(grid.dtype)ifsquashed_batch_size>1:# Apply same grid to a batch of imagesgrid=grid.expand(squashed_batch_size,-1,-1,-1)# Append a dummy mask for customized fill colors, should be faster than grid_sample() twiceiffillisnotNone:mask=torch.ones((squashed_batch_size,1,input_height,input_width),dtype=float_img.dtype,device=float_img.device)float_img=torch.cat((float_img,mask),dim=1)float_img=grid_sample(float_img,grid,mode=mode,padding_mode="zeros",align_corners=False)# Fill with required coloriffillisnotNone:float_img,mask=torch.tensor_split(float_img,indices=(-1,),dim=-3)mask=mask.expand_as(float_img)fill_list=fillifisinstance(fill,(tuple,list))else[float(fill)]# type: ignore[arg-type]fill_img=torch.tensor(fill_list,dtype=float_img.dtype,device=float_img.device).view(1,-1,1,1)ifmode=="nearest":float_img=torch.where(mask<0.5,fill_img.expand_as(float_img),float_img)else:# 'bilinear'# The following is mathematically equivalent to:# img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fillfloat_img=float_img.sub_(fill_img).mul_(mask).add_(fill_img)img=float_img.round_().to(img.dtype)ifnotfpelsefloat_imgreturnimg.reshape(output_shape)def_assert_grid_transform_inputs(image:torch.Tensor,matrix:Optional[List[float]],interpolation:str,fill:_FillTypeJIT,supported_interpolation_modes:List[str],coeffs:Optional[List[float]]=None,)->None:ifmatrixisnotNone:ifnotisinstance(matrix,list):raiseTypeError("Argument matrix should be a list")eliflen(matrix)!=6:raiseValueError("Argument matrix should have 6 float values")ifcoeffsisnotNoneandlen(coeffs)!=8:raiseValueError("Argument coeffs should have 8 float values")iffillisnotNone:ifisinstance(fill,(tuple,list)):length=len(fill)num_channels=image.shape[-3]iflength>1andlength!=num_channels:raiseValueError("The number of elements in 'fill' cannot broadcast to match the number of "f"channels of the image ({length} != {num_channels})")elifnotisinstance(fill,(int,float)):raiseValueError("Argument fill should be either int, float, tuple or list")ifinterpolationnotinsupported_interpolation_modes:raiseValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")def_affine_grid(theta:torch.Tensor,w:int,h:int,ow:int,oh:int,)->torch.Tensor:# https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/# AffineGridGenerator.cpp#L18# Difference with AffineGridGenerator is that:# 1) we normalize grid values after applying theta# 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotatedtype=theta.dtypedevice=theta.devicebase_grid=torch.empty(1,oh,ow,3,dtype=dtype,device=device)x_grid=torch.linspace((1.0-ow)*0.5,(ow-1.0)*0.5,steps=ow,device=device)base_grid[...,0].copy_(x_grid)y_grid=torch.linspace((1.0-oh)*0.5,(oh-1.0)*0.5,steps=oh,device=device).unsqueeze_(-1)base_grid[...,1].copy_(y_grid)base_grid[...,2].fill_(1)rescaled_theta=theta.transpose(1,2).div_(torch.tensor([0.5*w,0.5*h],dtype=dtype,device=device))output_grid=base_grid.view(1,oh*ow,3).bmm(rescaled_theta)returnoutput_grid.view(1,oh,ow,2)@_register_kernel_internal(affine,torch.Tensor)@_register_kernel_internal(affine,tv_tensors.Image)defaffine_image(image:torch.Tensor,angle:Union[int,float],translate:List[float],scale:float,shear:List[float],interpolation:Union[InterpolationMode,int]=InterpolationMode.NEAREST,fill:_FillTypeJIT=None,center:Optional[List[float]]=None,)->torch.Tensor:interpolation=_check_interpolation(interpolation)angle,translate,shear,center=_affine_parse_args(angle,translate,scale,shear,interpolation,center)height,width=image.shape[-2:]center_f=[0.0,0.0]ifcenterisnotNone:# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.center_f=[(c-s*0.5)forc,sinzip(center,[width,height])]translate_f=[float(t)fortintranslate]matrix=_get_inverse_affine_matrix(center_f,angle,translate_f,scale,shear)_assert_grid_transform_inputs(image,matrix,interpolation.value,fill,["nearest","bilinear"])dtype=image.dtypeiftorch.is_floating_point(image)elsetorch.float32theta=torch.tensor(matrix,dtype=dtype,device=image.device).reshape(1,2,3)grid=_affine_grid(theta,w=width,h=height,ow=width,oh=height)return_apply_grid_transform(image,grid,interpolation.value,fill=fill)@_register_kernel_internal(affine,PIL.Image.Image)def_affine_image_pil(image:PIL.Image.Image,angle:Union[int,float],translate:List[float],scale:float,shear:List[float],interpolation:Union[InterpolationMode,int]=InterpolationMode.NEAREST,fill:_FillTypeJIT=None,center:Optional[List[float]]=None,)->PIL.Image.Image:interpolation=_check_interpolation(interpolation)angle,translate,shear,center=_affine_parse_args(angle,translate,scale,shear,interpolation,center)# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)# it is visually better to estimate the center without 0.5 offset# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affineifcenterisNone:height,width=_get_size_image_pil(image)center=[width*0.5,height*0.5]matrix=_get_inverse_affine_matrix(center,angle,translate,scale,shear)return_FP.affine(image,matrix,interpolation=pil_modes_mapping[interpolation],fill=fill)def_affine_bounding_boxes_with_expand(bounding_boxes:torch.Tensor,format:tv_tensors.BoundingBoxFormat,canvas_size:Tuple[int,int],angle:Union[int,float],translate:List[float],scale:float,shear:List[float],center:Optional[List[float]]=None,expand:bool=False,)->Tuple[torch.Tensor,Tuple[int,int]]:ifbounding_boxes.numel()==0:returnbounding_boxes,canvas_sizeoriginal_shape=bounding_boxes.shapeoriginal_dtype=bounding_boxes.dtypebounding_boxes=bounding_boxes.clone()ifbounding_boxes.is_floating_point()elsebounding_boxes.float()dtype=bounding_boxes.dtypedevice=bounding_boxes.devicebounding_boxes=(convert_bounding_box_format(bounding_boxes,old_format=format,new_format=tv_tensors.BoundingBoxFormat.XYXY,inplace=True)).reshape(-1,4)angle,translate,shear,center=_affine_parse_args(angle,translate,scale,shear,InterpolationMode.NEAREST,center)ifcenterisNone:height,width=canvas_sizecenter=[width*0.5,height*0.5]affine_vector=_get_inverse_affine_matrix(center,angle,translate,scale,shear,inverted=False)transposed_affine_matrix=(torch.tensor(affine_vector,dtype=dtype,device=device,).reshape(2,3).T)# 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).# Tensor of points has shape (N * 4, 3), where N is the number of bboxes# Single point structure is similar to# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]points=bounding_boxes[:,[[0,1],[2,1],[2,3],[0,3]]].reshape(-1,2)points=torch.cat([points,torch.ones(points.shape[0],1,device=device,dtype=dtype)],dim=-1)# 2) Now let's transform the points using affine matrixtransformed_points=torch.matmul(points,transposed_affine_matrix)# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]# and compute bounding box from 4 transformed points:transformed_points=transformed_points.reshape(-1,4,2)out_bbox_mins,out_bbox_maxs=torch.aminmax(transformed_points,dim=1)out_bboxes=torch.cat([out_bbox_mins,out_bbox_maxs],dim=1)ifexpand:# Compute minimum point for transformed image frame:# Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.height,width=canvas_sizepoints=torch.tensor([[0.0,0.0,1.0],[0.0,float(height),1.0],[float(width),float(height),1.0],[float(width),0.0,1.0],],dtype=dtype,device=device,)new_points=torch.matmul(points,transposed_affine_matrix)tr=torch.amin(new_points,dim=0,keepdim=True)# Translate bounding boxesout_bboxes.sub_(tr.repeat((1,2)))# Estimate meta-data for image with inverted=Trueaffine_vector=_get_inverse_affine_matrix(center,angle,translate,scale,shear)new_width,new_height=_compute_affine_output_size(affine_vector,width,height)canvas_size=(new_height,new_width)out_bboxes=clamp_bounding_boxes(out_bboxes,format=tv_tensors.BoundingBoxFormat.XYXY,canvas_size=canvas_size)out_bboxes=convert_bounding_box_format(out_bboxes,old_format=tv_tensors.BoundingBoxFormat.XYXY,new_format=format,inplace=True).reshape(original_shape)out_bboxes=out_bboxes.to(original_dtype)returnout_bboxes,canvas_sizedefaffine_bounding_boxes(bounding_boxes:torch.Tensor,format:tv_tensors.BoundingBoxFormat,canvas_size:Tuple[int,int],angle:Union[int,float],translate:List[float],scale:float,shear:List[float],center:Optional[List[float]]=None,)->torch.Tensor:out_box,_=_affine_bounding_boxes_with_expand(bounding_boxes,format=format,canvas_size=canvas_size,angle=angle,translate=translate,scale=scale,shear=shear,center=center,expand=False,)returnout_box@_register_kernel_internal(affine,tv_tensors.BoundingBoxes,tv_tensor_wrapper=False)def_affine_bounding_boxes_dispatch(inpt:tv_tensors.BoundingBoxes,angle:Union[int,float],translate:List[float],scale:float,shear:List[float],center:Optional[List[float]]=None,**kwargs,)->tv_tensors.BoundingBoxes:output=affine_bounding_boxes(inpt.as_subclass(torch.Tensor),format=inpt.format,canvas_size=inpt.canvas_size,angle=angle,translate=translate,scale=scale,shear=shear,center=center,)returntv_tensors.wrap(output,like=inpt)defaffine_mask(mask:torch.Tensor,angle:Union[int,float],translate:List[float],scale:float,shear:List[float],fill:_FillTypeJIT=None,center:Optional[List[float]]=None,)->torch.Tensor:ifmask.ndim<3:mask=mask.unsqueeze(0)needs_squeeze=Trueelse:needs_squeeze=Falseoutput=affine_image(mask,angle=angle,translate=translate,scale=scale,shear=shear,interpolation=InterpolationMode.NEAREST,fill=fill,center=center,)ifneeds_squeeze:output=output.squeeze(0)returnoutput@_register_kernel_internal(affine,tv_tensors.Mask,tv_tensor_wrapper=False)def_affine_mask_dispatch(inpt:tv_tensors.Mask,angle:Union[int,float],translate:List[float],scale:float,shear:List[float],fill:_FillTypeJIT=None,center:Optional[List[float]]=None,**kwargs,)->tv_tensors.Mask:output=affine_mask(inpt.as_subclass(torch.Tensor),angle=angle,translate=translate,scale=scale,shear=shear,fill=fill,center=center,)returntv_tensors.wrap(output,like=inpt)@_register_kernel_internal(affine,tv_tensors.Video)defaffine_video(video:torch.Tensor,angle:Union[int,float],translate:List[float],scale:float,shear:List[float],interpolation:Union[InterpolationMode,int]=InterpolationMode.NEAREST,fill:_FillTypeJIT=None,center:Optional[List[float]]=None,)->torch.Tensor:returnaffine_image(video,angle=angle,translate=translate,scale=scale,shear=shear,interpolation=interpolation,fill=fill,center=center,)
[docs]defrotate(inpt:torch.Tensor,angle:float,interpolation:Union[InterpolationMode,int]=InterpolationMode.NEAREST,expand:bool=False,center:Optional[List[float]]=None,fill:_FillTypeJIT=None,)->torch.Tensor:"""See :class:`~torchvision.transforms.v2.RandomRotation` for details."""iftorch.jit.is_scripting():returnrotate_image(inpt,angle=angle,interpolation=interpolation,expand=expand,fill=fill,center=center)_log_api_usage_once(rotate)kernel=_get_kernel(rotate,type(inpt))returnkernel(inpt,angle=angle,interpolation=interpolation,expand=expand,fill=fill,center=center)
@_register_kernel_internal(rotate,torch.Tensor)@_register_kernel_internal(rotate,tv_tensors.Image)defrotate_image(image:torch.Tensor,angle:float,interpolation:Union[InterpolationMode,int]=InterpolationMode.NEAREST,expand:bool=False,center:Optional[List[float]]=None,fill:_FillTypeJIT=None,)->torch.Tensor:angle=angle%360# shift angle to [0, 360) range# fast path: transpose without affine transformifcenterisNone:ifangle==0:returnimage.clone()ifangle==180:returntorch.rot90(image,k=2,dims=(-2,-1))ifexpandorimage.shape[-1]==image.shape[-2]:ifangle==90:returntorch.rot90(image,k=1,dims=(-2,-1))ifangle==270:returntorch.rot90(image,k=3,dims=(-2,-1))interpolation=_check_interpolation(interpolation)input_height,input_width=image.shape[-2:]center_f=[0.0,0.0]ifcenterisnotNone:# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.center_f=[(c-s*0.5)forc,sinzip(center,[input_width,input_height])]# due to current incoherence of rotation angle direction between affine and rotate implementations# we need to set -angle.matrix=_get_inverse_affine_matrix(center_f,-angle,[0.0,0.0],1.0,[0.0,0.0])_assert_grid_transform_inputs(image,matrix,interpolation.value,fill,["nearest","bilinear"])output_width,output_height=(_compute_affine_output_size(matrix,input_width,input_height)ifexpandelse(input_width,input_height))dtype=image.dtypeiftorch.is_floating_point(image)elsetorch.float32theta=torch.tensor(matrix,dtype=dtype,device=image.device).reshape(1,2,3)grid=_affine_grid(theta,w=input_width,h=input_height,ow=output_width,oh=output_height)return_apply_grid_transform(image,grid,interpolation.value,fill=fill)@_register_kernel_internal(rotate,PIL.Image.Image)def_rotate_image_pil(image:PIL.Image.Image,angle:float,interpolation:Union[InterpolationMode,int]=InterpolationMode.NEAREST,expand:bool=False,center:Optional[List[float]]=None,fill:_FillTypeJIT=None,)->PIL.Image.Image:interpolation=_check_interpolation(interpolation)return_FP.rotate(image,angle,interpolation=pil_modes_mapping[interpolation],expand=expand,fill=fill,center=center)defrotate_bounding_boxes(bounding_boxes:torch.Tensor,format:tv_tensors.BoundingBoxFormat,canvas_size:Tuple[int,int],angle:float,expand:bool=False,center:Optional[List[float]]=None,)->Tuple[torch.Tensor,Tuple[int,int]]:return_affine_bounding_boxes_with_expand(bounding_boxes,format=format,canvas_size=canvas_size,angle=-angle,translate=[0.0,0.0],scale=1.0,shear=[0.0,0.0],center=center,expand=expand,)@_register_kernel_internal(rotate,tv_tensors.BoundingBoxes,tv_tensor_wrapper=False)def_rotate_bounding_boxes_dispatch(inpt:tv_tensors.BoundingBoxes,angle:float,expand:bool=False,center:Optional[List[float]]=None,**kwargs)->tv_tensors.BoundingBoxes:output,canvas_size=rotate_bounding_boxes(inpt.as_subclass(torch.Tensor),format=inpt.format,canvas_size=inpt.canvas_size,angle=angle,expand=expand,center=center,)returntv_tensors.wrap(output,like=inpt,canvas_size=canvas_size)defrotate_mask(mask:torch.Tensor,angle:float,expand:bool=False,center:Optional[List[float]]=None,fill:_FillTypeJIT=None,)->torch.Tensor:ifmask.ndim<3:mask=mask.unsqueeze(0)needs_squeeze=Trueelse:needs_squeeze=Falseoutput=rotate_image(mask,angle=angle,expand=expand,interpolation=InterpolationMode.NEAREST,fill=fill,center=center,)ifneeds_squeeze:output=output.squeeze(0)returnoutput@_register_kernel_internal(rotate,tv_tensors.Mask,tv_tensor_wrapper=False)def_rotate_mask_dispatch(inpt:tv_tensors.Mask,angle:float,expand:bool=False,center:Optional[List[float]]=None,fill:_FillTypeJIT=None,**kwargs,)->tv_tensors.Mask:output=rotate_mask(inpt.as_subclass(torch.Tensor),angle=angle,expand=expand,fill=fill,center=center)returntv_tensors.wrap(output,like=inpt)@_register_kernel_internal(rotate,tv_tensors.Video)defrotate_video(video:torch.Tensor,angle:float,interpolation:Union[InterpolationMode,int]=InterpolationMode.NEAREST,expand:bool=False,center:Optional[List[float]]=None,fill:_FillTypeJIT=None,)->torch.Tensor:returnrotate_image(video,angle,interpolation=interpolation,expand=expand,fill=fill,center=center)
[docs]defpad(inpt:torch.Tensor,padding:List[int],fill:Optional[Union[int,float,List[float]]]=None,padding_mode:str="constant",)->torch.Tensor:"""See :class:`~torchvision.transforms.v2.Pad` for details."""iftorch.jit.is_scripting():returnpad_image(inpt,padding=padding,fill=fill,padding_mode=padding_mode)_log_api_usage_once(pad)kernel=_get_kernel(pad,type(inpt))returnkernel(inpt,padding=padding,fill=fill,padding_mode=padding_mode)
def_parse_pad_padding(padding:Union[int,List[int]])->List[int]:ifisinstance(padding,int):pad_left=pad_right=pad_top=pad_bottom=paddingelifisinstance(padding,(tuple,list)):iflen(padding)==1:pad_left=pad_right=pad_top=pad_bottom=padding[0]eliflen(padding)==2:pad_left=pad_right=padding[0]pad_top=pad_bottom=padding[1]eliflen(padding)==4:pad_left=padding[0]pad_top=padding[1]pad_right=padding[2]pad_bottom=padding[3]else:raiseValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")else:raiseTypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}")return[pad_left,pad_right,pad_top,pad_bottom]@_register_kernel_internal(pad,torch.Tensor)@_register_kernel_internal(pad,tv_tensors.Image)defpad_image(image:torch.Tensor,padding:List[int],fill:Optional[Union[int,float,List[float]]]=None,padding_mode:str="constant",)->torch.Tensor:# Be aware that while `padding` has order `[left, top, right, bottom]`, `torch_padding` uses# `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad`# internally.torch_padding=_parse_pad_padding(padding)ifpadding_modenotin("constant","edge","reflect","symmetric"):raiseValueError(f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "f"but got `'{padding_mode}'`.")iffillisNone:fill=0ifisinstance(fill,(int,float)):return_pad_with_scalar_fill(image,torch_padding,fill=fill,padding_mode=padding_mode)eliflen(fill)==1:return_pad_with_scalar_fill(image,torch_padding,fill=fill[0],padding_mode=padding_mode)else:return_pad_with_vector_fill(image,torch_padding,fill=fill,padding_mode=padding_mode)def_pad_with_scalar_fill(image:torch.Tensor,torch_padding:List[int],fill:Union[int,float],padding_mode:str,)->torch.Tensor:shape=image.shapenum_channels,height,width=shape[-3:]batch_size=1forsinshape[:-3]:batch_size*=simage=image.reshape(batch_size,num_channels,height,width)ifpadding_mode=="edge":# Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map# the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`# name.padding_mode="replicate"ifpadding_mode=="constant":image=torch_pad(image,torch_padding,mode=padding_mode,value=float(fill))elifpadding_modein("reflect","replicate"):# `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.# TODO: See https://github.com/pytorch/pytorch/issues/40763dtype=image.dtypeifnotimage.is_floating_point():needs_cast=Trueimage=image.to(torch.float32)else:needs_cast=Falseimage=torch_pad(image,torch_padding,mode=padding_mode)ifneeds_cast:image=image.to(dtype)else:# padding_mode == "symmetric"image=_pad_symmetric(image,torch_padding)new_height,new_width=image.shape[-2:]returnimage.reshape(shape[:-3]+(num_channels,new_height,new_width))# TODO: This should be removed once torch_pad supports non-scalar padding valuesdef_pad_with_vector_fill(image:torch.Tensor,torch_padding:List[int],fill:List[float],padding_mode:str,)->torch.Tensor:ifpadding_mode!="constant":raiseValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")output=_pad_with_scalar_fill(image,torch_padding,fill=0,padding_mode="constant")left,right,top,bottom=torch_padding# We are creating the tensor in the autodetected dtype first and convert to the right one after to avoid an implicit# float -> int conversion. That happens for example for the valid input of a uint8 image with floating point fill# value.fill=torch.tensor(fill,device=image.device).to(dtype=image.dtype).reshape(-1,1,1)iftop>0:output[...,:top,:]=fillifleft>0:output[...,:,:left]=fillifbottom>0:output[...,-bottom:,:]=fillifright>0:output[...,:,-right:]=fillreturnoutput_pad_image_pil=_register_kernel_internal(pad,PIL.Image.Image)(_FP.pad)@_register_kernel_internal(pad,tv_tensors.Mask)defpad_mask(mask:torch.Tensor,padding:List[int],fill:Optional[Union[int,float,List[float]]]=None,padding_mode:str="constant",)->torch.Tensor:iffillisNone:fill=0ifisinstance(fill,(tuple,list)):raiseValueError("Non-scalar fill value is not supported")ifmask.ndim<3:mask=mask.unsqueeze(0)needs_squeeze=Trueelse:needs_squeeze=Falseoutput=pad_image(mask,padding=padding,fill=fill,padding_mode=padding_mode)ifneeds_squeeze:output=output.squeeze(0)returnoutputdefpad_bounding_boxes(bounding_boxes:torch.Tensor,format:tv_tensors.BoundingBoxFormat,canvas_size:Tuple[int,int],padding:List[int],padding_mode:str="constant",)->Tuple[torch.Tensor,Tuple[int,int]]:ifpadding_modenotin["constant"]:# TODO: add support of other padding modesraiseValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")left,right,top,bottom=_parse_pad_padding(padding)ifformat==tv_tensors.BoundingBoxFormat.XYXY:pad=[left,top,left,top]else:pad=[left,top,0,0]bounding_boxes=bounding_boxes+torch.tensor(pad,dtype=bounding_boxes.dtype,device=bounding_boxes.device)height,width=canvas_sizeheight+=top+bottomwidth+=left+rightcanvas_size=(height,width)returnclamp_bounding_boxes(bounding_boxes,format=format,canvas_size=canvas_size),canvas_size@_register_kernel_internal(pad,tv_tensors.BoundingBoxes,tv_tensor_wrapper=False)def_pad_bounding_boxes_dispatch(inpt:tv_tensors.BoundingBoxes,padding:List[int],padding_mode:str="constant",**kwargs)->tv_tensors.BoundingBoxes:output,canvas_size=pad_bounding_boxes(inpt.as_subclass(torch.Tensor),format=inpt.format,canvas_size=inpt.canvas_size,padding=padding,padding_mode=padding_mode,)returntv_tensors.wrap(output,like=inpt,canvas_size=canvas_size)@_register_kernel_internal(pad,tv_tensors.Video)defpad_video(video:torch.Tensor,padding:List[int],fill:Optional[Union[int,float,List[float]]]=None,padding_mode:str="constant",)->torch.Tensor:returnpad_image(video,padding,fill=fill,padding_mode=padding_mode)
[docs]defcrop(inpt:torch.Tensor,top:int,left:int,height:int,width:int)->torch.Tensor:"""See :class:`~torchvision.transforms.v2.RandomCrop` for details."""iftorch.jit.is_scripting():returncrop_image(inpt,top=top,left=left,height=height,width=width)_log_api_usage_once(crop)kernel=_get_kernel(crop,type(inpt))returnkernel(inpt,top=top,left=left,height=height,width=width)
@_register_kernel_internal(crop,torch.Tensor)@_register_kernel_internal(crop,tv_tensors.Image)defcrop_image(image:torch.Tensor,top:int,left:int,height:int,width:int)->torch.Tensor:h,w=image.shape[-2:]right=left+widthbottom=top+heightifleft<0ortop<0orright>worbottom>h:image=image[...,max(top,0):bottom,max(left,0):right]torch_padding=[max(min(right,0)-left,0),max(right-max(w,left),0),max(min(bottom,0)-top,0),max(bottom-max(h,top),0),]return_pad_with_scalar_fill(image,torch_padding,fill=0,padding_mode="constant")returnimage[...,top:bottom,left:right]_crop_image_pil=_FP.crop_register_kernel_internal(crop,PIL.Image.Image)(_crop_image_pil)defcrop_bounding_boxes(bounding_boxes:torch.Tensor,format:tv_tensors.BoundingBoxFormat,top:int,left:int,height:int,width:int,)->Tuple[torch.Tensor,Tuple[int,int]]:# Crop or implicit pad if left and/or top have negative values:ifformat==tv_tensors.BoundingBoxFormat.XYXY:sub=[left,top,left,top]else:sub=[left,top,0,0]bounding_boxes=bounding_boxes-torch.tensor(sub,dtype=bounding_boxes.dtype,device=bounding_boxes.device)canvas_size=(height,width)returnclamp_bounding_boxes(bounding_boxes,format=format,canvas_size=canvas_size),canvas_size@_register_kernel_internal(crop,tv_tensors.BoundingBoxes,tv_tensor_wrapper=False)def_crop_bounding_boxes_dispatch(inpt:tv_tensors.BoundingBoxes,top:int,left:int,height:int,width:int)->tv_tensors.BoundingBoxes:output,canvas_size=crop_bounding_boxes(inpt.as_subclass(torch.Tensor),format=inpt.format,top=top,left=left,height=height,width=width)returntv_tensors.wrap(output,like=inpt,canvas_size=canvas_size)@_register_kernel_internal(crop,tv_tensors.Mask)defcrop_mask(mask:torch.Tensor,top:int,left:int,height:int,width:int)->torch.Tensor:ifmask.ndim<3:mask=mask.unsqueeze(0)needs_squeeze=Trueelse:needs_squeeze=Falseoutput=crop_image(mask,top,left,height,width)ifneeds_squeeze:output=output.squeeze(0)returnoutput@_register_kernel_internal(crop,tv_tensors.Video)defcrop_video(video:torch.Tensor,top:int,left:int,height:int,width:int)->torch.Tensor:returncrop_image(video,top,left,height,width)
[docs]defperspective(inpt:torch.Tensor,startpoints:Optional[List[List[int]]],endpoints:Optional[List[List[int]]],interpolation:Union[InterpolationMode,int]=InterpolationMode.BILINEAR,fill:_FillTypeJIT=None,coefficients:Optional[List[float]]=None,)->torch.Tensor:"""See :class:`~torchvision.transforms.v2.RandomPerspective` for details."""iftorch.jit.is_scripting():returnperspective_image(inpt,startpoints=startpoints,endpoints=endpoints,interpolation=interpolation,fill=fill,coefficients=coefficients,)_log_api_usage_once(perspective)kernel=_get_kernel(perspective,type(inpt))returnkernel(inpt,startpoints=startpoints,endpoints=endpoints,interpolation=interpolation,fill=fill,coefficients=coefficients,)
def_perspective_grid(coeffs:List[float],ow:int,oh:int,dtype:torch.dtype,device:torch.device)->torch.Tensor:# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/# src/libImaging/Geometry.c#L394## x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)#theta1=torch.tensor([[[coeffs[0],coeffs[1],coeffs[2]],[coeffs[3],coeffs[4],coeffs[5]]]],dtype=dtype,device=device)theta2=torch.tensor([[[coeffs[6],coeffs[7],1.0],[coeffs[6],coeffs[7],1.0]]],dtype=dtype,device=device)d=0.5base_grid=torch.empty(1,oh,ow,3,dtype=dtype,device=device)x_grid=torch.linspace(d,ow+d-1.0,steps=ow,device=device,dtype=dtype)base_grid[...,0].copy_(x_grid)y_grid=torch.linspace(d,oh+d-1.0,steps=oh,device=device,dtype=dtype).unsqueeze_(-1)base_grid[...,1].copy_(y_grid)base_grid[...,2].fill_(1)rescaled_theta1=theta1.transpose(1,2).div_(torch.tensor([0.5*ow,0.5*oh],dtype=dtype,device=device))shape=(1,oh*ow,3)output_grid1=base_grid.view(shape).bmm(rescaled_theta1)output_grid2=base_grid.view(shape).bmm(theta2.transpose(1,2))output_grid=output_grid1.div_(output_grid2).sub_(1.0)returnoutput_grid.view(1,oh,ow,2)def_perspective_coefficients(startpoints:Optional[List[List[int]]],endpoints:Optional[List[List[int]]],coefficients:Optional[List[float]],)->List[float]:ifcoefficientsisnotNone:ifstartpointsisnotNoneandendpointsisnotNone:raiseValueError("The startpoints/endpoints and the coefficients shouldn't be defined concurrently.")eliflen(coefficients)!=8:raiseValueError("Argument coefficients should have 8 float values")returncoefficientselifstartpointsisnotNoneandendpointsisnotNone:return_get_perspective_coeffs(startpoints,endpoints)else:raiseValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.")@_register_kernel_internal(perspective,torch.Tensor)@_register_kernel_internal(perspective,tv_tensors.Image)defperspective_image(image:torch.Tensor,startpoints:Optional[List[List[int]]],endpoints:Optional[List[List[int]]],interpolation:Union[InterpolationMode,int]=InterpolationMode.BILINEAR,fill:_FillTypeJIT=None,coefficients:Optional[List[float]]=None,)->torch.Tensor:perspective_coeffs=_perspective_coefficients(startpoints,endpoints,coefficients)interpolation=_check_interpolation(interpolation)_assert_grid_transform_inputs(image,matrix=None,interpolation=interpolation.value,fill=fill,supported_interpolation_modes=["nearest","bilinear"],coeffs=perspective_coeffs,)oh,ow=image.shape[-2:]dtype=image.dtypeiftorch.is_floating_point(image)elsetorch.float32grid=_perspective_grid(perspective_coeffs,ow=ow,oh=oh,dtype=dtype,device=image.device)return_apply_grid_transform(image,grid,interpolation.value,fill=fill)@_register_kernel_internal(perspective,PIL.Image.Image)def_perspective_image_pil(image:PIL.Image.Image,startpoints:Optional[List[List[int]]],endpoints:Optional[List[List[int]]],interpolation:Union[InterpolationMode,int]=InterpolationMode.BILINEAR,fill:_FillTypeJIT=None,coefficients:Optional[List[float]]=None,)->PIL.Image.Image:perspective_coeffs=_perspective_coefficients(startpoints,endpoints,coefficients)interpolation=_check_interpolation(interpolation)return_FP.perspective(image,perspective_coeffs,interpolation=pil_modes_mapping[interpolation],fill=fill)defperspective_bounding_boxes(bounding_boxes:torch.Tensor,format:tv_tensors.BoundingBoxFormat,canvas_size:Tuple[int,int],startpoints:Optional[List[List[int]]],endpoints:Optional[List[List[int]]],coefficients:Optional[List[float]]=None,)->torch.Tensor:ifbounding_boxes.numel()==0:returnbounding_boxesperspective_coeffs=_perspective_coefficients(startpoints,endpoints,coefficients)original_shape=bounding_boxes.shape# TODO: first cast to float if bbox is int64 before convert_bounding_box_formatbounding_boxes=(convert_bounding_box_format(bounding_boxes,old_format=format,new_format=tv_tensors.BoundingBoxFormat.XYXY)).reshape(-1,4)dtype=bounding_boxes.dtypeiftorch.is_floating_point(bounding_boxes)elsetorch.float32device=bounding_boxes.device# perspective_coeffs are computed as endpoint -> start point# We have to invert perspective_coeffs for bboxes:# (x, y) - end point and (x_out, y_out) - start point# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)# and we would like to get:# x = (inv_coeffs[0] * x_out + inv_coeffs[1] * y_out + inv_coeffs[2])# / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)# y = (inv_coeffs[3] * x_out + inv_coeffs[4] * y_out + inv_coeffs[5])# / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)# and compute inv_coeffs in terms of coeffsdenom=perspective_coeffs[0]*perspective_coeffs[4]-perspective_coeffs[1]*perspective_coeffs[3]ifdenom==0:raiseRuntimeError(f"Provided perspective_coeffs {perspective_coeffs} can not be inverted to transform bounding boxes. "f"Denominator is zero, denom={denom}")inv_coeffs=[(perspective_coeffs[4]-perspective_coeffs[5]*perspective_coeffs[7])/denom,(-perspective_coeffs[1]+perspective_coeffs[2]*perspective_coeffs[7])/denom,(perspective_coeffs[1]*perspective_coeffs[5]-perspective_coeffs[2]*perspective_coeffs[4])/denom,(-perspective_coeffs[3]+perspective_coeffs[5]*perspective_coeffs[6])/denom,(perspective_coeffs[0]-perspective_coeffs[2]*perspective_coeffs[6])/denom,(-perspective_coeffs[0]*perspective_coeffs[5]+perspective_coeffs[2]*perspective_coeffs[3])/denom,(-perspective_coeffs[4]*perspective_coeffs[6]+perspective_coeffs[3]*perspective_coeffs[7])/denom,(-perspective_coeffs[0]*perspective_coeffs[7]+perspective_coeffs[1]*perspective_coeffs[6])/denom,]theta1=torch.tensor([[inv_coeffs[0],inv_coeffs[1],inv_coeffs[2]],[inv_coeffs[3],inv_coeffs[4],inv_coeffs[5]]],dtype=dtype,device=device,)theta2=torch.tensor([[inv_coeffs[6],inv_coeffs[7],1.0],[inv_coeffs[6],inv_coeffs[7],1.0]],dtype=dtype,device=device)# 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).# Tensor of points has shape (N * 4, 3), where N is the number of bboxes# Single point structure is similar to# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]points=bounding_boxes[:,[[0,1],[2,1],[2,3],[0,3]]].reshape(-1,2)points=torch.cat([points,torch.ones(points.shape[0],1,device=points.device)],dim=-1)# 2) Now let's transform the points using perspective matrices# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)numer_points=torch.matmul(points,theta1.T)denom_points=torch.matmul(points,theta2.T)transformed_points=numer_points.div_(denom_points)# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]# and compute bounding box from 4 transformed points:transformed_points=transformed_points.reshape(-1,4,2)out_bbox_mins,out_bbox_maxs=torch.aminmax(transformed_points,dim=1)out_bboxes=clamp_bounding_boxes(torch.cat([out_bbox_mins,out_bbox_maxs],dim=1).to(bounding_boxes.dtype),format=tv_tensors.BoundingBoxFormat.XYXY,canvas_size=canvas_size,)# out_bboxes should be of shape [N boxes, 4]returnconvert_bounding_box_format(out_bboxes,old_format=tv_tensors.BoundingBoxFormat.XYXY,new_format=format,inplace=True).reshape(original_shape)@_register_kernel_internal(perspective,tv_tensors.BoundingBoxes,tv_tensor_wrapper=False)def_perspective_bounding_boxes_dispatch(inpt:tv_tensors.BoundingBoxes,startpoints:Optional[List[List[int]]],endpoints:Optional[List[List[int]]],coefficients:Optional[List[float]]=None,**kwargs,)->tv_tensors.BoundingBoxes:output=perspective_bounding_boxes(inpt.as_subclass(torch.Tensor),format=inpt.format,canvas_size=inpt.canvas_size,startpoints=startpoints,endpoints=endpoints,coefficients=coefficients,)returntv_tensors.wrap(output,like=inpt)defperspective_mask(mask:torch.Tensor,startpoints:Optional[List[List[int]]],endpoints:Optional[List[List[int]]],fill:_FillTypeJIT=None,coefficients:Optional[List[float]]=None,)->torch.Tensor:ifmask.ndim<3:mask=mask.unsqueeze(0)needs_squeeze=Trueelse:needs_squeeze=Falseoutput=perspective_image(mask,startpoints,endpoints,interpolation=InterpolationMode.NEAREST,fill=fill,coefficients=coefficients)ifneeds_squeeze:output=output.squeeze(0)returnoutput@_register_kernel_internal(perspective,tv_tensors.Mask,tv_tensor_wrapper=False)def_perspective_mask_dispatch(inpt:tv_tensors.Mask,startpoints:Optional[List[List[int]]],endpoints:Optional[List[List[int]]],fill:_FillTypeJIT=None,coefficients:Optional[List[float]]=None,**kwargs,)->tv_tensors.Mask:output=perspective_mask(inpt.as_subclass(torch.Tensor),startpoints=startpoints,endpoints=endpoints,fill=fill,coefficients=coefficients,)returntv_tensors.wrap(output,like=inpt)@_register_kernel_internal(perspective,tv_tensors.Video)defperspective_video(video:torch.Tensor,startpoints:Optional[List[List[int]]],endpoints:Optional[List[List[int]]],interpolation:Union[InterpolationMode,int]=InterpolationMode.BILINEAR,fill:_FillTypeJIT=None,coefficients:Optional[List[float]]=None,)->torch.Tensor:returnperspective_image(video,startpoints,endpoints,interpolation=interpolation,fill=fill,coefficients=coefficients)
[docs]defelastic(inpt:torch.Tensor,displacement:torch.Tensor,interpolation:Union[InterpolationMode,int]=InterpolationMode.BILINEAR,fill:_FillTypeJIT=None,)->torch.Tensor:"""See :class:`~torchvision.transforms.v2.ElasticTransform` for details."""iftorch.jit.is_scripting():returnelastic_image(inpt,displacement=displacement,interpolation=interpolation,fill=fill)_log_api_usage_once(elastic)kernel=_get_kernel(elastic,type(inpt))returnkernel(inpt,displacement=displacement,interpolation=interpolation,fill=fill)
elastic_transform=elastic@_register_kernel_internal(elastic,torch.Tensor)@_register_kernel_internal(elastic,tv_tensors.Image)defelastic_image(image:torch.Tensor,displacement:torch.Tensor,interpolation:Union[InterpolationMode,int]=InterpolationMode.BILINEAR,fill:_FillTypeJIT=None,)->torch.Tensor:ifnotisinstance(displacement,torch.Tensor):raiseTypeError("Argument displacement should be a Tensor")interpolation=_check_interpolation(interpolation)height,width=image.shape[-2:]device=image.devicedtype=image.dtypeiftorch.is_floating_point(image)elsetorch.float32# Patch: elastic transform should support (cpu,f16) inputis_cpu_half=device.type=="cpu"anddtype==torch.float16ifis_cpu_half:image=image.to(torch.float32)dtype=torch.float32# We are aware that if input image dtype is uint8 and displacement is float64 then# displacement will be cast to float32 and all computations will be done with float32# We can fix this later if neededexpected_shape=(1,height,width,2)ifexpected_shape!=displacement.shape:raiseValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")grid=_create_identity_grid((height,width),device=device,dtype=dtype).add_(displacement.to(dtype=dtype,device=device))output=_apply_grid_transform(image,grid,interpolation.value,fill=fill)ifis_cpu_half:output=output.to(torch.float16)returnoutput@_register_kernel_internal(elastic,PIL.Image.Image)def_elastic_image_pil(image:PIL.Image.Image,displacement:torch.Tensor,interpolation:Union[InterpolationMode,int]=InterpolationMode.BILINEAR,fill:_FillTypeJIT=None,)->PIL.Image.Image:t_img=pil_to_tensor(image)output=elastic_image(t_img,displacement,interpolation=interpolation,fill=fill)returnto_pil_image(output,mode=image.mode)def_create_identity_grid(size:Tuple[int,int],device:torch.device,dtype:torch.dtype)->torch.Tensor:sy,sx=sizebase_grid=torch.empty(1,sy,sx,2,device=device,dtype=dtype)x_grid=torch.linspace((-sx+1)/sx,(sx-1)/sx,sx,device=device,dtype=dtype)base_grid[...,0].copy_(x_grid)y_grid=torch.linspace((-sy+1)/sy,(sy-1)/sy,sy,device=device,dtype=dtype).unsqueeze_(-1)base_grid[...,1].copy_(y_grid)returnbase_griddefelastic_bounding_boxes(bounding_boxes:torch.Tensor,format:tv_tensors.BoundingBoxFormat,canvas_size:Tuple[int,int],displacement:torch.Tensor,)->torch.Tensor:expected_shape=(1,canvas_size[0],canvas_size[1],2)ifnotisinstance(displacement,torch.Tensor):raiseTypeError("Argument displacement should be a Tensor")elifdisplacement.shape!=expected_shape:raiseValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")ifbounding_boxes.numel()==0:returnbounding_boxes# TODO: add in docstring about approximation we are doing for grid inversiondevice=bounding_boxes.devicedtype=bounding_boxes.dtypeiftorch.is_floating_point(bounding_boxes)elsetorch.float32ifdisplacement.dtype!=dtypeordisplacement.device!=device:displacement=displacement.to(dtype=dtype,device=device)original_shape=bounding_boxes.shape# TODO: first cast to float if bbox is int64 before convert_bounding_box_formatbounding_boxes=(convert_bounding_box_format(bounding_boxes,old_format=format,new_format=tv_tensors.BoundingBoxFormat.XYXY)).reshape(-1,4)id_grid=_create_identity_grid(canvas_size,device=device,dtype=dtype)# We construct an approximation of inverse grid as inv_grid = id_grid - displacement# This is not an exact inverse of the gridinv_grid=id_grid.sub_(displacement)# Get points from bboxespoints=bounding_boxes[:,[[0,1],[2,1],[2,3],[0,3]]].reshape(-1,2)ifpoints.is_floating_point():points=points.ceil_()index_xy=points.to(dtype=torch.long)index_x,index_y=index_xy[:,0],index_xy[:,1]# Transform points:t_size=torch.tensor(canvas_size[::-1],device=displacement.device,dtype=displacement.dtype)transformed_points=inv_grid[0,index_y,index_x,:].add_(1).mul_(0.5*t_size).sub_(0.5)transformed_points=transformed_points.reshape(-1,4,2)out_bbox_mins,out_bbox_maxs=torch.aminmax(transformed_points,dim=1)out_bboxes=clamp_bounding_boxes(torch.cat([out_bbox_mins,out_bbox_maxs],dim=1).to(bounding_boxes.dtype),format=tv_tensors.BoundingBoxFormat.XYXY,canvas_size=canvas_size,)returnconvert_bounding_box_format(out_bboxes,old_format=tv_tensors.BoundingBoxFormat.XYXY,new_format=format,inplace=True).reshape(original_shape)@_register_kernel_internal(elastic,tv_tensors.BoundingBoxes,tv_tensor_wrapper=False)def_elastic_bounding_boxes_dispatch(inpt:tv_tensors.BoundingBoxes,displacement:torch.Tensor,**kwargs)->tv_tensors.BoundingBoxes:output=elastic_bounding_boxes(inpt.as_subclass(torch.Tensor),format=inpt.format,canvas_size=inpt.canvas_size,displacement=displacement)returntv_tensors.wrap(output,like=inpt)defelastic_mask(mask:torch.Tensor,displacement:torch.Tensor,fill:_FillTypeJIT=None,)->torch.Tensor:ifmask.ndim<3:mask=mask.unsqueeze(0)needs_squeeze=Trueelse:needs_squeeze=Falseoutput=elastic_image(mask,displacement=displacement,interpolation=InterpolationMode.NEAREST,fill=fill)ifneeds_squeeze:output=output.squeeze(0)returnoutput@_register_kernel_internal(elastic,tv_tensors.Mask,tv_tensor_wrapper=False)def_elastic_mask_dispatch(inpt:tv_tensors.Mask,displacement:torch.Tensor,fill:_FillTypeJIT=None,**kwargs)->tv_tensors.Mask:output=elastic_mask(inpt.as_subclass(torch.Tensor),displacement=displacement,fill=fill)returntv_tensors.wrap(output,like=inpt)@_register_kernel_internal(elastic,tv_tensors.Video)defelastic_video(video:torch.Tensor,displacement:torch.Tensor,interpolation:Union[InterpolationMode,int]=InterpolationMode.BILINEAR,fill:_FillTypeJIT=None,)->torch.Tensor:returnelastic_image(video,displacement,interpolation=interpolation,fill=fill)
[docs]defcenter_crop(inpt:torch.Tensor,output_size:List[int])->torch.Tensor:"""See :class:`~torchvision.transforms.v2.RandomCrop` for details."""iftorch.jit.is_scripting():returncenter_crop_image(inpt,output_size=output_size)_log_api_usage_once(center_crop)kernel=_get_kernel(center_crop,type(inpt))returnkernel(inpt,output_size=output_size)
[docs]defresized_crop(inpt:torch.Tensor,top:int,left:int,height:int,width:int,size:List[int],interpolation:Union[InterpolationMode,int]=InterpolationMode.BILINEAR,antialias:Optional[bool]=True,)->torch.Tensor:"""See :class:`~torchvision.transforms.v2.RandomResizedCrop` for details."""iftorch.jit.is_scripting():returnresized_crop_image(inpt,top=top,left=left,height=height,width=width,size=size,interpolation=interpolation,antialias=antialias,)_log_api_usage_once(resized_crop)kernel=_get_kernel(resized_crop,type(inpt))returnkernel(inpt,top=top,left=left,height=height,width=width,size=size,interpolation=interpolation,antialias=antialias,)
@_register_kernel_internal(resized_crop,torch.Tensor)@_register_kernel_internal(resized_crop,tv_tensors.Image)defresized_crop_image(image:torch.Tensor,top:int,left:int,height:int,width:int,size:List[int],interpolation:Union[InterpolationMode,int]=InterpolationMode.BILINEAR,antialias:Optional[bool]=True,)->torch.Tensor:image=crop_image(image,top,left,height,width)returnresize_image(image,size,interpolation=interpolation,antialias=antialias)def_resized_crop_image_pil(image:PIL.Image.Image,top:int,left:int,height:int,width:int,size:List[int],interpolation:Union[InterpolationMode,int]=InterpolationMode.BILINEAR,)->PIL.Image.Image:image=_crop_image_pil(image,top,left,height,width)return_resize_image_pil(image,size,interpolation=interpolation)@_register_kernel_internal(resized_crop,PIL.Image.Image)def_resized_crop_image_pil_dispatch(image:PIL.Image.Image,top:int,left:int,height:int,width:int,size:List[int],interpolation:Union[InterpolationMode,int]=InterpolationMode.BILINEAR,antialias:Optional[bool]=True,)->PIL.Image.Image:ifantialiasisFalse:warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")return_resized_crop_image_pil(image,top=top,left=left,height=height,width=width,size=size,interpolation=interpolation,)defresized_crop_bounding_boxes(bounding_boxes:torch.Tensor,format:tv_tensors.BoundingBoxFormat,top:int,left:int,height:int,width:int,size:List[int],)->Tuple[torch.Tensor,Tuple[int,int]]:bounding_boxes,canvas_size=crop_bounding_boxes(bounding_boxes,format,top,left,height,width)returnresize_bounding_boxes(bounding_boxes,canvas_size=canvas_size,size=size)@_register_kernel_internal(resized_crop,tv_tensors.BoundingBoxes,tv_tensor_wrapper=False)def_resized_crop_bounding_boxes_dispatch(inpt:tv_tensors.BoundingBoxes,top:int,left:int,height:int,width:int,size:List[int],**kwargs)->tv_tensors.BoundingBoxes:output,canvas_size=resized_crop_bounding_boxes(inpt.as_subclass(torch.Tensor),format=inpt.format,top=top,left=left,height=height,width=width,size=size)returntv_tensors.wrap(output,like=inpt,canvas_size=canvas_size)defresized_crop_mask(mask:torch.Tensor,top:int,left:int,height:int,width:int,size:List[int],)->torch.Tensor:mask=crop_mask(mask,top,left,height,width)returnresize_mask(mask,size)@_register_kernel_internal(resized_crop,tv_tensors.Mask,tv_tensor_wrapper=False)def_resized_crop_mask_dispatch(inpt:tv_tensors.Mask,top:int,left:int,height:int,width:int,size:List[int],**kwargs)->tv_tensors.Mask:output=resized_crop_mask(inpt.as_subclass(torch.Tensor),top=top,left=left,height=height,width=width,size=size)returntv_tensors.wrap(output,like=inpt)@_register_kernel_internal(resized_crop,tv_tensors.Video)defresized_crop_video(video:torch.Tensor,top:int,left:int,height:int,width:int,size:List[int],interpolation:Union[InterpolationMode,int]=InterpolationMode.BILINEAR,antialias:Optional[bool]=True,)->torch.Tensor:returnresized_crop_image(video,top,left,height,width,antialias=antialias,size=size,interpolation=interpolation)
[docs]deffive_crop(inpt:torch.Tensor,size:List[int])->Tuple[torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor]:"""See :class:`~torchvision.transforms.v2.FiveCrop` for details."""iftorch.jit.is_scripting():returnfive_crop_image(inpt,size=size)_log_api_usage_once(five_crop)kernel=_get_kernel(five_crop,type(inpt))returnkernel(inpt,size=size)
def_parse_five_crop_size(size:List[int])->List[int]:ifisinstance(size,numbers.Number):s=int(size)size=[s,s]elifisinstance(size,(tuple,list))andlen(size)==1:s=size[0]size=[s,s]iflen(size)!=2:raiseValueError("Please provide only two dimensions (h, w) for size.")returnsize@_register_five_ten_crop_kernel_internal(five_crop,torch.Tensor)@_register_five_ten_crop_kernel_internal(five_crop,tv_tensors.Image)deffive_crop_image(image:torch.Tensor,size:List[int])->Tuple[torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor]:crop_height,crop_width=_parse_five_crop_size(size)image_height,image_width=image.shape[-2:]ifcrop_width>image_widthorcrop_height>image_height:raiseValueError(f"Requested crop size {size} is bigger than input size {(image_height,image_width)}")tl=crop_image(image,0,0,crop_height,crop_width)tr=crop_image(image,0,image_width-crop_width,crop_height,crop_width)bl=crop_image(image,image_height-crop_height,0,crop_height,crop_width)br=crop_image(image,image_height-crop_height,image_width-crop_width,crop_height,crop_width)center=center_crop_image(image,[crop_height,crop_width])returntl,tr,bl,br,center@_register_five_ten_crop_kernel_internal(five_crop,PIL.Image.Image)def_five_crop_image_pil(image:PIL.Image.Image,size:List[int])->Tuple[PIL.Image.Image,PIL.Image.Image,PIL.Image.Image,PIL.Image.Image,PIL.Image.Image]:crop_height,crop_width=_parse_five_crop_size(size)image_height,image_width=_get_size_image_pil(image)ifcrop_width>image_widthorcrop_height>image_height:raiseValueError(f"Requested crop size {size} is bigger than input size {(image_height,image_width)}")tl=_crop_image_pil(image,0,0,crop_height,crop_width)tr=_crop_image_pil(image,0,image_width-crop_width,crop_height,crop_width)bl=_crop_image_pil(image,image_height-crop_height,0,crop_height,crop_width)br=_crop_image_pil(image,image_height-crop_height,image_width-crop_width,crop_height,crop_width)center=_center_crop_image_pil(image,[crop_height,crop_width])returntl,tr,bl,br,center@_register_five_ten_crop_kernel_internal(five_crop,tv_tensors.Video)deffive_crop_video(video:torch.Tensor,size:List[int])->Tuple[torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor]:returnfive_crop_image(video,size)
[docs]deften_crop(inpt:torch.Tensor,size:List[int],vertical_flip:bool=False)->Tuple[torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,]:"""See :class:`~torchvision.transforms.v2.TenCrop` for details."""iftorch.jit.is_scripting():returnten_crop_image(inpt,size=size,vertical_flip=vertical_flip)_log_api_usage_once(ten_crop)kernel=_get_kernel(ten_crop,type(inpt))returnkernel(inpt,size=size,vertical_flip=vertical_flip)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.