Shortcuts

Program Listing for File embedding.h

Return to documentation for file (torch/csrc/api/include/torch/nn/functional/embedding.h)

#pragma once

#include <torch/nn/options/embedding.h>

namespace torch {
namespace nn {
namespace functional {

inline Tensor one_hot(const Tensor& tensor, int64_t num_classes = -1) {
  return torch::one_hot(tensor, num_classes);
}

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline void _no_grad_embedding_renorm_(Tensor weight, const Tensor& input, float max_norm, float norm_type) {
  torch::NoGradGuard no_grad;
  torch::embedding_renorm_(weight, input, max_norm, norm_type);
}

inline Tensor embedding(const Tensor& input,
                        const Tensor& weight,
                        c10::optional<int64_t> padding_idx,
                        c10::optional<double> max_norm,
                        double norm_type,
                        bool scale_grad_by_freq,
                        bool sparse) {
  auto input_ = input;

  if (padding_idx != c10::nullopt) {
    if (*padding_idx > 0) {
      TORCH_CHECK(*padding_idx < weight.size(0), "Padding_idx must be within num_embeddings");
    }
    else if (*padding_idx < 0) {
      TORCH_CHECK(*padding_idx >= -weight.size(0), "Padding_idx must be within num_embedding");
      padding_idx = weight.size(0) + *padding_idx;
    }
  } else {
    padding_idx = -1;
  }

  if (max_norm != c10::nullopt) {
    input_ = input_.contiguous();
    // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
    _no_grad_embedding_renorm_(weight, input_, *max_norm, norm_type);
  }
  return torch::embedding(weight, input_, *padding_idx, scale_grad_by_freq, sparse);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor embedding(const Tensor& input, const Tensor& weight, const EmbeddingFuncOptions& options = {}) {
  return detail::embedding(
    input,
    weight,
    options.padding_idx(),
    options.max_norm(),
    options.norm_type(),
    options.scale_grad_by_freq(),
    options.sparse());
}

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor embedding_bag(
    const Tensor& input,
    const Tensor& weight,
    const Tensor& offsets,
    c10::optional<double> max_norm,
    double norm_type,
    bool scale_grad_by_freq,
    EmbeddingBagMode mode,
    bool sparse,
    const Tensor& per_sample_weights,
    bool include_last_offset,
    c10::optional<int64_t> padding_idx) {
  auto input_ = input;
  auto offsets_ = offsets;
  auto per_sample_weights_ = per_sample_weights;
  TORCH_CHECK(!per_sample_weights_.defined() || input_.sizes() == per_sample_weights_.sizes(),
    "embedding_bag: If per_sample_weights (", per_sample_weights_.sizes(), ") is not null, then it must have the same shape as the input (", input_.sizes(), ")");
  if (input_.dim() == 2) {
    TORCH_CHECK(!offsets_.defined(),
      "If input is 2D, then offsets has to be null, as input is treated is a mini-batch of fixed length sequences. However, found offsets of type Tensor");
    offsets_ = torch::arange(0, input_.numel(), input_.size(1), torch::TensorOptions().dtype(torch::kLong).device(input_.device()));
    input_ = input_.reshape(-1);
    if (per_sample_weights_.defined()) {
      per_sample_weights_ = per_sample_weights_.reshape(-1);
    }
  } else if (input_.dim() == 1) {
    TORCH_CHECK(offsets_.defined(), "offsets has to be a 1D Tensor but got null");
    TORCH_CHECK(offsets_.dim() == 1, "offsets has to be a 1D Tensor");
    TORCH_CHECK(offsets_[0].item<int64_t>() == 0, "offsets[0] has to be 0, i.e., the first sequence in the mini-batch has to start from position 0. However, got ",
     offsets_[0].item<int64_t>());
    TORCH_CHECK(offsets_[-1].item<int64_t>() <= input_.size(0), "offsets[-1] can not be greater than input's length({",
              input_.size(0), "}), but got offsets[-1] of {", offsets_[-1].item<int64_t>(), "}");
  } else {
    TORCH_CHECK(false, "input has to be 1D or 2D Tensor, but got Tensor of dimension ", input_.dim());
  }

  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
  int mode_enum;
  if (c10::get_if<enumtype::kSum>(&mode)) {
    mode_enum = 0;
  } else if (c10::get_if<enumtype::kMean>(&mode)) {
    mode_enum = 1;
  } else if (c10::get_if<enumtype::kMax>(&mode)) {
    mode_enum = 2;
    TORCH_CHECK(!scale_grad_by_freq, "max mode does not support scaling the gradient by the frequency");
    TORCH_CHECK(!sparse, "max mode does not support sparse weights");
  } else {
    TORCH_CHECK(false, "mode has to be one of sum, mean or max");
  }

  if (max_norm != c10::nullopt) {
    // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
    _no_grad_embedding_renorm_(weight, input_, *max_norm, norm_type);
  }

  TORCH_CHECK(
    !per_sample_weights_.defined() || c10::get_if<enumtype::kSum>(&mode),
    "embedding_bag: per_sample_weights was not null. ",
    "per_sample_weights is only supported for mode='kSum' (got mode='",
    torch::enumtype::get_enum_name(mode), "').Please open a feature request on GitHub.");

  return std::get<0>(
    torch::embedding_bag(
      weight,
      input_,
      offsets_,
      scale_grad_by_freq,
      mode_enum,
      sparse,
      per_sample_weights_,
      include_last_offset,
      padding_idx));
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor embedding_bag(const Tensor& input, const Tensor& weight, const EmbeddingBagFuncOptions& options = {}) {
  return detail::embedding_bag(
    input,
    weight,
    options.offsets(),
    options.max_norm(),
    options.norm_type(),
    options.scale_grad_by_freq(),
    options.mode(),
    options.sparse(),
    options.per_sample_weights(),
    options.include_last_offset(),
    options.padding_idx());
}

} // namespace functional
} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources