Shortcuts

Program Listing for File activation.h

Return to documentation for file (torch/csrc/api/include/torch/nn/options/activation.h)

#pragma once

#include <torch/arg.h>
#include <torch/csrc/Export.h>
#include <torch/enum.h>
#include <torch/types.h>

namespace torch::nn {

struct TORCH_API ELUOptions {
  TORCH_ARG(double, alpha) = 1.0;

  TORCH_ARG(bool, inplace) = false;
};

namespace functional {
using ELUFuncOptions = ELUOptions;
} // namespace functional

// ============================================================================

struct TORCH_API SELUOptions {
  /* implicit */ SELUOptions(bool inplace = false);

  TORCH_ARG(bool, inplace);
};

namespace functional {
using SELUFuncOptions = SELUOptions;
} // namespace functional

// ============================================================================

struct TORCH_API GLUOptions {
  /* implicit */ GLUOptions(int64_t dim = -1);

  TORCH_ARG(int64_t, dim);
};

namespace functional {
using GLUFuncOptions = GLUOptions;
} // namespace functional

// ============================================================================

struct TORCH_API GELUOptions {
  TORCH_ARG(std::string, approximate) = "none";
};

namespace functional {
using GELUFuncOptions = GELUOptions;
} // namespace functional

// ============================================================================

struct TORCH_API HardshrinkOptions {
  /* implicit */ HardshrinkOptions(double lambda = 0.5);

  TORCH_ARG(double, lambda);
};

namespace functional {
using HardshrinkFuncOptions = HardshrinkOptions;
} // namespace functional

// ============================================================================

struct TORCH_API HardtanhOptions {
  TORCH_ARG(double, min_val) = -1.0;

  TORCH_ARG(double, max_val) = 1.0;

  TORCH_ARG(bool, inplace) = false;
};

namespace functional {
using HardtanhFuncOptions = HardtanhOptions;
} // namespace functional

// ============================================================================

struct TORCH_API LeakyReLUOptions {
  TORCH_ARG(double, negative_slope) = 1e-2;

  TORCH_ARG(bool, inplace) = false;
};

namespace functional {
using LeakyReLUFuncOptions = LeakyReLUOptions;
} // namespace functional

// ============================================================================

struct TORCH_API SoftmaxOptions {
  SoftmaxOptions(int64_t dim);

  TORCH_ARG(int64_t, dim);
};

// ============================================================================

namespace functional {

struct TORCH_API SoftmaxFuncOptions {
  SoftmaxFuncOptions(int64_t dim);

  TORCH_ARG(int64_t, dim);

  TORCH_ARG(std::optional<torch::Dtype>, dtype) = std::nullopt;
};

} // namespace functional

// ============================================================================

struct TORCH_API SoftminOptions {
  SoftminOptions(int64_t dim);

  TORCH_ARG(int64_t, dim);
};

// ============================================================================

namespace functional {

struct TORCH_API SoftminFuncOptions {
  SoftminFuncOptions(int64_t dim);

  TORCH_ARG(int64_t, dim);

  TORCH_ARG(std::optional<torch::Dtype>, dtype) = std::nullopt;
};

} // namespace functional

// ============================================================================

struct TORCH_API LogSoftmaxOptions {
  LogSoftmaxOptions(int64_t dim);

  TORCH_ARG(int64_t, dim);
};

// ============================================================================

namespace functional {

struct TORCH_API LogSoftmaxFuncOptions {
  LogSoftmaxFuncOptions(int64_t dim);

  TORCH_ARG(int64_t, dim);

  TORCH_ARG(std::optional<torch::Dtype>, dtype) = std::nullopt;
};

} // namespace functional

// ============================================================================

struct TORCH_API PReLUOptions {
  TORCH_ARG(int64_t, num_parameters) = 1;

  TORCH_ARG(double, init) = 0.25;
};

// ============================================================================

struct TORCH_API ReLUOptions {
  /* implicit */ ReLUOptions(bool inplace = false);

  TORCH_ARG(bool, inplace);
};

namespace functional {
using ReLUFuncOptions = ReLUOptions;
} // namespace functional

// ============================================================================

struct TORCH_API ReLU6Options {
  /* implicit */ ReLU6Options(bool inplace = false);

  TORCH_ARG(bool, inplace);
};

namespace functional {
using ReLU6FuncOptions = ReLU6Options;
} // namespace functional

// ============================================================================

struct TORCH_API RReLUOptions {
  TORCH_ARG(double, lower) = 1.0 / 8.0;

  TORCH_ARG(double, upper) = 1.0 / 3.0;

  TORCH_ARG(bool, inplace) = false;
};

// ============================================================================

namespace functional {

struct TORCH_API RReLUFuncOptions {
  TORCH_ARG(double, lower) = 1.0 / 8.0;

  TORCH_ARG(double, upper) = 1.0 / 3.0;

  TORCH_ARG(bool, training) = false;

  TORCH_ARG(bool, inplace) = false;
};

} // namespace functional

// ============================================================================

struct TORCH_API CELUOptions {
  TORCH_ARG(double, alpha) = 1.0;

  TORCH_ARG(bool, inplace) = false;
};

namespace functional {
using CELUFuncOptions = CELUOptions;
} // namespace functional

// ============================================================================

struct TORCH_API SoftplusOptions {
  TORCH_ARG(double, beta) = 1.0;

  TORCH_ARG(double, threshold) = 20.0;
};

namespace functional {
using SoftplusFuncOptions = SoftplusOptions;
} // namespace functional

// ============================================================================

struct TORCH_API SoftshrinkOptions {
  /* implicit */ SoftshrinkOptions(double lambda = 0.5);

  TORCH_ARG(double, lambda);
};

namespace functional {
using SoftshrinkFuncOptions = SoftshrinkOptions;
} // namespace functional

// ============================================================================

struct TORCH_API ThresholdOptions {
  ThresholdOptions(double threshold, double value)
      : threshold_(threshold), value_(value) {}

  TORCH_ARG(double, threshold);

  TORCH_ARG(double, value);

  TORCH_ARG(bool, inplace) = false;
};

namespace functional {
using ThresholdFuncOptions = ThresholdOptions;
} // namespace functional

// ============================================================================

namespace functional {

struct TORCH_API GumbelSoftmaxFuncOptions {
  TORCH_ARG(double, tau) = 1.0;

  TORCH_ARG(bool, hard) = false;

  TORCH_ARG(int, dim) = -1;
};

} // namespace functional

// ============================================================================

struct TORCH_API MultiheadAttentionOptions {
  MultiheadAttentionOptions(int64_t embed_dim, int64_t num_heads);

  TORCH_ARG(int64_t, embed_dim);

  TORCH_ARG(int64_t, num_heads);

  TORCH_ARG(double, dropout) = 0.0;

  TORCH_ARG(bool, bias) = true;

  TORCH_ARG(bool, add_bias_kv) = false;

  TORCH_ARG(bool, add_zero_attn) = false;

  TORCH_ARG(int64_t, kdim);

  TORCH_ARG(int64_t, vdim);
};

// ============================================================================

namespace functional {

struct TORCH_API MultiheadAttentionForwardFuncOptions {
  MultiheadAttentionForwardFuncOptions(
      int64_t embed_dim_to_check,
      int64_t num_heads,
      Tensor in_proj_weight,
      Tensor in_proj_bias,
      Tensor bias_k,
      Tensor bias_v,
      bool add_zero_attn,
      double dropout_p,
      Tensor out_proj_weight,
      Tensor out_proj_bias);

  TORCH_ARG(int64_t, embed_dim_to_check);

  TORCH_ARG(int64_t, num_heads);

  TORCH_ARG(Tensor, in_proj_weight);

  TORCH_ARG(Tensor, in_proj_bias);

  TORCH_ARG(Tensor, bias_k);

  TORCH_ARG(Tensor, bias_v);

  TORCH_ARG(bool, add_zero_attn);

  TORCH_ARG(double, dropout_p);

  TORCH_ARG(Tensor, out_proj_weight);

  TORCH_ARG(Tensor, out_proj_bias);

  TORCH_ARG(bool, training) = true;

  TORCH_ARG(Tensor, key_padding_mask) = {};

  TORCH_ARG(bool, need_weights) = true;

  TORCH_ARG(Tensor, attn_mask) = {};

  TORCH_ARG(bool, use_separate_proj_weight) = false;

  TORCH_ARG(Tensor, q_proj_weight) = {};

  TORCH_ARG(Tensor, k_proj_weight) = {};

  TORCH_ARG(Tensor, v_proj_weight) = {};

  TORCH_ARG(Tensor, static_k) = {};

  TORCH_ARG(Tensor, static_v) = {};

  TORCH_ARG(bool, average_attn_weights) = true;
};

} // namespace functional

} // namespace torch::nn

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources