Shortcuts

Program Listing for File activation.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/activation.h)

#pragma once

#include <torch/nn/cloneable.h>
#include <torch/nn/functional/activation.h>
#include <torch/nn/modules/common.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/options/activation.h>

#include <torch/csrc/Export.h>

namespace torch {
namespace nn {

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API ELUImpl : public torch::nn::Cloneable<ELUImpl> {
 public:
  explicit ELUImpl(const ELUOptions& options_ = {});

  Tensor forward(Tensor input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  ELUOptions options;
};

TORCH_MODULE(ELU);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API SELUImpl : public torch::nn::Cloneable<SELUImpl> {
 public:
  explicit SELUImpl(const SELUOptions& options_ = {});

  Tensor forward(Tensor input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  SELUOptions options;
};

TORCH_MODULE(SELU);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Hardshrink ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API HardshrinkImpl : public torch::nn::Cloneable<HardshrinkImpl> {
 public:
  explicit HardshrinkImpl(const HardshrinkOptions& options_ = {});

  Tensor forward(const Tensor& input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  HardshrinkOptions options;
};

TORCH_MODULE(Hardshrink);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Hardtanh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API HardtanhImpl : public torch::nn::Cloneable<HardtanhImpl> {
 public:
  explicit HardtanhImpl(const HardtanhOptions& options_ = {});

  Tensor forward(Tensor input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  HardtanhOptions options;
};

TORCH_MODULE(Hardtanh);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LeakyReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API LeakyReLUImpl : public torch::nn::Cloneable<LeakyReLUImpl> {
 public:
  explicit LeakyReLUImpl(const LeakyReLUOptions& options_ = {});

  Tensor forward(Tensor input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  LeakyReLUOptions options;
};

TORCH_MODULE(LeakyReLU);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LogSigmoid ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API LogSigmoidImpl : public torch::nn::Cloneable<LogSigmoidImpl> {
 public:
  Tensor forward(const Tensor& input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;
};

TORCH_MODULE(LogSigmoid);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softmax ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API SoftmaxImpl : public torch::nn::Cloneable<SoftmaxImpl> {
 public:
  explicit SoftmaxImpl(int64_t dim) : SoftmaxImpl(SoftmaxOptions(dim)) {}
  explicit SoftmaxImpl(const SoftmaxOptions& options_);

  Tensor forward(const Tensor& input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  SoftmaxOptions options;
};

TORCH_MODULE(Softmax);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softmin ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API SoftminImpl : public torch::nn::Cloneable<SoftminImpl> {
 public:
  explicit SoftminImpl(int64_t dim) : SoftminImpl(SoftminOptions(dim)) {}
  explicit SoftminImpl(const SoftminOptions& options_);

  Tensor forward(const Tensor& input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  SoftminOptions options;
};

TORCH_MODULE(Softmin);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LogSoftmax ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API LogSoftmaxImpl : public torch::nn::Cloneable<LogSoftmaxImpl> {
 public:
  explicit LogSoftmaxImpl(int64_t dim)
      : LogSoftmaxImpl(LogSoftmaxOptions(dim)) {}
  explicit LogSoftmaxImpl(const LogSoftmaxOptions& options_);

  Tensor forward(const Tensor& input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  LogSoftmaxOptions options;
};

TORCH_MODULE(LogSoftmax);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softmax2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API Softmax2dImpl : public torch::nn::Cloneable<Softmax2dImpl> {
 public:
  Tensor forward(const Tensor& input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;
};

TORCH_MODULE(Softmax2d);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API PReLUImpl : public torch::nn::Cloneable<PReLUImpl> {
 public:
  explicit PReLUImpl(const PReLUOptions& options_ = {});

  Tensor forward(const Tensor& input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  PReLUOptions options;

  Tensor weight;
};

TORCH_MODULE(PReLU);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API ReLUImpl : public torch::nn::Cloneable<ReLUImpl> {
 public:
  explicit ReLUImpl(const ReLUOptions& options_ = {});

  Tensor forward(Tensor input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  ReLUOptions options;
};

TORCH_MODULE(ReLU);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReLU6 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API ReLU6Impl : public torch::nn::Cloneable<ReLU6Impl> {
 public:
  explicit ReLU6Impl(const ReLU6Options& options_ = {});

  Tensor forward(Tensor input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  ReLU6Options options;
};

TORCH_MODULE(ReLU6);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API RReLUImpl : public torch::nn::Cloneable<RReLUImpl> {
 public:
  explicit RReLUImpl(const RReLUOptions& options_ = {});

  Tensor forward(Tensor input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  RReLUOptions options;
};

TORCH_MODULE(RReLU);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API CELUImpl : public torch::nn::Cloneable<CELUImpl> {
 public:
  explicit CELUImpl(const CELUOptions& options_ = {});

  Tensor forward(Tensor input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  CELUOptions options;
};

TORCH_MODULE(CELU);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API GLUImpl : public torch::nn::Cloneable<GLUImpl> {
 public:
  explicit GLUImpl(const GLUOptions& options_ = {});

  Tensor forward(const Tensor& input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  GLUOptions options;
};

TORCH_MODULE(GLU);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API GELUImpl : public torch::nn::Cloneable<GELUImpl> {
 public:
  explicit GELUImpl(GELUOptions options_ = {});

  Tensor forward(const Tensor& input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  GELUOptions options;
};

TORCH_MODULE(GELU);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SiLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API SiLUImpl : public torch::nn::Cloneable<SiLUImpl> {
 public:
  Tensor forward(const Tensor& input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;
};

TORCH_MODULE(SiLU);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Mish ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API MishImpl : public torch::nn::Cloneable<MishImpl> {
 public:
  Tensor forward(const Tensor& input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;
};

TORCH_MODULE(Mish);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Sigmoid ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API SigmoidImpl : public torch::nn::Cloneable<SigmoidImpl> {
 public:
  Tensor forward(const Tensor& input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;
};

TORCH_MODULE(Sigmoid);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softplus ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API SoftplusImpl : public torch::nn::Cloneable<SoftplusImpl> {
 public:
  explicit SoftplusImpl(const SoftplusOptions& options_ = {});

  Tensor forward(const Tensor& input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  SoftplusOptions options;
};

TORCH_MODULE(Softplus);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softshrink ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API SoftshrinkImpl : public torch::nn::Cloneable<SoftshrinkImpl> {
 public:
  explicit SoftshrinkImpl(const SoftshrinkOptions& options_ = {});

  Tensor forward(const Tensor& input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  SoftshrinkOptions options;
};

TORCH_MODULE(Softshrink);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softsign ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API SoftsignImpl : public torch::nn::Cloneable<SoftsignImpl> {
 public:
  Tensor forward(const Tensor& input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;
};

TORCH_MODULE(Softsign);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tanh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API TanhImpl : public torch::nn::Cloneable<TanhImpl> {
 public:
  Tensor forward(const Tensor& input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;
};

TORCH_MODULE(Tanh);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tanhshrink ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API TanhshrinkImpl : public torch::nn::Cloneable<TanhshrinkImpl> {
 public:
  Tensor forward(const Tensor& input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;
};

TORCH_MODULE(Tanhshrink);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Threshold ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API ThresholdImpl : public torch::nn::Cloneable<ThresholdImpl> {
 public:
  ThresholdImpl(double threshold, double value)
      : ThresholdImpl(ThresholdOptions(threshold, value)) {}
  explicit ThresholdImpl(const ThresholdOptions& options_);

  Tensor forward(Tensor input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  ThresholdOptions options;
};

TORCH_MODULE(Threshold);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiheadAttention ~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API MultiheadAttentionImpl
    : public torch::nn::Cloneable<MultiheadAttentionImpl> {
 public:
  MultiheadAttentionImpl(int64_t embed_dim, int64_t num_heads)
      : MultiheadAttentionImpl(
            MultiheadAttentionOptions(embed_dim, num_heads)) {}
  explicit MultiheadAttentionImpl(const MultiheadAttentionOptions& options_);

  std::tuple<Tensor, Tensor> forward(
      const Tensor& query,
      const Tensor& key,
      const Tensor& value,
      const Tensor& key_padding_mask = {},
      bool need_weights = true,
      const Tensor& attn_mask = {},
      bool average_attn_weights = true);

 protected:
  FORWARD_HAS_DEFAULT_ARGS(
      {3, AnyValue(Tensor())},
      {4, AnyValue(true)},
      {5, AnyValue(Tensor())},
      {6, AnyValue(true)})

 public:
  void reset() override;

  void _reset_parameters();

  MultiheadAttentionOptions options;

  bool _qkv_same_embed_dim;
  Tensor in_proj_weight;
  Tensor in_proj_bias;
  Tensor bias_k;
  Tensor bias_v;
  Linear out_proj = nullptr;
  Tensor q_proj_weight;
  Tensor k_proj_weight;
  Tensor v_proj_weight;
  int64_t head_dim;
};

TORCH_MODULE(MultiheadAttention);

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources