Shortcuts

Program Listing for File normalization.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/normalization.h)

#pragma once

#include <torch/nn/cloneable.h>
#include <torch/nn/functional/normalization.h>
#include <torch/nn/modules/_functions.h>
#include <torch/nn/options/normalization.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>

#include <cstddef>
#include <vector>

namespace torch {
namespace nn {

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LayerNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API LayerNormImpl : public torch::nn::Cloneable<LayerNormImpl> {
 public:
  LayerNormImpl(std::vector<int64_t> normalized_shape)
      : LayerNormImpl(LayerNormOptions(normalized_shape)) {}
  explicit LayerNormImpl(LayerNormOptions options_);

  void reset() override;

  void reset_parameters();

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input);

  LayerNormOptions options;

  Tensor weight;

  Tensor bias;
};

TORCH_MODULE(LayerNorm);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LocalResponseNorm
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API LocalResponseNormImpl
    : public Cloneable<LocalResponseNormImpl> {
 public:
  LocalResponseNormImpl(int64_t size)
      : LocalResponseNormImpl(LocalResponseNormOptions(size)) {}
  explicit LocalResponseNormImpl(const LocalResponseNormOptions& options_);

  Tensor forward(const Tensor& input);

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  LocalResponseNormOptions options;
};

TORCH_MODULE(LocalResponseNorm);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CrossMapLRN2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API CrossMapLRN2dImpl
    : public torch::nn::Cloneable<CrossMapLRN2dImpl> {
 public:
  CrossMapLRN2dImpl(int64_t size)
      : CrossMapLRN2dImpl(CrossMapLRN2dOptions(size)) {}
  explicit CrossMapLRN2dImpl(const CrossMapLRN2dOptions& options_)
      : options(options_) {}

  void reset() override;

  void pretty_print(std::ostream& stream) const override;

  torch::Tensor forward(const torch::Tensor& input);

  CrossMapLRN2dOptions options;
};

TORCH_MODULE(CrossMapLRN2d);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GroupNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API GroupNormImpl : public torch::nn::Cloneable<GroupNormImpl> {
 public:
  GroupNormImpl(int64_t num_groups, int64_t num_channels)
      : GroupNormImpl(GroupNormOptions(num_groups, num_channels)) {}
  explicit GroupNormImpl(const GroupNormOptions& options_);

  void reset() override;

  void reset_parameters();

  void pretty_print(std::ostream& stream) const override;

  Tensor forward(const Tensor& input);

  GroupNormOptions options;

  Tensor weight;

  Tensor bias;
};

TORCH_MODULE(GroupNorm);

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources