Program Listing for File upsampling.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/functional/upsampling.h
)
#pragma once
#include <c10/util/irange.h>
#include <torch/nn/functional/pooling.h>
#include <torch/nn/options/upsampling.h>
#include <cmath>
#include <utility>
namespace torch {
namespace nn {
namespace functional {
inline std::vector<int64_t> _interp_output_size(
int64_t dim,
std::tuple<
Tensor,
c10::optional<std::vector<int64_t>>,
c10::optional<std::vector<double>>,
c10::optional<bool>> closed_over_args) {
auto [input, size, scale_factor, recompute_scale_factor] = closed_over_args;
if (size == c10::nullopt && scale_factor == c10::nullopt) {
TORCH_CHECK(false, "either size or scale_factor should be defined");
}
if (size != c10::nullopt && scale_factor != c10::nullopt) {
TORCH_CHECK(false, "only one of size or scale_factor should be defined");
}
if (scale_factor != c10::nullopt) {
if (static_cast<int64_t>(scale_factor.value().size()) != dim) {
TORCH_CHECK(
false,
"scale_factor shape must match input shape. ",
"Input is ",
dim,
"D, scale_factor size is ",
torch::ArrayRef<double>(*scale_factor));
}
}
if (size != c10::nullopt) {
return *size;
}
TORCH_INTERNAL_ASSERT(scale_factor != c10::nullopt);
auto scale_factors = *scale_factor;
if (recompute_scale_factor == c10::nullopt) {
// only warn when the scales have floating values since
// the result for ints is the same with/without recompute_scale_factor
bool is_float_scale_factor = false;
for (double scale : scale_factors) {
is_float_scale_factor = floor(scale) != scale;
if (is_float_scale_factor) {
break;
}
}
if (is_float_scale_factor) {
TORCH_WARN(
"The default behavior for interpolate/upsample with float scale_factor changed "
"in 1.6.0 to align with other frameworks/libraries, and uses scale_factor directly, "
"instead of relying on the computed output size. "
"If you wish to keep the old behavior, please set recompute_scale_factor=True. "
"See the documentation of nn.Upsample for details. ");
}
}
std::vector<int64_t> ret;
for (const auto i : c10::irange(dim)) {
ret.emplace_back(static_cast<int64_t>(
floor(static_cast<double>(input.size(i + 2)) * scale_factors[i])));
}
return ret;
}
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor interpolate(
const Tensor& input,
const c10::optional<std::vector<int64_t>>& size,
const c10::optional<std::vector<double>>& scale_factor,
InterpolateFuncOptions::mode_t mode,
c10::optional<bool> align_corners,
c10::optional<bool> recompute_scale_factor,
bool antialias) {
if (std::holds_alternative<enumtype::kNearest>(mode) ||
std::get_if<enumtype::kArea>(&mode)) {
if (align_corners != c10::nullopt) {
TORCH_CHECK(
false,
"align_corners option can only be set with the "
"interpolating modes: linear | bilinear | bicubic | trilinear");
}
} else {
if (align_corners == c10::nullopt) {
TORCH_WARN(
"Default upsampling behavior when mode=",
enumtype::get_enum_name(mode),
" is changed "
"to align_corners=False since 0.4.0. Please specify "
"align_corners=True if the old behavior is desired. "
"See the documentation of nn.Upsample for details.");
align_corners = false;
}
}
TORCH_CHECK(
input.dim() >= 3 && input.dim() <= 5,
"Input Error: Only 3D, 4D and 5D input Tensors supported "
"(got ",
input.dim(),
"D) for the modes: nearest | linear | bilinear | bicubic | trilinear "
"(got ",
enumtype::get_enum_name(mode),
")");
auto scale_factor_len = input.dim() - 2;
std::vector<c10::optional<double>> scale_factor_list(
scale_factor_len, c10::nullopt);
if (scale_factor != c10::nullopt && !recompute_scale_factor.value_or(false)) {
auto _scale_factor_repeated = *scale_factor;
scale_factor_list = {};
for (const auto& elem : _scale_factor_repeated) {
scale_factor_list.emplace_back(elem);
}
}
if (antialias &&
!(input.dim() == 4 &&
(std::get_if<enumtype::kBilinear>(&mode) ||
std::get_if<enumtype::kBicubic>(&mode)))) {
TORCH_CHECK(
false,
"Anti-alias option is only supported for bilinear and bicubic modes");
}
auto closed_over_args =
std::make_tuple(input, size, scale_factor, recompute_scale_factor);
if (input.dim() == 3 && std::get_if<enumtype::kNearest>(&mode)) {
return torch::upsample_nearest1d(
input,
_interp_output_size(1, std::move(closed_over_args)),
scale_factor_list.at(0));
} else if (input.dim() == 4 && std::get_if<enumtype::kNearest>(&mode)) {
return torch::upsample_nearest2d(
input,
_interp_output_size(2, std::move(closed_over_args)),
scale_factor_list.at(0),
scale_factor_list.at(1));
} else if (input.dim() == 5 && std::get_if<enumtype::kNearest>(&mode)) {
return torch::upsample_nearest3d(
input,
_interp_output_size(3, std::move(closed_over_args)),
scale_factor_list.at(0),
scale_factor_list.at(1),
scale_factor_list.at(2));
} else if (input.dim() == 3 && std::get_if<enumtype::kNearestExact>(&mode)) {
return torch::_upsample_nearest_exact1d(
input,
_interp_output_size(1, std::move(closed_over_args)),
scale_factor_list.at(0));
} else if (input.dim() == 4 && std::get_if<enumtype::kNearestExact>(&mode)) {
return torch::_upsample_nearest_exact2d(
input,
_interp_output_size(2, std::move(closed_over_args)),
scale_factor_list.at(0),
scale_factor_list.at(1));
} else if (input.dim() == 5 && std::get_if<enumtype::kNearestExact>(&mode)) {
return torch::_upsample_nearest_exact3d(
input,
_interp_output_size(3, std::move(closed_over_args)),
scale_factor_list.at(0),
scale_factor_list.at(1),
scale_factor_list.at(2));
} else if (input.dim() == 3 && std::get_if<enumtype::kArea>(&mode)) {
return detail::adaptive_avg_pool1d(
input, _interp_output_size(1, std::move(closed_over_args)));
} else if (input.dim() == 4 && std::get_if<enumtype::kArea>(&mode)) {
return detail::adaptive_avg_pool2d(
input, _interp_output_size(2, std::move(closed_over_args)));
} else if (input.dim() == 5 && std::get_if<enumtype::kArea>(&mode)) {
return detail::adaptive_avg_pool3d(
input, _interp_output_size(3, std::move(closed_over_args)));
} else if (input.dim() == 3 && std::get_if<enumtype::kLinear>(&mode)) {
TORCH_CHECK(
align_corners != c10::nullopt, "align_corners should be specified.");
return torch::upsample_linear1d(
input,
_interp_output_size(1, std::move(closed_over_args)),
*align_corners,
scale_factor_list.at(0));
} else if (input.dim() == 3 && std::get_if<enumtype::kBilinear>(&mode)) {
TORCH_CHECK(false, "Got 3D input, but bilinear mode needs 4D input");
} else if (input.dim() == 3 && std::get_if<enumtype::kTrilinear>(&mode)) {
TORCH_CHECK(false, "Got 3D input, but trilinear mode needs 5D input");
} else if (input.dim() == 4 && std::get_if<enumtype::kLinear>(&mode)) {
TORCH_CHECK(false, "Got 4D input, but linear mode needs 3D input");
} else if (input.dim() == 4 && std::get_if<enumtype::kBilinear>(&mode)) {
TORCH_CHECK(
align_corners != c10::nullopt, "align_corners should be specified.");
if (antialias) {
return torch::_upsample_bilinear2d_aa(
input,
_interp_output_size(2, std::move(closed_over_args)),
*align_corners,
scale_factor_list.at(0),
scale_factor_list.at(1));
}
return torch::upsample_bilinear2d(
input,
_interp_output_size(2, std::move(closed_over_args)),
*align_corners,
scale_factor_list.at(0),
scale_factor_list.at(1));
} else if (input.dim() == 4 && std::get_if<enumtype::kTrilinear>(&mode)) {
TORCH_CHECK(false, "Got 4D input, but trilinear mode needs 5D input");
} else if (input.dim() == 5 && std::get_if<enumtype::kLinear>(&mode)) {
TORCH_CHECK(false, "Got 5D input, but linear mode needs 3D input");
} else if (input.dim() == 5 && std::get_if<enumtype::kBilinear>(&mode)) {
TORCH_CHECK(false, "Got 5D input, but bilinear mode needs 4D input");
} else if (input.dim() == 5 && std::get_if<enumtype::kTrilinear>(&mode)) {
TORCH_CHECK(
align_corners != c10::nullopt, "align_corners should be specified.");
return torch::upsample_trilinear3d(
input,
_interp_output_size(3, std::move(closed_over_args)),
*align_corners,
scale_factor_list.at(0),
scale_factor_list.at(1),
scale_factor_list.at(2));
} else if (input.dim() == 4 && std::get_if<enumtype::kBicubic>(&mode)) {
TORCH_CHECK(
align_corners != c10::nullopt, "align_corners should be specified.");
if (antialias) {
return torch::_upsample_bicubic2d_aa(
input,
_interp_output_size(2, std::move(closed_over_args)),
*align_corners,
scale_factor_list.at(0),
scale_factor_list.at(1));
}
return torch::upsample_bicubic2d(
input,
_interp_output_size(2, std::move(closed_over_args)),
*align_corners,
scale_factor_list.at(0),
scale_factor_list.at(1));
} else {
TORCH_CHECK(
false,
"Input Error: Only 3D, 4D and 5D input Tensors supported "
"(got ",
input.dim(),
"D) for the modes: nearest | linear | bilinear | bicubic | trilinear "
"(got ",
enumtype::get_enum_name(mode),
")");
}
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
inline Tensor interpolate(
const Tensor& input,
const InterpolateFuncOptions& options = {}) {
return detail::interpolate(
input,
options.size(),
options.scale_factor(),
options.mode(),
options.align_corners(),
options.recompute_scale_factor(),
options.antialias());
}
} // namespace functional
} // namespace nn
} // namespace torch