torch.nn.functional.grid_sample¶
- torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=None)[source][source]¶
Compute grid sample.
Given an
input
and a flow-fieldgrid
, computes theoutput
usinginput
values and pixel locations fromgrid
.Currently, only spatial (4-D) and volumetric (5-D)
input
are supported.In the spatial (4-D) case, for
input
with shape andgrid
with shape , the output will have shape .For each output location
output[n, :, h, w]
, the size-2 vectorgrid[n, h, w]
specifiesinput
pixel locationsx
andy
, which are used to interpolate the output valueoutput[n, :, h, w]
. In the case of 5D inputs,grid[n, d, h, w]
specifies thex
,y
,z
pixel locations for interpolatingoutput[n, :, d, h, w]
.mode
argument specifiesnearest
orbilinear
interpolation method to sample the input pixels.grid
specifies the sampling pixel locations normalized by theinput
spatial dimensions. Therefore, it should have most values in the range of[-1, 1]
. For example, valuesx = -1, y = -1
is the left-top pixel ofinput
, and valuesx = 1, y = 1
is the right-bottom pixel ofinput
.If
grid
has values outside the range of[-1, 1]
, the corresponding outputs are handled as defined bypadding_mode
. Options arepadding_mode="zeros"
: use0
for out-of-bound grid locations,padding_mode="border"
: use border values for out-of-bound grid locations,padding_mode="reflection"
: use values at locations reflected by the border for out-of-bound grid locations. For location far away from the border, it will keep being reflected until becoming in bound, e.g., (normalized) pixel locationx = -3.5
reflects by border-1
and becomesx' = 1.5
, then reflects by border1
and becomesx'' = -0.5
.
Note
This function is often used in conjunction with
affine_grid()
to build Spatial Transformer Networks .Note
When using the CUDA backend, this operation may induce nondeterministic behaviour in its backward pass that is not easily switched off. Please see the notes on Reproducibility for background.
Note
NaN values in
grid
would be interpreted as-1
.- Parameters
input (Tensor) – input of shape (4-D case) or (5-D case)
grid (Tensor) – flow-field of shape (4-D case) or (5-D case)
mode (str) – interpolation mode to calculate output values
'bilinear'
|'nearest'
|'bicubic'
. Default:'bilinear'
Note:mode='bicubic'
supports only 4-D input. Whenmode='bilinear'
and the input is 5-D, the interpolation mode used internally will actually be trilinear. However, when the input is 4-D, the interpolation mode will legitimately be bilinear.padding_mode (str) – padding mode for outside grid values
'zeros'
|'border'
|'reflection'
. Default:'zeros'
align_corners (bool, optional) – Geometrically, we consider the pixels of the input as squares rather than points. If set to
True
, the extrema (-1
and1
) are considered as referring to the center points of the input’s corner pixels. If set toFalse
, they are instead considered as referring to the corner points of the input’s corner pixels, making the sampling more resolution agnostic. This option parallels thealign_corners
option ininterpolate()
, and so whichever option is used here should also be used there to resize the input image before grid sampling. Default:False
- Returns
output Tensor
- Return type
output (Tensor)
Warning
When
align_corners = True
, the grid positions depend on the pixel size relative to the input image size, and so the locations sampled bygrid_sample()
will differ for the same input given at different resolutions (that is, after being upsampled or downsampled). The default behavior up to version 1.2.0 wasalign_corners = True
. Since then, the default behavior has been changed toalign_corners = False
, in order to bring it in line with the default forinterpolate()
.Note
mode='bicubic'
is implemented using the cubic convolution algorithm with . The constant might be different from packages to packages. For example, PIL and OpenCV use -0.5 and -0.75 respectively. This algorithm may “overshoot” the range of values it’s interpolating. For example, it may produce negative values or values greater than 255 when interpolating input in [0, 255]. Clamp the results withtorch.clamp()
to ensure they are within the valid range.