Class AdaptiveLogSoftmaxWithLossImpl¶
Defined in File adaptive.h
Page Contents
Inheritance Relationships¶
Base Type¶
public torch::nn::Cloneable< AdaptiveLogSoftmaxWithLossImpl >
(Template Class Cloneable)
Class Documentation¶
-
class AdaptiveLogSoftmaxWithLossImpl : public torch::nn::Cloneable<AdaptiveLogSoftmaxWithLossImpl>¶
Efficient softmax approximation as described in
Efficient softmax approximation for GPUs
_ by Edouard Grave, Armand Joulin, Moustapha Cissé, David Grangier, and Hervé Jégou.See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveLogSoftmaxWithLoss to learn about the exact behavior of this module.
See the documentation for
torch::nn::AdaptiveLogSoftmaxWithLossOptions
class to learn what constructor arguments are supported for this module.Example:
AdaptiveLogSoftmaxWithLoss model(AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8}).div_value(2.).head_bias(true));
Public Functions
-
inline AdaptiveLogSoftmaxWithLossImpl(int64_t in_features, int64_t n_classes, std::vector<int64_t> cutoffs)¶
-
explicit AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions options_)¶
-
virtual void reset() override¶
reset()
must perform initialization of all members with reference semantics, most importantly parameters, buffers and submodules.
-
void reset_parameters()¶
-
virtual void pretty_print(std::ostream &stream) const override¶
Pretty prints the
AdaptiveLogSoftmaxWithLoss
module into the givenstream
.
-
Tensor _get_full_log_prob(const Tensor &input, const Tensor &head_output)¶
Given input tensor, and output of
head
, computes the log of the full distribution.
-
Tensor log_prob(const Tensor &input)¶
Computes log probabilities for all n_classes.
-
Tensor predict(const Tensor &input)¶
This is equivalent to
log_pob(input).argmax(1)
but is more efficient in some cases.
Public Members
-
AdaptiveLogSoftmaxWithLossOptions options¶
The options with which this
Module
was constructed.
-
std::vector<int64_t> cutoffs¶
Cutoffs used to assign targets to their buckets.
It should be an ordered Sequence of integers sorted in the increasing order
-
int64_t shortlist_size¶
-
int64_t n_clusters¶
Number of clusters.
-
int64_t head_size¶
Output size of head classifier.
-
ModuleList tail¶
-
inline AdaptiveLogSoftmaxWithLossImpl(int64_t in_features, int64_t n_classes, std::vector<int64_t> cutoffs)¶