Shortcuts

Program Listing for File loss.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/loss.h)

#pragma once

#include <torch/expanding_array.h>
#include <torch/nn/cloneable.h>
#include <torch/nn/functional/loss.h>
#include <torch/nn/options/loss.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>

#include <torch/csrc/WindowsTorchApiMacro.h>

#include <cstddef>
#include <vector>

namespace torch {
namespace nn {

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ L1Loss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_API L1LossImpl : Cloneable<L1LossImpl> {
  explicit L1LossImpl(const L1LossOptions& options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  L1LossOptions options;
};

TORCH_MODULE(L1Loss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ KLDivLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_API KLDivLossImpl : Cloneable<KLDivLossImpl> {
  explicit KLDivLossImpl(const KLDivLossOptions& options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  KLDivLossOptions options;
};

TORCH_MODULE(KLDivLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MSELoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_API MSELossImpl : Cloneable<MSELossImpl> {
  explicit MSELossImpl(const MSELossOptions& options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  MSELossOptions options;
};

TORCH_MODULE(MSELoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BCELoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_API BCELossImpl : Cloneable<BCELossImpl> {
  explicit BCELossImpl(const BCELossOptions& options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  BCELossOptions options;
};

TORCH_MODULE(BCELoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ HingeEmbeddingLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_API HingeEmbeddingLossImpl : Cloneable<HingeEmbeddingLossImpl> {
  explicit HingeEmbeddingLossImpl(
      const HingeEmbeddingLossOptions& options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  HingeEmbeddingLossOptions options;
};

TORCH_MODULE(HingeEmbeddingLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiMarginLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_API MultiMarginLossImpl : public Cloneable<MultiMarginLossImpl> {
  explicit MultiMarginLossImpl(
      const MultiMarginLossOptions& options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  MultiMarginLossOptions options;
};

TORCH_MODULE(MultiMarginLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CosineEmbeddingLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_API CosineEmbeddingLossImpl : public Cloneable<CosineEmbeddingLossImpl> {
  explicit CosineEmbeddingLossImpl(
      const CosineEmbeddingLossOptions& options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(
      const Tensor& input1,
      const Tensor& input2,
      const Tensor& target);

  CosineEmbeddingLossOptions options;
};

TORCH_MODULE(CosineEmbeddingLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SmoothL1Loss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_API SmoothL1LossImpl : public Cloneable<SmoothL1LossImpl> {
  explicit SmoothL1LossImpl(const SmoothL1LossOptions& options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  SmoothL1LossOptions options;
};

TORCH_MODULE(SmoothL1Loss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ HuberLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_API HuberLossImpl : public Cloneable<HuberLossImpl> {
  explicit HuberLossImpl(const HuberLossOptions& options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  HuberLossOptions options;
};

TORCH_MODULE(HuberLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiLabelMarginLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_API MultiLabelMarginLossImpl : public Cloneable<MultiLabelMarginLossImpl> {
  explicit MultiLabelMarginLossImpl(
    const MultiLabelMarginLossOptions& options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  MultiLabelMarginLossOptions options;
};

TORCH_MODULE(MultiLabelMarginLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SoftMarginLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_API SoftMarginLossImpl : public Cloneable<SoftMarginLossImpl> {
  explicit SoftMarginLossImpl(const SoftMarginLossOptions& options_ = {});

  void pretty_print(std::ostream& stream) const override;

  void reset() override;

  Tensor forward(const Tensor& input, const Tensor& target);

  SoftMarginLossOptions options;
};

TORCH_MODULE(SoftMarginLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiLabelSoftMarginLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_API MultiLabelSoftMarginLossImpl : public Cloneable<MultiLabelSoftMarginLossImpl> {
  explicit MultiLabelSoftMarginLossImpl(
    const MultiLabelSoftMarginLossOptions& options_ = {});

  void pretty_print(std::ostream& stream) const override;

  void reset() override;

  Tensor forward(const Tensor& input, const Tensor& target);

  MultiLabelSoftMarginLossOptions options;
};

TORCH_MODULE(MultiLabelSoftMarginLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TripletMarginLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_API TripletMarginLossImpl : public Cloneable<TripletMarginLossImpl> {
  explicit TripletMarginLossImpl(
      const TripletMarginLossOptions& options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(
      const Tensor& anchor,
      const Tensor& positive,
      const Tensor& negative);

  TripletMarginLossOptions options;
};

TORCH_MODULE(TripletMarginLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TripletMarginWithDistanceLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_API TripletMarginWithDistanceLossImpl : public Cloneable<TripletMarginWithDistanceLossImpl> {
  explicit TripletMarginWithDistanceLossImpl(
      TripletMarginWithDistanceLossOptions options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(
      const Tensor& anchor,
      const Tensor& positive,
      const Tensor& negative);

  TripletMarginWithDistanceLossOptions options;
};

TORCH_MODULE(TripletMarginWithDistanceLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CTCLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_API CTCLossImpl : public Cloneable<CTCLossImpl> {

  explicit CTCLossImpl(const CTCLossOptions& options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& log_probs, const Tensor& targets,
                 const Tensor& input_lengths, const Tensor& target_lengths);

  CTCLossOptions options;
};

TORCH_MODULE(CTCLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PoissonNLLLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_API PoissonNLLLossImpl : public Cloneable<PoissonNLLLossImpl> {
  explicit PoissonNLLLossImpl(const PoissonNLLLossOptions& options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& log_input, const Tensor& targets);

  PoissonNLLLossOptions options;
};

TORCH_MODULE(PoissonNLLLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MarginRankingLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_API MarginRankingLossImpl : public Cloneable<MarginRankingLossImpl> {
  explicit MarginRankingLossImpl(const MarginRankingLossOptions& options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input1,
    const Tensor& input2, const Tensor& targets);

  MarginRankingLossOptions options;
};

TORCH_MODULE(MarginRankingLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ NLLLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_API NLLLossImpl : public Cloneable<NLLLossImpl> {
  explicit NLLLossImpl(
      const NLLLossOptions& options_ = {});

  void pretty_print(std::ostream& stream) const override;

  void reset() override;

  Tensor forward(
      const Tensor& input,
      const Tensor& target);

  NLLLossOptions options;

  Tensor weight;
};

TORCH_MODULE(NLLLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CrossEntropyLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_API CrossEntropyLossImpl : public Cloneable<CrossEntropyLossImpl> {
  explicit CrossEntropyLossImpl(
      const CrossEntropyLossOptions& options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(
      const Tensor& input,
      const Tensor& target);

  CrossEntropyLossOptions options;

  Tensor weight;
};

TORCH_MODULE(CrossEntropyLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BCEWithLogitsLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_API BCEWithLogitsLossImpl : public Cloneable<BCEWithLogitsLossImpl> {
  explicit BCEWithLogitsLossImpl(const BCEWithLogitsLossOptions& options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  BCEWithLogitsLossOptions options;

  Tensor weight;

  Tensor pos_weight;
};

TORCH_MODULE(BCEWithLogitsLoss);

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources