Shortcuts

Program Listing for File dropout.h

Return to documentation for file (torch/csrc/api/include/torch/nn/functional/dropout.h)

#pragma once

#include <torch/nn/options/dropout.h>

#include <utility>

namespace torch {
namespace nn {
namespace functional {

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {

inline Tensor dropout(Tensor input, double p, bool training, bool inplace) {
  TORCH_CHECK(
      p >= 0. && p <= 1.,
      "dropout probability has to be between 0 and 1, but got ",
      p);
  if (inplace) {
    return torch::dropout_(input, p, training);
  } else {
    return torch::dropout(input, p, training);
  }
}

} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor dropout(Tensor input, const DropoutFuncOptions& options = {}) {
  return detail::dropout(
      std::move(input), options.p(), options.training(), options.inplace());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {

template <int64_t unbatched_dim, int64_t batched_dim>
inline Tensor _dropoutNd_helper(
    Tensor input,
    double p,
    bool training,
    bool inplace,
    const char* fn_name) {
  TORCH_CHECK(
      p >= 0. && p <= 1.,
      "dropout probability has to be between 0 and 1, but got ",
      p);

  auto inp_dim = input.dim();
  auto is_batched = inp_dim == batched_dim;
  if (!is_batched) {
    if (inplace) {
      input = input.unsqueeze_(0);
    } else {
      input = input.unsqueeze(0);
    }
  }

  Tensor result;
  if (inplace) {
    result = torch::feature_dropout_(input, p, training);
  } else {
    result = torch::feature_dropout(input, p, training);
  }

  if (!is_batched) {
    if (inplace) {
      result = result.squeeze_(0);
    } else {
      result = result.squeeze(0);
    }
  }
  return result;
}

inline Tensor dropout2d(Tensor input, double p, bool training, bool inplace) {
  return _dropoutNd_helper<3, 4>(
      std::move(input), p, training, inplace, "dropout2d");
}

} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor dropout2d(
    Tensor input,
    const Dropout2dFuncOptions& options = {}) {
  return detail::dropout2d(
      std::move(input), options.p(), options.training(), options.inplace());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {

inline Tensor dropout3d(Tensor input, double p, bool training, bool inplace) {
  return _dropoutNd_helper<4, 5>(
      std::move(input), p, training, inplace, "dropout3d");
}

} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor dropout3d(
    Tensor input,
    const Dropout3dFuncOptions& options = {}) {
  return detail::dropout3d(
      std::move(input), options.p(), options.training(), options.inplace());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {

inline Tensor alpha_dropout(
    Tensor input,
    double p,
    bool training,
    bool inplace) {
  if (p < 0. || p > 1.) {
    TORCH_CHECK(
        false, "dropout probability has to be between 0 and 1, but got ", p);
  }
  return inplace ? torch::alpha_dropout_(input, p, training)
                 : torch::alpha_dropout(input, p, training);
}

} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor alpha_dropout(
    Tensor input,
    const AlphaDropoutFuncOptions& options = {}) {
  return detail::alpha_dropout(
      std::move(input), options.p(), options.training(), options.inplace());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {

inline Tensor feature_alpha_dropout(
    Tensor input,
    double p,
    bool training,
    bool inplace) {
  if (p < 0. || p > 1.) {
    TORCH_CHECK(
        false, "dropout probability has to be between 0 and 1, but got ", p);
  }
  return inplace ? torch::feature_alpha_dropout_(input, p, training)
                 : torch::feature_alpha_dropout(input, p, training);
}

} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor feature_alpha_dropout(
    Tensor input,
    const FeatureAlphaDropoutFuncOptions& options = {}) {
  return detail::feature_alpha_dropout(
      std::move(input), options.p(), options.training(), options.inplace());
}

} // namespace functional
} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources