struct torch::nn::functional::GumbelSoftmaxFuncOptions

Options for torch::nn::functional::gumbel_softmax.


namespace F = torch::nn::functional;
F::gumbel_softmax(logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(-1));

Public Functions

auto tau(const double &new_tau) -> decltype(*this)

non-negative scalar temperature

auto tau(double &&new_tau) -> decltype(*this)
const double &tau() const noexcept
double &tau() noexcept
auto hard(const bool &new_hard) -> decltype(*this)

returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd.

Default: False

auto hard(bool &&new_hard) -> decltype(*this)
const bool &hard() const noexcept
bool &hard() noexcept
auto dim(const int &new_dim) -> decltype(*this)

dimension along which softmax will be computed. Default: -1

auto dim(int &&new_dim) -> decltype(*this)
const int &dim() const noexcept
int &dim() noexcept


