Program Listing for File activation.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/options/activation.h
)
#pragma once
#include <torch/arg.h>
#include <torch/csrc/Export.h>
#include <torch/enum.h>
#include <torch/types.h>
namespace torch::nn {
struct TORCH_API ELUOptions {
TORCH_ARG(double, alpha) = 1.0;
TORCH_ARG(bool, inplace) = false;
};
namespace functional {
using ELUFuncOptions = ELUOptions;
} // namespace functional
// ============================================================================
struct TORCH_API SELUOptions {
/* implicit */ SELUOptions(bool inplace = false);
TORCH_ARG(bool, inplace);
};
namespace functional {
using SELUFuncOptions = SELUOptions;
} // namespace functional
// ============================================================================
struct TORCH_API GLUOptions {
/* implicit */ GLUOptions(int64_t dim = -1);
TORCH_ARG(int64_t, dim);
};
namespace functional {
using GLUFuncOptions = GLUOptions;
} // namespace functional
// ============================================================================
struct TORCH_API GELUOptions {
TORCH_ARG(std::string, approximate) = "none";
};
namespace functional {
using GELUFuncOptions = GELUOptions;
} // namespace functional
// ============================================================================
struct TORCH_API HardshrinkOptions {
/* implicit */ HardshrinkOptions(double lambda = 0.5);
TORCH_ARG(double, lambda);
};
namespace functional {
using HardshrinkFuncOptions = HardshrinkOptions;
} // namespace functional
// ============================================================================
struct TORCH_API HardtanhOptions {
TORCH_ARG(double, min_val) = -1.0;
TORCH_ARG(double, max_val) = 1.0;
TORCH_ARG(bool, inplace) = false;
};
namespace functional {
using HardtanhFuncOptions = HardtanhOptions;
} // namespace functional
// ============================================================================
struct TORCH_API LeakyReLUOptions {
TORCH_ARG(double, negative_slope) = 1e-2;
TORCH_ARG(bool, inplace) = false;
};
namespace functional {
using LeakyReLUFuncOptions = LeakyReLUOptions;
} // namespace functional
// ============================================================================
struct TORCH_API SoftmaxOptions {
SoftmaxOptions(int64_t dim);
TORCH_ARG(int64_t, dim);
};
// ============================================================================
namespace functional {
struct TORCH_API SoftmaxFuncOptions {
SoftmaxFuncOptions(int64_t dim);
TORCH_ARG(int64_t, dim);
TORCH_ARG(std::optional<torch::Dtype>, dtype) = std::nullopt;
};
} // namespace functional
// ============================================================================
struct TORCH_API SoftminOptions {
SoftminOptions(int64_t dim);
TORCH_ARG(int64_t, dim);
};
// ============================================================================
namespace functional {
struct TORCH_API SoftminFuncOptions {
SoftminFuncOptions(int64_t dim);
TORCH_ARG(int64_t, dim);
TORCH_ARG(std::optional<torch::Dtype>, dtype) = std::nullopt;
};
} // namespace functional
// ============================================================================
struct TORCH_API LogSoftmaxOptions {
LogSoftmaxOptions(int64_t dim);
TORCH_ARG(int64_t, dim);
};
// ============================================================================
namespace functional {
struct TORCH_API LogSoftmaxFuncOptions {
LogSoftmaxFuncOptions(int64_t dim);
TORCH_ARG(int64_t, dim);
TORCH_ARG(std::optional<torch::Dtype>, dtype) = std::nullopt;
};
} // namespace functional
// ============================================================================
struct TORCH_API PReLUOptions {
TORCH_ARG(int64_t, num_parameters) = 1;
TORCH_ARG(double, init) = 0.25;
};
// ============================================================================
struct TORCH_API ReLUOptions {
/* implicit */ ReLUOptions(bool inplace = false);
TORCH_ARG(bool, inplace);
};
namespace functional {
using ReLUFuncOptions = ReLUOptions;
} // namespace functional
// ============================================================================
struct TORCH_API ReLU6Options {
/* implicit */ ReLU6Options(bool inplace = false);
TORCH_ARG(bool, inplace);
};
namespace functional {
using ReLU6FuncOptions = ReLU6Options;
} // namespace functional
// ============================================================================
struct TORCH_API RReLUOptions {
TORCH_ARG(double, lower) = 1.0 / 8.0;
TORCH_ARG(double, upper) = 1.0 / 3.0;
TORCH_ARG(bool, inplace) = false;
};
// ============================================================================
namespace functional {
struct TORCH_API RReLUFuncOptions {
TORCH_ARG(double, lower) = 1.0 / 8.0;
TORCH_ARG(double, upper) = 1.0 / 3.0;
TORCH_ARG(bool, training) = false;
TORCH_ARG(bool, inplace) = false;
};
} // namespace functional
// ============================================================================
struct TORCH_API CELUOptions {
TORCH_ARG(double, alpha) = 1.0;
TORCH_ARG(bool, inplace) = false;
};
namespace functional {
using CELUFuncOptions = CELUOptions;
} // namespace functional
// ============================================================================
struct TORCH_API SoftplusOptions {
TORCH_ARG(double, beta) = 1.0;
TORCH_ARG(double, threshold) = 20.0;
};
namespace functional {
using SoftplusFuncOptions = SoftplusOptions;
} // namespace functional
// ============================================================================
struct TORCH_API SoftshrinkOptions {
/* implicit */ SoftshrinkOptions(double lambda = 0.5);
TORCH_ARG(double, lambda);
};
namespace functional {
using SoftshrinkFuncOptions = SoftshrinkOptions;
} // namespace functional
// ============================================================================
struct TORCH_API ThresholdOptions {
ThresholdOptions(double threshold, double value)
: threshold_(threshold), value_(value) {}
TORCH_ARG(double, threshold);
TORCH_ARG(double, value);
TORCH_ARG(bool, inplace) = false;
};
namespace functional {
using ThresholdFuncOptions = ThresholdOptions;
} // namespace functional
// ============================================================================
namespace functional {
struct TORCH_API GumbelSoftmaxFuncOptions {
TORCH_ARG(double, tau) = 1.0;
TORCH_ARG(bool, hard) = false;
TORCH_ARG(int, dim) = -1;
};
} // namespace functional
// ============================================================================
struct TORCH_API MultiheadAttentionOptions {
MultiheadAttentionOptions(int64_t embed_dim, int64_t num_heads);
TORCH_ARG(int64_t, embed_dim);
TORCH_ARG(int64_t, num_heads);
TORCH_ARG(double, dropout) = 0.0;
TORCH_ARG(bool, bias) = true;
TORCH_ARG(bool, add_bias_kv) = false;
TORCH_ARG(bool, add_zero_attn) = false;
TORCH_ARG(int64_t, kdim);
TORCH_ARG(int64_t, vdim);
};
// ============================================================================
namespace functional {
struct TORCH_API MultiheadAttentionForwardFuncOptions {
MultiheadAttentionForwardFuncOptions(
int64_t embed_dim_to_check,
int64_t num_heads,
Tensor in_proj_weight,
Tensor in_proj_bias,
Tensor bias_k,
Tensor bias_v,
bool add_zero_attn,
double dropout_p,
Tensor out_proj_weight,
Tensor out_proj_bias);
TORCH_ARG(int64_t, embed_dim_to_check);
TORCH_ARG(int64_t, num_heads);
TORCH_ARG(Tensor, in_proj_weight);
TORCH_ARG(Tensor, in_proj_bias);
TORCH_ARG(Tensor, bias_k);
TORCH_ARG(Tensor, bias_v);
TORCH_ARG(bool, add_zero_attn);
TORCH_ARG(double, dropout_p);
TORCH_ARG(Tensor, out_proj_weight);
TORCH_ARG(Tensor, out_proj_bias);
TORCH_ARG(bool, training) = true;
TORCH_ARG(Tensor, key_padding_mask) = {};
TORCH_ARG(bool, need_weights) = true;
TORCH_ARG(Tensor, attn_mask) = {};
TORCH_ARG(bool, use_separate_proj_weight) = false;
TORCH_ARG(Tensor, q_proj_weight) = {};
TORCH_ARG(Tensor, k_proj_weight) = {};
TORCH_ARG(Tensor, v_proj_weight) = {};
TORCH_ARG(Tensor, static_k) = {};
TORCH_ARG(Tensor, static_v) = {};
TORCH_ARG(bool, average_attn_weights) = true;
};
} // namespace functional
} // namespace torch::nn