Program Listing for File normalization.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/modules/normalization.h
)
#pragma once
#include <torch/nn/cloneable.h>
#include <torch/nn/functional/normalization.h>
#include <torch/nn/modules/_functions.h>
#include <torch/nn/options/normalization.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>
#include <cstddef>
#include <vector>
namespace torch {
namespace nn {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LayerNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API LayerNormImpl : public torch::nn::Cloneable<LayerNormImpl> {
public:
LayerNormImpl(std::vector<int64_t> normalized_shape)
: LayerNormImpl(LayerNormOptions(normalized_shape)) {}
explicit LayerNormImpl(LayerNormOptions options_);
void reset() override;
void reset_parameters();
void pretty_print(std::ostream& stream) const override;
Tensor forward(const Tensor& input);
LayerNormOptions options;
Tensor weight;
Tensor bias;
};
TORCH_MODULE(LayerNorm);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LocalResponseNorm
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API LocalResponseNormImpl
: public Cloneable<LocalResponseNormImpl> {
public:
LocalResponseNormImpl(int64_t size)
: LocalResponseNormImpl(LocalResponseNormOptions(size)) {}
explicit LocalResponseNormImpl(const LocalResponseNormOptions& options_);
Tensor forward(const Tensor& input);
void reset() override;
void pretty_print(std::ostream& stream) const override;
LocalResponseNormOptions options;
};
TORCH_MODULE(LocalResponseNorm);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CrossMapLRN2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API CrossMapLRN2dImpl
: public torch::nn::Cloneable<CrossMapLRN2dImpl> {
public:
CrossMapLRN2dImpl(int64_t size)
: CrossMapLRN2dImpl(CrossMapLRN2dOptions(size)) {}
explicit CrossMapLRN2dImpl(const CrossMapLRN2dOptions& options_)
: options(options_) {}
void reset() override;
void pretty_print(std::ostream& stream) const override;
torch::Tensor forward(const torch::Tensor& input);
CrossMapLRN2dOptions options;
};
TORCH_MODULE(CrossMapLRN2d);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GroupNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API GroupNormImpl : public torch::nn::Cloneable<GroupNormImpl> {
public:
GroupNormImpl(int64_t num_groups, int64_t num_channels)
: GroupNormImpl(GroupNormOptions(num_groups, num_channels)) {}
explicit GroupNormImpl(const GroupNormOptions& options_);
void reset() override;
void reset_parameters();
void pretty_print(std::ostream& stream) const override;
Tensor forward(const Tensor& input);
GroupNormOptions options;
Tensor weight;
Tensor bias;
};
TORCH_MODULE(GroupNorm);
} // namespace nn
} // namespace torch