Program Listing for File batchnorm.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/modules/batchnorm.h
)
#pragma once
#include <torch/nn/cloneable.h>
#include <torch/nn/functional/batchnorm.h>
#include <torch/nn/init.h>
#include <torch/nn/options/batchnorm.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>
#include <cstdint>
namespace torch {
namespace nn {
template <size_t D, typename Derived, typename DerivedOptions>
class NormImplBase : public torch::nn::Cloneable<Derived> {
protected:
virtual void _check_input_dim(const Tensor& input) = 0;
public:
NormImplBase(const DerivedOptions& options_) : options(options_) {
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
reset();
}
void reset() override {
if (options.affine()) {
weight = this->register_parameter(
"weight", torch::empty({options.num_features()}));
bias = this->register_parameter(
"bias", torch::empty({options.num_features()}));
} else {
weight =
this->register_parameter("weight", Tensor(), /*requires_grad=*/false);
bias =
this->register_parameter("bias", Tensor(), /*requires_grad=*/false);
}
if (options.track_running_stats()) {
running_mean = this->register_buffer(
"running_mean", torch::zeros({options.num_features()}));
running_var = this->register_buffer(
"running_var", torch::ones({options.num_features()}));
num_batches_tracked = this->register_buffer(
"num_batches_tracked", torch::tensor(0, torch::dtype(torch::kLong)));
} else {
running_mean = this->register_buffer("running_mean", Tensor());
running_var = this->register_buffer("running_var", Tensor());
num_batches_tracked =
this->register_buffer("num_batches_tracked", Tensor());
}
reset_parameters();
}
void reset_running_stats() {
if (options.track_running_stats()) {
running_mean.zero_();
running_var.fill_(1);
num_batches_tracked.zero_();
}
}
void reset_parameters() {
reset_running_stats();
if (options.affine()) {
torch::nn::init::ones_(weight);
torch::nn::init::zeros_(bias);
}
}
DerivedOptions options;
Tensor weight;
Tensor bias;
Tensor running_mean;
Tensor running_var;
Tensor num_batches_tracked;
};
template <size_t D, typename Derived>
class BatchNormImplBase : public NormImplBase<D, Derived, BatchNormOptions> {
public:
using NormImplBase<D, Derived, BatchNormOptions>::NormImplBase;
Tensor forward(const Tensor& input) {
this->_check_input_dim(input);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
double exponential_average_factor;
if (this->options.momentum() == c10::nullopt) {
exponential_average_factor = 0.0;
} else {
exponential_average_factor = this->options.momentum().value();
}
if (this->is_training() && this->options.track_running_stats()) {
if (this->num_batches_tracked.defined()) {
this->num_batches_tracked += 1;
if (this->options.momentum() ==
c10::nullopt) { // use cumulative moving average
exponential_average_factor =
1.0 / this->num_batches_tracked.template item<double>();
} else { // use exponential moving average
exponential_average_factor = this->options.momentum().value();
}
}
}
return torch::nn::functional::detail::batch_norm(
input,
this->running_mean,
this->running_var,
this->weight,
this->bias,
this->is_training() || !this->options.track_running_stats(),
/*momentum=*/exponential_average_factor,
this->options.eps());
}
void pretty_print(std::ostream& stream) const override;
};
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm1d
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API BatchNorm1dImpl : public BatchNormImplBase<1, BatchNorm1dImpl> {
protected:
void _check_input_dim(const Tensor& input) override;
public:
using BatchNormImplBase<1, BatchNorm1dImpl>::BatchNormImplBase;
};
TORCH_MODULE(BatchNorm1d);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm2d
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API BatchNorm2dImpl : public BatchNormImplBase<2, BatchNorm2dImpl> {
protected:
void _check_input_dim(const Tensor& input) override;
public:
using BatchNormImplBase<2, BatchNorm2dImpl>::BatchNormImplBase;
};
TORCH_MODULE(BatchNorm2d);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm3d
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API BatchNorm3dImpl : public BatchNormImplBase<3, BatchNorm3dImpl> {
protected:
void _check_input_dim(const Tensor& input) override;
public:
using BatchNormImplBase<3, BatchNorm3dImpl>::BatchNormImplBase;
};
TORCH_MODULE(BatchNorm3d);
} // namespace nn
} // namespace torch