Program Listing for File embedding.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/modules/embedding.h
)
#pragma once
#include <torch/nn/cloneable.h>
#include <torch/nn/functional/embedding.h>
#include <torch/nn/modules/common.h>
#include <torch/nn/options/embedding.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>
#include <cstddef>
namespace torch {
namespace nn {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Embedding
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API EmbeddingImpl : public torch::nn::Cloneable<EmbeddingImpl> {
public:
EmbeddingImpl(int64_t num_embeddings, int64_t embedding_dim)
: EmbeddingImpl(EmbeddingOptions(num_embeddings, embedding_dim)) {}
explicit EmbeddingImpl(EmbeddingOptions options_);
void reset() override;
void reset_parameters();
void pretty_print(std::ostream& stream) const override;
Tensor forward(const Tensor& indices);
EmbeddingOptions options;
Tensor weight;
};
class Embedding : public torch::nn::ModuleHolder<EmbeddingImpl> {
public:
using torch::nn::ModuleHolder<EmbeddingImpl>::ModuleHolder;
static Embedding from_pretrained(
const torch::Tensor& embeddings,
const EmbeddingFromPretrainedOptions& options = {}) {
TORCH_CHECK(
embeddings.dim() == 2,
"Embeddings parameter is expected to be 2-dimensional");
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t rows, cols;
rows = embeddings.size(0);
cols = embeddings.size(1);
Embedding embedding(EmbeddingOptions(rows, cols)
._weight(embeddings)
.padding_idx(options.padding_idx())
.max_norm(options.max_norm())
.norm_type(options.norm_type())
.scale_grad_by_freq(options.scale_grad_by_freq())
.sparse(options.sparse()));
embedding->weight.set_requires_grad(!options.freeze());
return embedding;
}
};
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ EmbeddingBag
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API EmbeddingBagImpl
: public torch::nn::Cloneable<EmbeddingBagImpl> {
public:
EmbeddingBagImpl(int64_t num_embeddings, int64_t embedding_dim)
: EmbeddingBagImpl(EmbeddingBagOptions(num_embeddings, embedding_dim)) {}
explicit EmbeddingBagImpl(EmbeddingBagOptions options_);
void reset() override;
void reset_parameters();
void pretty_print(std::ostream& stream) const override;
EmbeddingBagOptions options;
Tensor weight;
Tensor forward(
const Tensor& input,
const Tensor& offsets = {},
const Tensor& per_sample_weights = {});
protected:
FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}, {2, AnyValue(Tensor())})
};
class EmbeddingBag : public torch::nn::ModuleHolder<EmbeddingBagImpl> {
public:
using torch::nn::ModuleHolder<EmbeddingBagImpl>::ModuleHolder;
static EmbeddingBag from_pretrained(
const torch::Tensor& embeddings,
const EmbeddingBagFromPretrainedOptions& options = {}) {
TORCH_CHECK(
embeddings.dim() == 2,
"Embeddings parameter is expected to be 2-dimensional");
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t rows, cols;
rows = embeddings.size(0);
cols = embeddings.size(1);
EmbeddingBag embeddingbag(
EmbeddingBagOptions(rows, cols)
._weight(embeddings)
.max_norm(options.max_norm())
.norm_type(options.norm_type())
.scale_grad_by_freq(options.scale_grad_by_freq())
.mode(options.mode())
.sparse(options.sparse())
.padding_idx(options.padding_idx()));
embeddingbag->weight.set_requires_grad(!options.freeze());
return embeddingbag;
}
};
} // namespace nn
} // namespace torch