Program Listing for File dropout.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/functional/dropout.h
)
#pragma once
#include <torch/nn/options/dropout.h>
#include <utility>
namespace torch {
namespace nn {
namespace functional {
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor dropout(Tensor input, double p, bool training, bool inplace) {
TORCH_CHECK(
p >= 0. && p <= 1.,
"dropout probability has to be between 0 and 1, but got ",
p);
if (inplace) {
return torch::dropout_(input, p, training);
} else {
return torch::dropout(input, p, training);
}
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
inline Tensor dropout(Tensor input, const DropoutFuncOptions& options = {}) {
return detail::dropout(
std::move(input), options.p(), options.training(), options.inplace());
}
// ============================================================================
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
template <int64_t unbatched_dim, int64_t batched_dim>
inline Tensor _dropoutNd_helper(
Tensor input,
double p,
bool training,
bool inplace,
const char* fn_name) {
TORCH_CHECK(
p >= 0. && p <= 1.,
"dropout probability has to be between 0 and 1, but got ",
p);
auto inp_dim = input.dim();
auto is_batched = inp_dim == batched_dim;
if (!is_batched) {
if (inplace) {
input = input.unsqueeze_(0);
} else {
input = input.unsqueeze(0);
}
}
Tensor result;
if (inplace) {
result = torch::feature_dropout_(input, p, training);
} else {
result = torch::feature_dropout(input, p, training);
}
if (!is_batched) {
if (inplace) {
result = result.squeeze_(0);
} else {
result = result.squeeze(0);
}
}
return result;
}
inline Tensor dropout2d(Tensor input, double p, bool training, bool inplace) {
return _dropoutNd_helper<3, 4>(
std::move(input), p, training, inplace, "dropout2d");
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
inline Tensor dropout2d(
Tensor input,
const Dropout2dFuncOptions& options = {}) {
return detail::dropout2d(
std::move(input), options.p(), options.training(), options.inplace());
}
// ============================================================================
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor dropout3d(Tensor input, double p, bool training, bool inplace) {
return _dropoutNd_helper<4, 5>(
std::move(input), p, training, inplace, "dropout3d");
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
inline Tensor dropout3d(
Tensor input,
const Dropout3dFuncOptions& options = {}) {
return detail::dropout3d(
std::move(input), options.p(), options.training(), options.inplace());
}
// ============================================================================
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor alpha_dropout(
Tensor input,
double p,
bool training,
bool inplace) {
if (p < 0. || p > 1.) {
TORCH_CHECK(
false, "dropout probability has to be between 0 and 1, but got ", p);
}
return inplace ? torch::alpha_dropout_(input, p, training)
: torch::alpha_dropout(input, p, training);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
inline Tensor alpha_dropout(
Tensor input,
const AlphaDropoutFuncOptions& options = {}) {
return detail::alpha_dropout(
std::move(input), options.p(), options.training(), options.inplace());
}
// ============================================================================
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor feature_alpha_dropout(
Tensor input,
double p,
bool training,
bool inplace) {
if (p < 0. || p > 1.) {
TORCH_CHECK(
false, "dropout probability has to be between 0 and 1, but got ", p);
}
return inplace ? torch::feature_alpha_dropout_(input, p, training)
: torch::feature_alpha_dropout(input, p, training);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
inline Tensor feature_alpha_dropout(
Tensor input,
const FeatureAlphaDropoutFuncOptions& options = {}) {
return detail::feature_alpha_dropout(
std::move(input), options.p(), options.training(), options.inplace());
}
} // namespace functional
} // namespace nn
} // namespace torch