Shortcuts

Program Listing for File adaptive.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/adaptive.h)

#pragma once

#include <torch/nn/cloneable.h>
#include <torch/nn/functional/activation.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/container/modulelist.h>
#include <torch/nn/modules/container/sequential.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/options/adaptive.h>

namespace torch {
namespace nn {

struct TORCH_API ASMoutput {
  ASMoutput(Tensor output_, double loss_);

  Tensor output;

  double loss;
};

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveLogSoftmaxWithLoss
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API AdaptiveLogSoftmaxWithLossImpl
    : public Cloneable<AdaptiveLogSoftmaxWithLossImpl> {
 public:
  AdaptiveLogSoftmaxWithLossImpl(
      int64_t in_features,
      int64_t n_classes,
      std::vector<int64_t> cutoffs)
      : AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions(
            in_features,
            n_classes,
            cutoffs)) {}

  explicit AdaptiveLogSoftmaxWithLossImpl(
      AdaptiveLogSoftmaxWithLossOptions options_);

  ASMoutput forward(const Tensor& input, const Tensor& target);

  void reset() override;

  void reset_parameters();

  void pretty_print(std::ostream& stream) const override;

  Tensor _get_full_log_prob(const Tensor& input, const Tensor& head_output);

  Tensor log_prob(const Tensor& input);

  Tensor predict(const Tensor& input);

  AdaptiveLogSoftmaxWithLossOptions options;

  std::vector<int64_t> cutoffs;

  int64_t shortlist_size;

  int64_t n_clusters;

  int64_t head_size;

  Linear head = nullptr;

  ModuleList tail;
};

TORCH_MODULE(AdaptiveLogSoftmaxWithLoss);

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources