Shortcuts

Program Listing for File activation.h

Return to documentation for file (torch/csrc/api/include/torch/nn/functional/activation.h)

#pragma once

#include <ATen/Dispatch.h>
#include <torch/nn/functional/dropout.h>
#include <torch/nn/functional/linear.h>
#include <torch/nn/options/activation.h>
#include <torch/nn/options/dropout.h>
#include <torch/nn/options/linear.h>
#include <torch/types.h>
#include <limits>
#include <utility>

namespace torch {
namespace nn {
namespace functional {

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor elu(Tensor input, double alpha, bool inplace) {
  if (inplace) {
    return torch::elu_(input, alpha);
  } else {
    return torch::elu(input, alpha);
  }
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor elu(Tensor input, const ELUFuncOptions& options = {}) {
  return detail::elu(std::move(input), options.alpha(), options.inplace());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor selu(Tensor input, bool inplace) {
  if (inplace) {
    return torch::selu_(input);
  } else {
    return torch::selu(input);
  }
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor selu(Tensor input, const SELUFuncOptions& options = {}) {
  return detail::selu(std::move(input), options.inplace());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor hardshrink(const Tensor& input, double lambda) {
  return torch::hardshrink(input, lambda);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor hardshrink(
    const Tensor& input,
    const HardshrinkFuncOptions& options = {}) {
  return detail::hardshrink(input, options.lambda());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor hardtanh(
    Tensor input,
    double min_val,
    double max_val,
    bool inplace) {
  if (inplace) {
    return torch::hardtanh_(input, min_val, max_val);
  } else {
    return torch::hardtanh(input, min_val, max_val);
  }
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor hardtanh(Tensor input, const HardtanhFuncOptions& options = {}) {
  return detail::hardtanh(
      std::move(input),
      options.min_val(),
      options.max_val(),
      options.inplace());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor leaky_relu(Tensor input, double negative_slope, bool inplace) {
  if (inplace) {
    return torch::leaky_relu_(input, negative_slope);
  } else {
    return torch::leaky_relu(input, negative_slope);
  }
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor leaky_relu(
    Tensor input,
    const LeakyReLUFuncOptions& options = {}) {
  return detail::leaky_relu(
      std::move(input), options.negative_slope(), options.inplace());
}

// ============================================================================

inline Tensor logsigmoid(const Tensor& input) {
  return torch::log_sigmoid(input);
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor gumbel_softmax(
    const Tensor& logits,
    double tau,
    bool hard,
    int dim) {
  auto gumbels =
      -torch::empty_like(logits).exponential_().log(); // ~Gumbel(0,1)
  gumbels = (logits + gumbels) / tau; // ~Gumbel(logits, tau)
  auto y_soft = gumbels.softmax(dim);

  torch::Tensor ret;
  if (hard) {
    // Straight through.
    auto index = std::get<1>(y_soft.max(dim, /*keepdim=*/true));
    auto y_hard = torch::zeros_like(logits).scatter_(dim, index, 1.0);
    ret = y_hard - y_soft.detach() + y_soft;
  } else {
    ret = y_soft;
  }
  return ret;
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor gumbel_softmax(
    const Tensor& logits,
    const GumbelSoftmaxFuncOptions& options = {}) {
  return detail::gumbel_softmax(
      logits, options.tau(), options.hard(), options.dim());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor softmax(
    const Tensor& input,
    int64_t dim,
    std::optional<torch::Dtype> dtype) {
  Tensor ret;

  if (dtype == std::nullopt) {
    ret = input.softmax(dim);
  } else {
    ret = input.softmax(dim, dtype);
  }

  return ret;
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor softmax(const Tensor& input, const SoftmaxFuncOptions& options) {
  return detail::softmax(input, options.dim(), options.dtype());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor softmin(
    const Tensor& input,
    int64_t dim,
    std::optional<torch::Dtype> dtype) {
  Tensor ret;

  if (dtype == std::nullopt) {
    ret = (-input).softmax(dim);
  } else {
    ret = (-input).softmax(dim, dtype);
  }

  return ret;
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor softmin(const Tensor& input, const SoftminFuncOptions& options) {
  return detail::softmin(input, options.dim(), options.dtype());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor log_softmax(
    const Tensor& input,
    int64_t dim,
    std::optional<torch::Dtype> dtype) {
  Tensor ret;

  if (dtype == std::nullopt) {
    ret = input.log_softmax(dim);
  } else {
    ret = input.log_softmax(dim, dtype);
  }

  return ret;
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor log_softmax(
    const Tensor& input,
    const LogSoftmaxFuncOptions& options) {
  return detail::log_softmax(input, options.dim(), options.dtype());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor glu(const Tensor& input, int64_t dim) {
  TORCH_CHECK(
      input.dim() != 0,
      "glu does not suppport scalars because halving size must be even");
  return torch::glu(input, dim);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor glu(const Tensor& input, const GLUFuncOptions& options = {}) {
  return detail::glu(input, options.dim());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor gelu(const Tensor& input, string approximate) {
  return torch::gelu(input, approximate);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor gelu(const Tensor& input, const GELUFuncOptions& options = {}) {
  return detail::gelu(input, options.approximate());
}

// ============================================================================

inline Tensor silu(const Tensor& input) {
  return torch::silu(input);
}

// ============================================================================

inline Tensor mish(const Tensor& input) {
  return torch::mish(input);
}

// ============================================================================

inline Tensor prelu(const Tensor& input, const Tensor& weight) {
  return torch::prelu(input, weight);
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor relu(Tensor input, bool inplace) {
  if (inplace) {
    return torch::relu_(input);
  } else {
    return torch::relu(input);
  }
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor relu(Tensor input, const ReLUFuncOptions& options = {}) {
  return detail::relu(std::move(input), options.inplace());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor relu6(Tensor input, bool inplace) {
  if (inplace) {
    return torch::relu6_(input);
  } else {
    return torch::relu6(input);
  }
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor relu6(Tensor input, const ReLU6FuncOptions& options = {}) {
  return detail::relu6(std::move(input), options.inplace());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor rrelu(
    Tensor input,
    double lower,
    double upper,
    bool training,
    bool inplace) {
  if (inplace) {
    return torch::rrelu_(input, lower, upper, training);
  } else {
    return torch::rrelu(input, lower, upper, training);
  }
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor rrelu(Tensor input, const RReLUFuncOptions& options = {}) {
  return detail::rrelu(
      std::move(input),
      options.lower(),
      options.upper(),
      options.training(),
      options.inplace());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor celu(Tensor input, double alpha, bool inplace) {
  if (inplace) {
    return torch::celu_(input, alpha);
  } else {
    return torch::celu(input, alpha);
  }
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor celu(Tensor input, const CELUFuncOptions& options = {}) {
  return detail::celu(std::move(input), options.alpha(), options.inplace());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor softplus(const Tensor& input, double beta, double threshold) {
  return torch::softplus(input, beta, threshold);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor softplus(
    const Tensor& input,
    const SoftplusFuncOptions& options = {}) {
  return detail::softplus(input, options.beta(), options.threshold());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor softshrink(const Tensor& input, double lambda) {
  return torch::softshrink(input, lambda);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor softshrink(
    const Tensor& input,
    const SoftshrinkFuncOptions& options = {}) {
  return detail::softshrink(input, options.lambda());
}

// ============================================================================

inline Tensor softsign(const Tensor& input) {
  return input / (input.abs() + 1);
}

// ============================================================================

inline Tensor tanhshrink(const Tensor& input) {
  return input - input.tanh();
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor threshold(
    Tensor input,
    double threshold,
    double value,
    bool inplace) {
  if (inplace) {
    return torch::threshold_(input, threshold, value);
  } else {
    return torch::threshold(input, threshold, value);
  }
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor threshold(Tensor input, const ThresholdFuncOptions& options) {
  return detail::threshold(
      std::move(input),
      options.threshold(),
      options.value(),
      options.inplace());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline std::tuple<Tensor, Tensor> multi_head_attention_forward(
    const Tensor& query,
    const Tensor& key,
    const Tensor& value,
    int64_t embed_dim_to_check,
    int64_t num_heads,
    const Tensor& in_proj_weight,
    const Tensor& in_proj_bias,
    const Tensor& bias_k,
    const Tensor& bias_v,
    bool add_zero_attn,
    double dropout_p,
    const Tensor& out_proj_weight,
    const Tensor& out_proj_bias,
    bool training = true,
    const Tensor& key_padding_mask = {},
    bool need_weights = true,
    const Tensor& attn_mask = {},
    bool use_separate_proj_weight = false,
    const Tensor& q_proj_weight = {},
    const Tensor& k_proj_weight = {},
    const Tensor& v_proj_weight = {},
    const Tensor& static_k = {},
    const Tensor& static_v = {},
    bool average_attn_weights = true) {
  namespace F = torch::nn::functional;

  const auto query_sizes = query.sizes();
  const auto& tgt_len = query_sizes[0];
  const auto& bsz = query_sizes[1];
  const auto& embed_dim = query_sizes[2];
  TORCH_INTERNAL_ASSERT(embed_dim == embed_dim_to_check);
  TORCH_INTERNAL_ASSERT(key.sizes() == value.sizes());

  const auto head_dim = embed_dim / num_heads;
  TORCH_CHECK(
      head_dim * num_heads == embed_dim,
      "embed_dim must be divisible by num_heads");
  const auto scaling = 1 / std::sqrt(head_dim);

  Tensor q, k, v;
  if (!use_separate_proj_weight) {
    if (torch::equal(query, key) && torch::equal(key, value)) {
      // self-attention
      const auto chunks =
          F::linear(query, in_proj_weight, in_proj_bias).chunk(3, /*dim=*/-1);
      q = chunks[0];
      k = chunks[1];
      v = chunks[2];
    } else if (torch::equal(key, value)) {
      // encoder-decoder attention
      // This is inline in_proj function with in_proj_weight and in_proj_bias
      auto _b = in_proj_bias;
      auto _start = 0;
      auto _end = embed_dim;
      auto _w = in_proj_weight.slice(/*dim=*/0, _start, _end);
      if (_b.defined()) {
        _b = _b.slice(/*dim=*/0, _start, _end);
      }
      q = F::linear(query, _w, _b);

      if (!key.defined()) {
        TORCH_INTERNAL_ASSERT(!value.defined());
        k.reset();
        v.reset();
      } else {
        // This is inline in_proj function with in_proj_weight and in_proj_bias
        _b = in_proj_bias;
        _start = embed_dim;
        _w = in_proj_weight.slice(/*dim=*/0, _start);
        if (_b.defined()) {
          _b = _b.slice(/*dim=*/0, _start);
        }
        const auto chunks = F::linear(key, _w, _b).chunk(2, /*dim=*/-1);
        k = chunks[0];
        v = chunks[1];
      }
    } else {
      // This is inline in_proj function with in_proj_weight and in_proj_bias
      auto _b = in_proj_bias;
      auto _start = 0;
      auto _end = embed_dim;
      auto _w = in_proj_weight.slice(/*dim=*/0, _start, _end);
      if (_b.defined()) {
        _b = _b.slice(/*dim=*/0, _start, _end);
      }
      q = F::linear(query, _w, _b);

      // This is inline in_proj function with in_proj_weight and in_proj_bias
      _b = in_proj_bias;
      _start = embed_dim;
      _end = embed_dim * 2;
      _w = in_proj_weight.slice(/*dim=*/0, _start, _end);
      if (_b.defined()) {
        _b = _b.slice(/*dim=*/0, _start, _end);
      }
      k = F::linear(key, _w, _b);

      // This is inline in_proj function with in_proj_weight and in_proj_bias
      _b = in_proj_bias;
      _start = embed_dim * 2;
      _w = in_proj_weight.slice(/*dim=*/0, _start);
      if (_b.defined()) {
        _b = _b.slice(0, _start);
      }
      v = F::linear(value, _w, _b);
    }
  } else {
    const auto& q_proj_weight_non_opt = q_proj_weight;
    {
      const auto sizes = q_proj_weight_non_opt.sizes();
      const auto len1 = sizes[0];
      const auto len2 = sizes[1];
      TORCH_CHECK(len1 == embed_dim && len2 == query.size(-1));
    }

    const auto& k_proj_weight_non_opt = k_proj_weight;
    {
      const auto sizes = k_proj_weight_non_opt.sizes();
      const auto len1 = sizes[0];
      const auto len2 = sizes[1];
      TORCH_CHECK(len1 == embed_dim && len2 == key.size(-1));
    }

    const auto& v_proj_weight_non_opt = v_proj_weight;
    {
      const auto sizes = v_proj_weight_non_opt.sizes();
      const auto len1 = sizes[0];
      const auto len2 = sizes[1];
      TORCH_CHECK(len1 == embed_dim && len2 == value.size(-1));
    }

    if (in_proj_bias.defined()) {
      q = F::linear(
          query,
          q_proj_weight_non_opt,
          in_proj_bias.slice(/*dim=*/0, 0, embed_dim));
      k = F::linear(
          key,
          k_proj_weight_non_opt,
          in_proj_bias.slice(/*dim=*/0, embed_dim, (embed_dim * 2)));
      v = F::linear(
          value,
          v_proj_weight_non_opt,
          in_proj_bias.slice(/*dim=*/0, (embed_dim * 2)));
    } else {
      q = F::linear(query, q_proj_weight_non_opt, in_proj_bias);
      k = F::linear(key, k_proj_weight_non_opt, in_proj_bias);
      v = F::linear(value, v_proj_weight_non_opt, in_proj_bias);
    }
  }
  q = q * scaling;
  Tensor attn_mask_ = attn_mask;
  Tensor key_padding_mask_ = key_padding_mask;
  if (bias_k.defined() && bias_v.defined()) {
    if (!static_k.defined() && !static_v.defined()) {
      k = torch::cat({k, bias_k.repeat({1, bsz, 1})});
      v = torch::cat({v, bias_v.repeat({1, bsz, 1})});
      if (attn_mask_.defined()) {
        attn_mask_ = torch::cat(
            {attn_mask_,
             torch::zeros(
                 {attn_mask_.size(0), 1},
                 at::TensorOptions(attn_mask_.dtype())
                     .device(attn_mask_.device()))},
            /*dim=*/1);
      }
      if (key_padding_mask_.defined()) {
        key_padding_mask_ = torch::cat(
            {key_padding_mask_,
             torch::zeros(
                 {key_padding_mask_.size(0), 1},
                 at::TensorOptions(key_padding_mask_.dtype())
                     .device(key_padding_mask_.device()))},
            /*dim=*/1);
      }
    } else {
      TORCH_CHECK(!static_k.defined(), "bias cannot be added to static key.");
      TORCH_CHECK(!static_v.defined(), "bias cannot be added to static value.");
    }
  } else {
    TORCH_CHECK(!bias_k.defined());
    TORCH_CHECK(!bias_v.defined());
  }
  q = q.contiguous().view({tgt_len, bsz * num_heads, head_dim}).transpose(0, 1);
  if (k.defined()) {
    k = k.contiguous().view({-1, bsz * num_heads, head_dim}).transpose(0, 1);
  }
  if (v.defined()) {
    v = v.contiguous().view({-1, bsz * num_heads, head_dim}).transpose(0, 1);
  }
  if (static_k.defined()) {
    TORCH_CHECK(static_k.size(0) == bsz * num_heads);
    TORCH_CHECK(static_k.size(2) == head_dim);
    k = static_k;
  }
  if (static_v.defined()) {
    TORCH_CHECK(static_v.size(0) == bsz * num_heads);
    TORCH_CHECK(static_v.size(2) == head_dim);
    v = static_v;
  }
  auto src_len = k.size(1);
  if (key_padding_mask_.defined()) {
    TORCH_CHECK(key_padding_mask_.size(0) == bsz);
    TORCH_CHECK(key_padding_mask_.size(1) == src_len);
  }
  if (add_zero_attn) {
    src_len += 1;
    auto k_sizes = k.sizes().vec();
    k_sizes[1] = 1;
    k = torch::cat(
        {k,
         torch::zeros(
             k_sizes, at::TensorOptions(k.dtype()).device(k.device()))},
        /*dim=*/1);
    auto v_sizes = v.sizes().vec();
    v_sizes[1] = 1;
    v = torch::cat(
        {v,
         torch::zeros(
             v_sizes, at::TensorOptions(v.dtype()).device(v.device()))},
        /*dim=*/1);
    if (attn_mask_.defined()) {
      attn_mask_ = torch::cat(
          {attn_mask_,
           torch::zeros(
               {attn_mask_.size(0), 1},
               at::TensorOptions(attn_mask_.dtype())
                   .device(attn_mask_.device()))},
          /*dim=*/1);
    }
    if (key_padding_mask_.defined()) {
      key_padding_mask_ = torch::cat(
          {key_padding_mask_,
           torch::zeros(
               {key_padding_mask_.size(0), 1},
               at::TensorOptions(key_padding_mask_.dtype())
                   .device(key_padding_mask_.device()))},
          /*dim=*/1);
    }
  }
  auto attn_output_weights = torch::bmm(q, k.transpose(1, 2));
  TORCH_CHECK(
      attn_output_weights.sizes() ==
      IntArrayRef({bsz * num_heads, tgt_len, src_len}));
  if (attn_mask_.defined()) {
    attn_mask_ = attn_mask_.unsqueeze(0);
    attn_output_weights += attn_mask_;
  }
  if (key_padding_mask_.defined()) {
    attn_output_weights =
        attn_output_weights.view({bsz, num_heads, tgt_len, src_len});
    attn_output_weights = AT_DISPATCH_FLOATING_TYPES(
        attn_output_weights.scalar_type(),
        "attn_output_weights.masked_fill",
        [&]() {
          return attn_output_weights.masked_fill(
              key_padding_mask_.unsqueeze(1).unsqueeze(2),
              -std::numeric_limits<scalar_t>::infinity());
        });
    attn_output_weights =
        attn_output_weights.view({bsz * num_heads, tgt_len, src_len});
  }
  // NOLINTNEXTLINE(bugprone-argument-comment)
  attn_output_weights = F::softmax(attn_output_weights, /*dim=*/-1);
  attn_output_weights = F::dropout(
      attn_output_weights,
      F::DropoutFuncOptions().p(dropout_p).training(training));
  auto attn_output = torch::bmm(attn_output_weights, v);
  TORCH_CHECK(
      attn_output.sizes() == IntArrayRef({bsz * num_heads, tgt_len, head_dim}));
  attn_output =
      attn_output.transpose(0, 1).contiguous().view({tgt_len, bsz, embed_dim});
  attn_output = F::linear(attn_output, out_proj_weight, out_proj_bias);
  if (need_weights) {
    attn_output_weights =
        attn_output_weights.view({bsz, num_heads, tgt_len, src_len});
    if (average_attn_weights) {
      // average attention weights over heads
      attn_output_weights = attn_output_weights.sum(/*dim=*/1) / num_heads;
    }
    return std::make_tuple(attn_output, attn_output_weights);
  } else {
    return std::make_tuple(attn_output, Tensor());
  }
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline std::tuple<Tensor, Tensor> multi_head_attention_forward(
    const Tensor& query,
    const Tensor& key,
    const Tensor& value,
    const MultiheadAttentionForwardFuncOptions& options) {
  return detail::multi_head_attention_forward(
      query,
      key,
      value,
      options.embed_dim_to_check(),
      options.num_heads(),
      options.in_proj_weight(),
      options.in_proj_bias(),
      options.bias_k(),
      options.bias_v(),
      options.add_zero_attn(),
      options.dropout_p(),
      options.out_proj_weight(),
      options.out_proj_bias(),
      options.training(),
      options.key_padding_mask(),
      options.need_weights(),
      options.attn_mask(),
      options.use_separate_proj_weight(),
      options.q_proj_weight(),
      options.k_proj_weight(),
      options.v_proj_weight(),
      options.static_k(),
      options.static_v(),
      options.average_attn_weights());
}

} // namespace functional
} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources