Shortcuts

Program Listing for File embedding.h

Return to documentation for file (torch/csrc/api/include/torch/nn/options/embedding.h)

#pragma once

#include <torch/arg.h>
#include <torch/csrc/Export.h>
#include <torch/enum.h>
#include <torch/types.h>

namespace torch {
namespace nn {

struct TORCH_API EmbeddingOptions {
  EmbeddingOptions(int64_t num_embeddings, int64_t embedding_dim);

  TORCH_ARG(int64_t, num_embeddings);
  TORCH_ARG(int64_t, embedding_dim);
  TORCH_ARG(c10::optional<int64_t>, padding_idx) = c10::nullopt;
  TORCH_ARG(c10::optional<double>, max_norm) = c10::nullopt;
  TORCH_ARG(double, norm_type) = 2.;
  TORCH_ARG(bool, scale_grad_by_freq) = false;
  TORCH_ARG(bool, sparse) = false;
  TORCH_ARG(torch::Tensor, _weight) = Tensor();
};

// ============================================================================

struct TORCH_API EmbeddingFromPretrainedOptions {
  TORCH_ARG(bool, freeze) = true;
  TORCH_ARG(c10::optional<int64_t>, padding_idx) = c10::nullopt;
  TORCH_ARG(c10::optional<double>, max_norm) = c10::nullopt;
  TORCH_ARG(double, norm_type) = 2.;
  TORCH_ARG(bool, scale_grad_by_freq) = false;
  TORCH_ARG(bool, sparse) = false;
};

// ============================================================================

namespace functional {

struct TORCH_API EmbeddingFuncOptions {
  TORCH_ARG(c10::optional<int64_t>, padding_idx) = c10::nullopt;
  TORCH_ARG(c10::optional<double>, max_norm) = c10::nullopt;
  TORCH_ARG(double, norm_type) = 2.;
  TORCH_ARG(bool, scale_grad_by_freq) = false;
  TORCH_ARG(bool, sparse) = false;
};

} // namespace functional

// ============================================================================

typedef std::variant<enumtype::kSum, enumtype::kMean, enumtype::kMax>
    EmbeddingBagMode;

struct TORCH_API EmbeddingBagOptions {
  EmbeddingBagOptions(int64_t num_embeddings, int64_t embedding_dim);

  TORCH_ARG(int64_t, num_embeddings);
  TORCH_ARG(int64_t, embedding_dim);
  TORCH_ARG(c10::optional<double>, max_norm) = c10::nullopt;
  TORCH_ARG(double, norm_type) = 2.;
  TORCH_ARG(bool, scale_grad_by_freq) = false;
  TORCH_ARG(EmbeddingBagMode, mode) = torch::kMean;
  TORCH_ARG(bool, sparse) = false;
  TORCH_ARG(torch::Tensor, _weight) = Tensor();
  TORCH_ARG(bool, include_last_offset) = false;
  TORCH_ARG(c10::optional<int64_t>, padding_idx) = c10::nullopt;
};

// ============================================================================

struct TORCH_API EmbeddingBagFromPretrainedOptions {
  TORCH_ARG(bool, freeze) = true;
  TORCH_ARG(c10::optional<double>, max_norm) = c10::nullopt;
  TORCH_ARG(double, norm_type) = 2.;
  TORCH_ARG(bool, scale_grad_by_freq) = false;
  TORCH_ARG(EmbeddingBagMode, mode) = torch::kMean;
  TORCH_ARG(bool, sparse) = false;
  TORCH_ARG(bool, include_last_offset) = false;
  TORCH_ARG(c10::optional<int64_t>, padding_idx) = c10::nullopt;
};

// ============================================================================

namespace functional {

struct TORCH_API EmbeddingBagFuncOptions {
  TORCH_ARG(torch::Tensor, offsets) = Tensor();
  TORCH_ARG(c10::optional<double>, max_norm) = c10::nullopt;
  TORCH_ARG(double, norm_type) = 2.;
  TORCH_ARG(bool, scale_grad_by_freq) = false;
  TORCH_ARG(EmbeddingBagMode, mode) = torch::kMean;
  TORCH_ARG(bool, sparse) = false;
  TORCH_ARG(torch::Tensor, per_sample_weights) = Tensor();
  TORCH_ARG(bool, include_last_offset) = false;
  TORCH_ARG(c10::optional<int64_t>, padding_idx) = c10::nullopt;
};

} // namespace functional

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources