Shortcuts

Program Listing for File batchnorm.h

Return to documentation for file (torch/csrc/api/include/torch/nn/functional/batchnorm.h)

#pragma once

#include <c10/util/irange.h>
#include <torch/nn/options/batchnorm.h>
#include <torch/types.h>

namespace torch {
namespace nn {
namespace functional {

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor batch_norm(
    const Tensor& input,
    const Tensor& running_mean,
    const Tensor& running_var,
    Tensor weight,
    Tensor bias,
    bool training,
    std::optional<double> momentum,
    double eps) {
  TORCH_CHECK(
      input.dim() >= 2,
      "Expected at least 2 input dimensions, but got ",
      input.dim());
  if (training) {
    auto size = input.sizes();
    int64_t size_prods = size[0];
    for (const auto i : c10::irange(size.size() - 2)) {
      size_prods *= size[i + 2];
    }
    TORCH_CHECK(
        size_prods != 1,
        "Expected more than 1 value per channel when training, got input size ",
        size);
  }

  return torch::batch_norm(
      input,
      weight,
      bias,
      running_mean,
      running_var,
      training,
      momentum.value(),
      eps,
      at::globalContext().userEnabledCuDNN());
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor batch_norm(
    const Tensor& input,
    const Tensor& running_mean,
    const Tensor& running_var,
    const BatchNormFuncOptions& options = {}) {
  return detail::batch_norm(
      input,
      running_mean,
      running_var,
      options.weight(),
      options.bias(),
      options.training(),
      options.momentum(),
      options.eps());
}

} // namespace functional
} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources