Program Listing for File adaptive.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/modules/adaptive.h
)
#pragma once
#include <torch/nn/cloneable.h>
#include <torch/nn/functional/activation.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/container/modulelist.h>
#include <torch/nn/modules/container/sequential.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/options/adaptive.h>
namespace torch {
namespace nn {
struct TORCH_API ASMoutput {
ASMoutput(Tensor output_, double loss_);
Tensor output;
double loss;
};
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveLogSoftmaxWithLoss
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API AdaptiveLogSoftmaxWithLossImpl
: public Cloneable<AdaptiveLogSoftmaxWithLossImpl> {
public:
AdaptiveLogSoftmaxWithLossImpl(
int64_t in_features,
int64_t n_classes,
std::vector<int64_t> cutoffs)
: AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions(
in_features,
n_classes,
cutoffs)) {}
explicit AdaptiveLogSoftmaxWithLossImpl(
AdaptiveLogSoftmaxWithLossOptions options_);
ASMoutput forward(const Tensor& input, const Tensor& target);
void reset() override;
void reset_parameters();
void pretty_print(std::ostream& stream) const override;
Tensor _get_full_log_prob(const Tensor& input, const Tensor& head_output);
Tensor log_prob(const Tensor& input);
Tensor predict(const Tensor& input);
AdaptiveLogSoftmaxWithLossOptions options;
std::vector<int64_t> cutoffs;
int64_t shortlist_size;
int64_t n_clusters;
int64_t head_size;
Linear head = nullptr;
ModuleList tail;
};
TORCH_MODULE(AdaptiveLogSoftmaxWithLoss);
} // namespace nn
} // namespace torch