Program Listing for File dropout.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/modules/dropout.h
)
#pragma once
#include <torch/nn/cloneable.h>
#include <torch/nn/options/dropout.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>
#include <torch/csrc/Export.h>
namespace torch::nn {
namespace detail {
template <typename Derived>
class _DropoutNd : public torch::nn::Cloneable<Derived> {
public:
_DropoutNd(double p) : _DropoutNd(DropoutOptions().p(p)) {}
explicit _DropoutNd(const DropoutOptions& options_ = {}) : options(options_) {
_DropoutNd::reset();
}
void reset() override {
TORCH_CHECK(
options.p() >= 0. && options.p() <= 1.,
"dropout probability has to be between 0 and 1, but got ",
options.p());
}
DropoutOptions options;
};
} // namespace detail
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API DropoutImpl : public detail::_DropoutNd<DropoutImpl> {
public:
using detail::_DropoutNd<DropoutImpl>::_DropoutNd;
Tensor forward(Tensor input);
void pretty_print(std::ostream& stream) const override;
};
TORCH_MODULE(Dropout);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API Dropout2dImpl : public detail::_DropoutNd<Dropout2dImpl> {
public:
using detail::_DropoutNd<Dropout2dImpl>::_DropoutNd;
Tensor forward(Tensor input);
void pretty_print(std::ostream& stream) const override;
};
TORCH_MODULE(Dropout2d);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API Dropout3dImpl : public detail::_DropoutNd<Dropout3dImpl> {
public:
using detail::_DropoutNd<Dropout3dImpl>::_DropoutNd;
Tensor forward(Tensor input);
void pretty_print(std::ostream& stream) const override;
};
TORCH_MODULE(Dropout3d);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AlphaDropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API AlphaDropoutImpl : public detail::_DropoutNd<AlphaDropoutImpl> {
public:
using detail::_DropoutNd<AlphaDropoutImpl>::_DropoutNd;
Tensor forward(const Tensor& input);
void pretty_print(std::ostream& stream) const override;
};
TORCH_MODULE(AlphaDropout);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FeatureAlphaDropout
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API FeatureAlphaDropoutImpl
: public detail::_DropoutNd<FeatureAlphaDropoutImpl> {
public:
using detail::_DropoutNd<FeatureAlphaDropoutImpl>::_DropoutNd;
Tensor forward(const Tensor& input);
void pretty_print(std::ostream& stream) const override;
};
TORCH_MODULE(FeatureAlphaDropout);
} // namespace torch::nn