Shortcuts

Class AdaptiveLogSoftmaxWithLossImpl

Inheritance Relationships

Base Type

Class Documentation

class AdaptiveLogSoftmaxWithLossImpl : public torch::nn::Cloneable<AdaptiveLogSoftmaxWithLossImpl>

Efficient softmax approximation as described in Efficient softmax approximation for GPUs_ by Edouard Grave, Armand Joulin, Moustapha Cissé, David Grangier, and Hervé Jégou.

See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveLogSoftmaxWithLoss to learn about the exact behavior of this module.

See the documentation for torch::nn::AdaptiveLogSoftmaxWithLossOptions class to learn what constructor arguments are supported for this module.

Example:

AdaptiveLogSoftmaxWithLoss model(AdaptiveLogSoftmaxWithLossOptions(8, 10,
{4, 8}).div_value(2.).head_bias(true));

Public Functions

inline AdaptiveLogSoftmaxWithLossImpl(int64_t in_features, int64_t n_classes, std::vector<int64_t> cutoffs)
explicit AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions options_)
ASMoutput forward(const Tensor &input, const Tensor &target)
virtual void reset() override

reset() must perform initialization of all members with reference semantics, most importantly parameters, buffers and submodules.

void reset_parameters()
virtual void pretty_print(std::ostream &stream) const override

Pretty prints the AdaptiveLogSoftmaxWithLoss module into the given stream.

Tensor _get_full_log_prob(const Tensor &input, const Tensor &head_output)

Given input tensor, and output of head, computes the log of the full distribution.

Tensor log_prob(const Tensor &input)

Computes log probabilities for all n_classes.

Tensor predict(const Tensor &input)

This is equivalent to log_pob(input).argmax(1) but is more efficient in some cases.

Public Members

AdaptiveLogSoftmaxWithLossOptions options

The options with which this Module was constructed.

std::vector<int64_t> cutoffs

Cutoffs used to assign targets to their buckets.

It should be an ordered Sequence of integers sorted in the increasing order

int64_t shortlist_size
int64_t n_clusters

Number of clusters.

int64_t head_size

Output size of head classifier.

Linear head = nullptr
ModuleList tail

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources