Shortcuts

Program Listing for File batchnorm.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/batchnorm.h)

#pragma once

#include <torch/nn/cloneable.h>
#include <torch/nn/functional/batchnorm.h>
#include <torch/nn/init.h>
#include <torch/nn/options/batchnorm.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>

#include <cstdint>

namespace torch {
namespace nn {

template <size_t D, typename Derived, typename DerivedOptions>
class NormImplBase : public torch::nn::Cloneable<Derived> {
 protected:
  virtual void _check_input_dim(const Tensor& input) = 0;

 public:
  NormImplBase(const DerivedOptions& options_) : options(options_) {
    // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
    reset();
  }

  void reset() override {
    if (options.affine()) {
      weight = this->register_parameter(
          "weight", torch::empty({options.num_features()}));
      bias = this->register_parameter(
          "bias", torch::empty({options.num_features()}));
    } else {
      weight =
          this->register_parameter("weight", Tensor(), /*requires_grad=*/false);
      bias =
          this->register_parameter("bias", Tensor(), /*requires_grad=*/false);
    }
    if (options.track_running_stats()) {
      running_mean = this->register_buffer(
          "running_mean", torch::zeros({options.num_features()}));
      running_var = this->register_buffer(
          "running_var", torch::ones({options.num_features()}));
      num_batches_tracked = this->register_buffer(
          "num_batches_tracked", torch::tensor(0, torch::dtype(torch::kLong)));
    } else {
      running_mean = this->register_buffer("running_mean", Tensor());
      running_var = this->register_buffer("running_var", Tensor());
      num_batches_tracked =
          this->register_buffer("num_batches_tracked", Tensor());
    }
    reset_parameters();
  }

  void reset_running_stats() {
    if (options.track_running_stats()) {
      running_mean.zero_();
      running_var.fill_(1);
      num_batches_tracked.zero_();
    }
  }

  void reset_parameters() {
    reset_running_stats();
    if (options.affine()) {
      torch::nn::init::ones_(weight);
      torch::nn::init::zeros_(bias);
    }
  }

  DerivedOptions options;

  Tensor weight;

  Tensor bias;

  Tensor running_mean;

  Tensor running_var;

  Tensor num_batches_tracked;
};

template <size_t D, typename Derived>
class BatchNormImplBase : public NormImplBase<D, Derived, BatchNormOptions> {
 public:
  using NormImplBase<D, Derived, BatchNormOptions>::NormImplBase;

  Tensor forward(const Tensor& input) {
    this->_check_input_dim(input);
    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
    double exponential_average_factor;
    if (this->options.momentum() == c10::nullopt) {
      exponential_average_factor = 0.0;
    } else {
      exponential_average_factor = this->options.momentum().value();
    }

    if (this->is_training() && this->options.track_running_stats()) {
      if (this->num_batches_tracked.defined()) {
        this->num_batches_tracked += 1;
        if (this->options.momentum() ==
            c10::nullopt) { // use cumulative moving average
          exponential_average_factor =
              1.0 / this->num_batches_tracked.template item<double>();
        } else { // use exponential moving average
          exponential_average_factor = this->options.momentum().value();
        }
      }
    }

    return torch::nn::functional::detail::batch_norm(
        input,
        this->running_mean,
        this->running_var,
        this->weight,
        this->bias,
        this->is_training() || !this->options.track_running_stats(),
        /*momentum=*/exponential_average_factor,
        this->options.eps());
  }

  void pretty_print(std::ostream& stream) const override {
    stream << std::boolalpha << "torch::nn::BatchNorm" << D << "d("
           << this->options.num_features() << ", "
           << "eps=" << this->options.eps() << ", "
           << "momentum=";

    if (this->options.momentum().has_value()) {
      stream << this->options.momentum().value();
    } else {
      stream << "None";
    }

    stream << ", "
           << "affine=" << this->options.affine() << ", "
           << "track_running_stats=" << this->options.track_running_stats()
           << ")";
  }
};

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm1d
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API BatchNorm1dImpl : public BatchNormImplBase<1, BatchNorm1dImpl> {
 protected:
  void _check_input_dim(const Tensor& input) override;

 public:
  using BatchNormImplBase<1, BatchNorm1dImpl>::BatchNormImplBase;
};

TORCH_MODULE(BatchNorm1d);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm2d
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API BatchNorm2dImpl : public BatchNormImplBase<2, BatchNorm2dImpl> {
 protected:
  void _check_input_dim(const Tensor& input) override;

 public:
  using BatchNormImplBase<2, BatchNorm2dImpl>::BatchNormImplBase;
};

TORCH_MODULE(BatchNorm2d);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm3d
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API BatchNorm3dImpl : public BatchNormImplBase<3, BatchNorm3dImpl> {
 protected:
  void _check_input_dim(const Tensor& input) override;

 public:
  using BatchNormImplBase<3, BatchNorm3dImpl>::BatchNormImplBase;
};

TORCH_MODULE(BatchNorm3d);

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources