Shortcuts

Program Listing for File dropout.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/dropout.h)

#pragma once

#include <torch/nn/cloneable.h>
#include <torch/nn/options/dropout.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>

#include <torch/csrc/Export.h>

#include <cstddef>
#include <vector>

namespace torch {
namespace nn {

namespace detail {

template <typename Derived>
class _DropoutNd : public torch::nn::Cloneable<Derived> {
 public:
  _DropoutNd(double p) : _DropoutNd(DropoutOptions().p(p)){};

  explicit _DropoutNd(const DropoutOptions& options_ = {}) : options(options_) {
    // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
    reset();
  }

  void reset() override {
    TORCH_CHECK(
        options.p() >= 0. && options.p() <= 1.,
        "dropout probability has to be between 0 and 1, but got ",
        options.p());
  }

  DropoutOptions options;
};

} // namespace detail

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API DropoutImpl : public detail::_DropoutNd<DropoutImpl> {
 public:
  using detail::_DropoutNd<DropoutImpl>::_DropoutNd;

  Tensor forward(Tensor input);

  void pretty_print(std::ostream& stream) const override;
};

TORCH_MODULE(Dropout);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API Dropout2dImpl : public detail::_DropoutNd<Dropout2dImpl> {
 public:
  using detail::_DropoutNd<Dropout2dImpl>::_DropoutNd;

  Tensor forward(Tensor input);

  void pretty_print(std::ostream& stream) const override;
};

TORCH_MODULE(Dropout2d);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API Dropout3dImpl : public detail::_DropoutNd<Dropout3dImpl> {
 public:
  using detail::_DropoutNd<Dropout3dImpl>::_DropoutNd;

  Tensor forward(Tensor input);

  void pretty_print(std::ostream& stream) const override;
};

TORCH_MODULE(Dropout3d);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AlphaDropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API AlphaDropoutImpl : public detail::_DropoutNd<AlphaDropoutImpl> {
 public:
  using detail::_DropoutNd<AlphaDropoutImpl>::_DropoutNd;

  Tensor forward(const Tensor& input);

  void pretty_print(std::ostream& stream) const override;
};

TORCH_MODULE(AlphaDropout);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FeatureAlphaDropout
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API FeatureAlphaDropoutImpl
    : public detail::_DropoutNd<FeatureAlphaDropoutImpl> {
 public:
  using detail::_DropoutNd<FeatureAlphaDropoutImpl>::_DropoutNd;

  Tensor forward(const Tensor& input);

  void pretty_print(std::ostream& stream) const override;
};

TORCH_MODULE(FeatureAlphaDropout);

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources