Shortcuts

Struct EmbeddingBagFuncOptions

Page Contents

Struct Documentation

struct torch::nn::functional::EmbeddingBagFuncOptions

Options for torch::nn::functional::embedding_bag.

Example:

namespace F = torch::nn::functional;
F::embedding_bag(input, weight, F::EmbeddingBagFuncOptions().mode(torch::kSum).offsets(offsets));

Public Functions

auto offsets(const torch::Tensor &new_offsets) -> decltype(*this)

Only used when input is 1D.

offsets determines the starting index position of each bag (sequence) in input.

auto offsets(torch::Tensor &&new_offsets) -> decltype(*this)
const torch::Tensor &offsets() const noexcept
torch::Tensor &offsets() noexcept
auto max_norm(const c10::optional<double> &new_max_norm) -> decltype(*this)

If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm.

auto max_norm(c10::optional<double> &&new_max_norm) -> decltype(*this)
const c10::optional<double> &max_norm() const noexcept
c10::optional<double> &max_norm() noexcept
auto norm_type(const double &new_norm_type) -> decltype(*this)

The p of the p-norm to compute for the max_norm option. Default 2.

auto norm_type(double &&new_norm_type) -> decltype(*this)
const double &norm_type() const noexcept
double &norm_type() noexcept
auto scale_grad_by_freq(const bool &new_scale_grad_by_freq) -> decltype(*this)

If given, this will scale gradients by the inverse of frequency of the words in the mini-batch.

Default false. Note: this option is not supported when mode="kMax".

auto scale_grad_by_freq(bool &&new_scale_grad_by_freq) -> decltype(*this)
const bool &scale_grad_by_freq() const noexcept
bool &scale_grad_by_freq() noexcept
auto mode(const EmbeddingBagMode &new_mode) -> decltype(*this)

"kSum", "kMean" or "kMax".

Specifies the way to reduce the bag. "kSum" computes the weighted sum, taking per_sample_weights into consideration. "kMean" computes the average of the values in the bag, "kMax" computes the max value over each bag.

auto mode(EmbeddingBagMode &&new_mode) -> decltype(*this)
const EmbeddingBagMode &mode() const noexcept
EmbeddingBagMode &mode() noexcept
auto sparse(const bool &new_sparse) -> decltype(*this)

If true, gradient w.r.t.

weight matrix will be a sparse tensor. Note: this option is not supported when mode="kMax".

auto sparse(bool &&new_sparse) -> decltype(*this)
const bool &sparse() const noexcept
bool &sparse() noexcept
auto per_sample_weights(const torch::Tensor &new_per_sample_weights) -> decltype(*this)

a tensor of float / double weights, or None to indicate all weights should be taken to be 1.

If specified, per_sample_weights must have exactly the same shape as input and is treated as having the same offsets, if those are not None.

auto per_sample_weights(torch::Tensor &&new_per_sample_weights) -> decltype(*this)
const torch::Tensor &per_sample_weights() const noexcept
torch::Tensor &per_sample_weights() noexcept
auto include_last_offset(const bool &new_include_last_offset) -> decltype(*this)

If true, offsets has one additional element, where the last element is equivalent to the size of indices.

This matches the CSR format. Note: this option is currently only supported when mode="sum".

auto include_last_offset(bool &&new_include_last_offset) -> decltype(*this)
const bool &include_last_offset() const noexcept
bool &include_last_offset() noexcept
auto padding_idx(const c10::optional<int64_t> &new_padding_idx) -> decltype(*this)

If specified, the entries at padding_idx do not contribute to the gradient; therefore, the embedding vector at padding_idx is not updated during training, i.e.

it remains as a fixed “pad”. Note that the embedding vector at padding_idx is excluded from the reduction.

auto padding_idx(c10::optional<int64_t> &&new_padding_idx) -> decltype(*this)
const c10::optional<int64_t> &padding_idx() const noexcept
c10::optional<int64_t> &padding_idx() noexcept

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources