Shortcuts

Program Listing for File embedding.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/embedding.h)

#pragma once

#include <torch/nn/cloneable.h>
#include <torch/nn/functional/embedding.h>
#include <torch/nn/modules/common.h>
#include <torch/nn/options/embedding.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>

#include <cstddef>

namespace torch {
namespace nn {

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Embedding
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
class TORCH_API EmbeddingImpl : public torch::nn::Cloneable<EmbeddingImpl> {
 public:
  EmbeddingImpl(int64_t num_embeddings, int64_t embedding_dim)
      : EmbeddingImpl(EmbeddingOptions(num_embeddings, embedding_dim)) {}
  explicit EmbeddingImpl(const EmbeddingOptions& options_);

  void reset() override;

  void reset_parameters();

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& indices);

  EmbeddingOptions options;

  Tensor weight;
};

class Embedding : public torch::nn::ModuleHolder<EmbeddingImpl> {
 public:
  using torch::nn::ModuleHolder<EmbeddingImpl>::ModuleHolder;

  static Embedding from_pretrained(
      const torch::Tensor& embeddings,
      const EmbeddingFromPretrainedOptions& options = {}) {
    TORCH_CHECK(
        embeddings.dim() == 2,
        "Embeddings parameter is expected to be 2-dimensional");

    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
    int64_t rows, cols;
    rows = embeddings.size(0);
    cols = embeddings.size(1);

    Embedding embedding(EmbeddingOptions(rows, cols)
                            ._weight(embeddings)
                            .padding_idx(options.padding_idx())
                            .max_norm(options.max_norm())
                            .norm_type(options.norm_type())
                            .scale_grad_by_freq(options.scale_grad_by_freq())
                            .sparse(options.sparse()));
    embedding->weight.set_requires_grad(!options.freeze());
    return embedding;
  }
};

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ EmbeddingBag
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
class TORCH_API EmbeddingBagImpl
    : public torch::nn::Cloneable<EmbeddingBagImpl> {
 public:
  EmbeddingBagImpl(int64_t num_embeddings, int64_t embedding_dim)
      : EmbeddingBagImpl(EmbeddingBagOptions(num_embeddings, embedding_dim)) {}
  explicit EmbeddingBagImpl(const EmbeddingBagOptions& options_);

  void reset() override;

  void reset_parameters();

  void pretty_print(std::ostream& stream) const override;

  EmbeddingBagOptions options;
  Tensor weight;

  Tensor forward(
      const Tensor& input,
      const Tensor& offsets = {},
      const Tensor& per_sample_weights = {});

 protected:
  FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}, {2, AnyValue(Tensor())})
};

class EmbeddingBag : public torch::nn::ModuleHolder<EmbeddingBagImpl> {
 public:
  using torch::nn::ModuleHolder<EmbeddingBagImpl>::ModuleHolder;

  static EmbeddingBag from_pretrained(
      const torch::Tensor& embeddings,
      const EmbeddingBagFromPretrainedOptions& options = {}) {
    TORCH_CHECK(
        embeddings.dim() == 2,
        "Embeddings parameter is expected to be 2-dimensional");

    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
    int64_t rows, cols;
    rows = embeddings.size(0);
    cols = embeddings.size(1);

    EmbeddingBag embeddingbag(
        EmbeddingBagOptions(rows, cols)
            ._weight(embeddings)
            .max_norm(options.max_norm())
            .norm_type(options.norm_type())
            .scale_grad_by_freq(options.scale_grad_by_freq())
            .mode(options.mode())
            .sparse(options.sparse())
            .padding_idx(options.padding_idx()));
    embeddingbag->weight.set_requires_grad(!options.freeze());
    return embeddingbag;
  }
};
} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources