Program Listing for File instancenorm.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/modules/instancenorm.h
)
#pragma once
#include <torch/nn/modules/batchnorm.h>
#include <torch/nn/options/instancenorm.h>
namespace torch {
namespace nn {
template <size_t D, typename Derived>
class InstanceNormImpl
: public torch::nn::NormImplBase<D, Derived, InstanceNormOptions> {
private:
inline Tensor apply_instance_norm(const Tensor& input) {
return torch::nn::functional::detail::instance_norm(
input,
this->running_mean,
this->running_var,
this->weight,
this->bias,
this->is_training() || !this->options.track_running_stats(),
this->options.momentum(),
this->options.eps());
}
inline Tensor handle_no_batch_input(const Tensor& input) {
return this->apply_instance_norm(input.unsqueeze(0)).squeeze(0);
}
public:
using torch::nn::NormImplBase<D, Derived, InstanceNormOptions>::NormImplBase;
Tensor forward(const Tensor& input) {
this->_check_input_dim(input);
// For InstanceNorm1D, 2D is unbatched and 3D is batched
// For InstanceNorm2D, 3D is unbatched and 4D is batched
// For InstanceNorm3D, 4D is unbatched and 5D is batched
// check if input does not have a batch-dim
if (input.dim() == D + 1) {
return this->handle_no_batch_input(input);
}
return this->apply_instance_norm(input);
}
void pretty_print(std::ostream& stream) const override {
stream << std::boolalpha << "torch::nn::InstanceNorm" << D << "d("
<< this->options.num_features() << ", "
<< "eps=" << this->options.eps() << ", "
<< "momentum=" << this->options.momentum() << ", "
<< "affine=" << this->options.affine() << ", "
<< "track_running_stats=" << this->options.track_running_stats()
<< ")";
}
};
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm1d
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API InstanceNorm1dImpl
: public InstanceNormImpl<1, InstanceNorm1dImpl> {
protected:
void _check_input_dim(const Tensor& input) override;
public:
using InstanceNormImpl<1, InstanceNorm1dImpl>::InstanceNormImpl;
};
TORCH_MODULE(InstanceNorm1d);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm2d
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API InstanceNorm2dImpl
: public InstanceNormImpl<2, InstanceNorm2dImpl> {
protected:
void _check_input_dim(const Tensor& input) override;
public:
using InstanceNormImpl<2, InstanceNorm2dImpl>::InstanceNormImpl;
};
TORCH_MODULE(InstanceNorm2d);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm3d
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API InstanceNorm3dImpl
: public InstanceNormImpl<3, InstanceNorm3dImpl> {
protected:
void _check_input_dim(const Tensor& input) override;
public:
using InstanceNormImpl<3, InstanceNorm3dImpl>::InstanceNormImpl;
};
TORCH_MODULE(InstanceNorm3d);
} // namespace nn
} // namespace torch