Shortcuts

Program Listing for File loss.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/loss.h)

#pragma once

#include <torch/expanding_array.h>
#include <torch/nn/cloneable.h>
#include <torch/nn/functional/loss.h>
#include <torch/nn/options/loss.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>

#include <torch/csrc/Export.h>

#include <cstddef>
#include <vector>

namespace torch {
namespace nn {

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ L1Loss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

struct TORCH_API L1LossImpl : Cloneable<L1LossImpl> {
  explicit L1LossImpl(L1LossOptions options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  L1LossOptions options;
};

TORCH_MODULE(L1Loss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ KLDivLoss
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

struct TORCH_API KLDivLossImpl : Cloneable<KLDivLossImpl> {
  explicit KLDivLossImpl(KLDivLossOptions options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  KLDivLossOptions options;
};

TORCH_MODULE(KLDivLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MSELoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

struct TORCH_API MSELossImpl : Cloneable<MSELossImpl> {
  explicit MSELossImpl(MSELossOptions options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  MSELossOptions options;
};

TORCH_MODULE(MSELoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BCELoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

struct TORCH_API BCELossImpl : Cloneable<BCELossImpl> {
  explicit BCELossImpl(BCELossOptions options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  BCELossOptions options;
};

TORCH_MODULE(BCELoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ HingeEmbeddingLoss
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

struct TORCH_API HingeEmbeddingLossImpl : Cloneable<HingeEmbeddingLossImpl> {
  explicit HingeEmbeddingLossImpl(HingeEmbeddingLossOptions options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  HingeEmbeddingLossOptions options;
};

TORCH_MODULE(HingeEmbeddingLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiMarginLoss
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

struct TORCH_API MultiMarginLossImpl : public Cloneable<MultiMarginLossImpl> {
  explicit MultiMarginLossImpl(MultiMarginLossOptions options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  MultiMarginLossOptions options;
};

TORCH_MODULE(MultiMarginLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CosineEmbeddingLoss
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

struct TORCH_API CosineEmbeddingLossImpl
    : public Cloneable<CosineEmbeddingLossImpl> {
  explicit CosineEmbeddingLossImpl(CosineEmbeddingLossOptions options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(
      const Tensor& input1,
      const Tensor& input2,
      const Tensor& target);

  CosineEmbeddingLossOptions options;
};

TORCH_MODULE(CosineEmbeddingLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SmoothL1Loss
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

struct TORCH_API SmoothL1LossImpl : public Cloneable<SmoothL1LossImpl> {
  explicit SmoothL1LossImpl(SmoothL1LossOptions options = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  SmoothL1LossOptions options;
};

TORCH_MODULE(SmoothL1Loss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ HuberLoss
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

struct TORCH_API HuberLossImpl : public Cloneable<HuberLossImpl> {
  explicit HuberLossImpl(HuberLossOptions options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  HuberLossOptions options;
};

TORCH_MODULE(HuberLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiLabelMarginLoss
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

struct TORCH_API MultiLabelMarginLossImpl
    : public Cloneable<MultiLabelMarginLossImpl> {
  explicit MultiLabelMarginLossImpl(MultiLabelMarginLossOptions options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  MultiLabelMarginLossOptions options;
};

TORCH_MODULE(MultiLabelMarginLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SoftMarginLoss
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

struct TORCH_API SoftMarginLossImpl : public Cloneable<SoftMarginLossImpl> {
  explicit SoftMarginLossImpl(SoftMarginLossOptions options_ = {});

  void pretty_print(std::ostream& stream) const override;

  void reset() override;

  Tensor forward(const Tensor& input, const Tensor& target);

  SoftMarginLossOptions options;
};

TORCH_MODULE(SoftMarginLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiLabelSoftMarginLoss
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

struct TORCH_API MultiLabelSoftMarginLossImpl
    : public Cloneable<MultiLabelSoftMarginLossImpl> {
  explicit MultiLabelSoftMarginLossImpl(
      MultiLabelSoftMarginLossOptions options_ = {});

  void pretty_print(std::ostream& stream) const override;

  void reset() override;

  Tensor forward(const Tensor& input, const Tensor& target);

  MultiLabelSoftMarginLossOptions options;
};

TORCH_MODULE(MultiLabelSoftMarginLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TripletMarginLoss
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

struct TORCH_API TripletMarginLossImpl
    : public Cloneable<TripletMarginLossImpl> {
  explicit TripletMarginLossImpl(TripletMarginLossOptions options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(
      const Tensor& anchor,
      const Tensor& positive,
      const Tensor& negative);

  TripletMarginLossOptions options;
};

TORCH_MODULE(TripletMarginLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TripletMarginWithDistanceLoss
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

struct TORCH_API TripletMarginWithDistanceLossImpl
    : public Cloneable<TripletMarginWithDistanceLossImpl> {
  explicit TripletMarginWithDistanceLossImpl(
      TripletMarginWithDistanceLossOptions options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(
      const Tensor& anchor,
      const Tensor& positive,
      const Tensor& negative);

  TripletMarginWithDistanceLossOptions options;
};

TORCH_MODULE(TripletMarginWithDistanceLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CTCLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

struct TORCH_API CTCLossImpl : public Cloneable<CTCLossImpl> {
  explicit CTCLossImpl(CTCLossOptions options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(
      const Tensor& log_probs,
      const Tensor& targets,
      const Tensor& input_lengths,
      const Tensor& target_lengths);

  CTCLossOptions options;
};

TORCH_MODULE(CTCLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PoissonNLLLoss
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

struct TORCH_API PoissonNLLLossImpl : public Cloneable<PoissonNLLLossImpl> {
  explicit PoissonNLLLossImpl(PoissonNLLLossOptions options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& log_input, const Tensor& targets);

  PoissonNLLLossOptions options;
};

TORCH_MODULE(PoissonNLLLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MarginRankingLoss
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

struct TORCH_API MarginRankingLossImpl
    : public Cloneable<MarginRankingLossImpl> {
  explicit MarginRankingLossImpl(MarginRankingLossOptions options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(
      const Tensor& input1,
      const Tensor& input2,
      const Tensor& targets);

  MarginRankingLossOptions options;
};

TORCH_MODULE(MarginRankingLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ NLLLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

struct TORCH_API NLLLossImpl : public Cloneable<NLLLossImpl> {
  explicit NLLLossImpl(NLLLossOptions options_ = {});

  void pretty_print(std::ostream& stream) const override;

  void reset() override;

  Tensor forward(const Tensor& input, const Tensor& target);

  NLLLossOptions options;

  Tensor weight;
};

TORCH_MODULE(NLLLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CrossEntropyLoss
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

struct TORCH_API CrossEntropyLossImpl : public Cloneable<CrossEntropyLossImpl> {
  explicit CrossEntropyLossImpl(CrossEntropyLossOptions options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  CrossEntropyLossOptions options;

  Tensor weight;
};

TORCH_MODULE(CrossEntropyLoss);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BCEWithLogitsLoss
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

struct TORCH_API BCEWithLogitsLossImpl
    : public Cloneable<BCEWithLogitsLossImpl> {
  explicit BCEWithLogitsLossImpl(BCEWithLogitsLossOptions options_ = {});

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input, const Tensor& target);

  BCEWithLogitsLossOptions options;

  Tensor weight;

  Tensor pos_weight;
};

TORCH_MODULE(BCEWithLogitsLoss);

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources