Program Listing for File activation.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/modules/activation.h
)
#pragma once
#include <torch/nn/cloneable.h>
#include <torch/nn/functional/activation.h>
#include <torch/nn/modules/common.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/options/activation.h>
#include <torch/csrc/Export.h>
namespace torch::nn {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API ELUImpl : public torch::nn::Cloneable<ELUImpl> {
public:
explicit ELUImpl(const ELUOptions& options_ = {});
Tensor forward(Tensor input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
ELUOptions options;
};
TORCH_MODULE(ELU);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API SELUImpl : public torch::nn::Cloneable<SELUImpl> {
public:
explicit SELUImpl(const SELUOptions& options_ = {});
Tensor forward(Tensor input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
SELUOptions options;
};
TORCH_MODULE(SELU);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Hardshrink ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API HardshrinkImpl : public torch::nn::Cloneable<HardshrinkImpl> {
public:
explicit HardshrinkImpl(const HardshrinkOptions& options_ = {});
Tensor forward(const Tensor& input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
HardshrinkOptions options;
};
TORCH_MODULE(Hardshrink);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Hardtanh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API HardtanhImpl : public torch::nn::Cloneable<HardtanhImpl> {
public:
explicit HardtanhImpl(const HardtanhOptions& options_ = {});
Tensor forward(Tensor input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
HardtanhOptions options;
};
TORCH_MODULE(Hardtanh);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LeakyReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API LeakyReLUImpl : public torch::nn::Cloneable<LeakyReLUImpl> {
public:
explicit LeakyReLUImpl(const LeakyReLUOptions& options_ = {});
Tensor forward(Tensor input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
LeakyReLUOptions options;
};
TORCH_MODULE(LeakyReLU);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LogSigmoid ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API LogSigmoidImpl : public torch::nn::Cloneable<LogSigmoidImpl> {
public:
Tensor forward(const Tensor& input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
};
TORCH_MODULE(LogSigmoid);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softmax ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API SoftmaxImpl : public torch::nn::Cloneable<SoftmaxImpl> {
public:
explicit SoftmaxImpl(int64_t dim) : SoftmaxImpl(SoftmaxOptions(dim)) {}
explicit SoftmaxImpl(const SoftmaxOptions& options_);
Tensor forward(const Tensor& input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
SoftmaxOptions options;
};
TORCH_MODULE(Softmax);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softmin ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API SoftminImpl : public torch::nn::Cloneable<SoftminImpl> {
public:
explicit SoftminImpl(int64_t dim) : SoftminImpl(SoftminOptions(dim)) {}
explicit SoftminImpl(const SoftminOptions& options_);
Tensor forward(const Tensor& input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
SoftminOptions options;
};
TORCH_MODULE(Softmin);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LogSoftmax ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API LogSoftmaxImpl : public torch::nn::Cloneable<LogSoftmaxImpl> {
public:
explicit LogSoftmaxImpl(int64_t dim)
: LogSoftmaxImpl(LogSoftmaxOptions(dim)) {}
explicit LogSoftmaxImpl(const LogSoftmaxOptions& options_);
Tensor forward(const Tensor& input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
LogSoftmaxOptions options;
};
TORCH_MODULE(LogSoftmax);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softmax2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API Softmax2dImpl : public torch::nn::Cloneable<Softmax2dImpl> {
public:
Tensor forward(const Tensor& input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
};
TORCH_MODULE(Softmax2d);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API PReLUImpl : public torch::nn::Cloneable<PReLUImpl> {
public:
explicit PReLUImpl(const PReLUOptions& options_ = {});
Tensor forward(const Tensor& input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
PReLUOptions options;
Tensor weight;
};
TORCH_MODULE(PReLU);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API ReLUImpl : public torch::nn::Cloneable<ReLUImpl> {
public:
explicit ReLUImpl(const ReLUOptions& options_ = {});
Tensor forward(Tensor input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
ReLUOptions options;
};
TORCH_MODULE(ReLU);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReLU6 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API ReLU6Impl : public torch::nn::Cloneable<ReLU6Impl> {
public:
explicit ReLU6Impl(const ReLU6Options& options_ = {});
Tensor forward(Tensor input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
ReLU6Options options;
};
TORCH_MODULE(ReLU6);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API RReLUImpl : public torch::nn::Cloneable<RReLUImpl> {
public:
explicit RReLUImpl(const RReLUOptions& options_ = {});
Tensor forward(Tensor input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
RReLUOptions options;
};
TORCH_MODULE(RReLU);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API CELUImpl : public torch::nn::Cloneable<CELUImpl> {
public:
explicit CELUImpl(const CELUOptions& options_ = {});
Tensor forward(Tensor input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
CELUOptions options;
};
TORCH_MODULE(CELU);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API GLUImpl : public torch::nn::Cloneable<GLUImpl> {
public:
explicit GLUImpl(const GLUOptions& options_ = {});
Tensor forward(const Tensor& input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
GLUOptions options;
};
TORCH_MODULE(GLU);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API GELUImpl : public torch::nn::Cloneable<GELUImpl> {
public:
explicit GELUImpl(GELUOptions options_ = {});
Tensor forward(const Tensor& input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
GELUOptions options;
};
TORCH_MODULE(GELU);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SiLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API SiLUImpl : public torch::nn::Cloneable<SiLUImpl> {
public:
Tensor forward(const Tensor& input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
};
TORCH_MODULE(SiLU);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Mish ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API MishImpl : public torch::nn::Cloneable<MishImpl> {
public:
Tensor forward(const Tensor& input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
};
TORCH_MODULE(Mish);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Sigmoid ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API SigmoidImpl : public torch::nn::Cloneable<SigmoidImpl> {
public:
Tensor forward(const Tensor& input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
};
TORCH_MODULE(Sigmoid);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softplus ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API SoftplusImpl : public torch::nn::Cloneable<SoftplusImpl> {
public:
explicit SoftplusImpl(const SoftplusOptions& options_ = {});
Tensor forward(const Tensor& input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
SoftplusOptions options;
};
TORCH_MODULE(Softplus);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softshrink ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API SoftshrinkImpl : public torch::nn::Cloneable<SoftshrinkImpl> {
public:
explicit SoftshrinkImpl(const SoftshrinkOptions& options_ = {});
Tensor forward(const Tensor& input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
SoftshrinkOptions options;
};
TORCH_MODULE(Softshrink);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softsign ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API SoftsignImpl : public torch::nn::Cloneable<SoftsignImpl> {
public:
Tensor forward(const Tensor& input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
};
TORCH_MODULE(Softsign);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tanh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API TanhImpl : public torch::nn::Cloneable<TanhImpl> {
public:
Tensor forward(const Tensor& input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
};
TORCH_MODULE(Tanh);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tanhshrink ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API TanhshrinkImpl : public torch::nn::Cloneable<TanhshrinkImpl> {
public:
Tensor forward(const Tensor& input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
};
TORCH_MODULE(Tanhshrink);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Threshold ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API ThresholdImpl : public torch::nn::Cloneable<ThresholdImpl> {
public:
ThresholdImpl(double threshold, double value)
: ThresholdImpl(ThresholdOptions(threshold, value)) {}
explicit ThresholdImpl(const ThresholdOptions& options_);
Tensor forward(Tensor input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
ThresholdOptions options;
};
TORCH_MODULE(Threshold);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiheadAttention ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API MultiheadAttentionImpl
: public torch::nn::Cloneable<MultiheadAttentionImpl> {
public:
MultiheadAttentionImpl(int64_t embed_dim, int64_t num_heads)
: MultiheadAttentionImpl(
MultiheadAttentionOptions(embed_dim, num_heads)) {}
explicit MultiheadAttentionImpl(const MultiheadAttentionOptions& options_);
std::tuple<Tensor, Tensor> forward(
const Tensor& query,
const Tensor& key,
const Tensor& value,
const Tensor& key_padding_mask = {},
bool need_weights = true,
const Tensor& attn_mask = {},
bool average_attn_weights = true);
protected:
FORWARD_HAS_DEFAULT_ARGS(
{3, AnyValue(Tensor())},
{4, AnyValue(true)},
{5, AnyValue(Tensor())},
{6, AnyValue(true)})
public:
void reset() override;
void _reset_parameters();
MultiheadAttentionOptions options;
bool _qkv_same_embed_dim{};
Tensor in_proj_weight;
Tensor in_proj_bias;
Tensor bias_k;
Tensor bias_v;
Linear out_proj = nullptr;
Tensor q_proj_weight;
Tensor k_proj_weight;
Tensor v_proj_weight;
int64_t head_dim{};
};
TORCH_MODULE(MultiheadAttention);
} // namespace torch::nn