Shortcuts

Program Listing for File instancenorm.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/instancenorm.h)

#pragma once

#include <torch/nn/modules/batchnorm.h>
#include <torch/nn/options/instancenorm.h>

namespace torch {
namespace nn {

template <size_t D, typename Derived>
class InstanceNormImpl
    : public torch::nn::NormImplBase<D, Derived, InstanceNormOptions> {
 private:
  inline Tensor apply_instance_norm(const Tensor& input) {
    return torch::nn::functional::detail::instance_norm(
        input,
        this->running_mean,
        this->running_var,
        this->weight,
        this->bias,
        this->is_training() || !this->options.track_running_stats(),
        this->options.momentum(),
        this->options.eps());
  }

  inline Tensor handle_no_batch_input(const Tensor& input) {
    return this->apply_instance_norm(input.unsqueeze(0)).squeeze(0);
  }

 public:
  using torch::nn::NormImplBase<D, Derived, InstanceNormOptions>::NormImplBase;

  Tensor forward(const Tensor& input) {
    this->_check_input_dim(input);

    // For InstanceNorm1D, 2D is unbatched and 3D is batched
    // For InstanceNorm2D, 3D is unbatched and 4D is batched
    // For InstanceNorm3D, 4D is unbatched and 5D is batched
    // check if input does not have a batch-dim
    if (input.dim() == D + 1) {
      return this->handle_no_batch_input(input);
    }

    return this->apply_instance_norm(input);
  }

  void pretty_print(std::ostream& stream) const override {
    stream << std::boolalpha << "torch::nn::InstanceNorm" << D << "d("
           << this->options.num_features() << ", "
           << "eps=" << this->options.eps() << ", "
           << "momentum=" << this->options.momentum() << ", "
           << "affine=" << this->options.affine() << ", "
           << "track_running_stats=" << this->options.track_running_stats()
           << ")";
  }
};

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm1d
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API InstanceNorm1dImpl
    : public InstanceNormImpl<1, InstanceNorm1dImpl> {
 protected:
  void _check_input_dim(const Tensor& input) override;

 public:
  using InstanceNormImpl<1, InstanceNorm1dImpl>::InstanceNormImpl;
};

TORCH_MODULE(InstanceNorm1d);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm2d
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API InstanceNorm2dImpl
    : public InstanceNormImpl<2, InstanceNorm2dImpl> {
 protected:
  void _check_input_dim(const Tensor& input) override;

 public:
  using InstanceNormImpl<2, InstanceNorm2dImpl>::InstanceNormImpl;
};

TORCH_MODULE(InstanceNorm2d);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm3d
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API InstanceNorm3dImpl
    : public InstanceNormImpl<3, InstanceNorm3dImpl> {
 protected:
  void _check_input_dim(const Tensor& input) override;

 public:
  using InstanceNormImpl<3, InstanceNorm3dImpl>::InstanceNormImpl;
};

TORCH_MODULE(InstanceNorm3d);

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources