Shortcuts

Program Listing for File conv.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/conv.h)

#pragma once

#include <c10/util/irange.h>
#include <c10/util/overloaded.h>

#include <torch/expanding_array.h>
#include <torch/nn/cloneable.h>
#include <torch/nn/init.h>
#include <torch/nn/modules/common.h>
#include <torch/nn/modules/utils.h>
#include <torch/nn/options/conv.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>

#include <torch/csrc/Export.h>

#include <cstddef>
#include <vector>

namespace torch {
namespace nn {

template <size_t D, typename Derived>
class ConvNdImpl : public torch::nn::Cloneable<Derived> {
 public:
  explicit ConvNdImpl(detail::ConvNdOptions<D> options_)
      : options(std::move(options_)) {
    // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
    reset();
  }

  void reset() override {
    TORCH_CHECK(
        options.in_channels() > 0 && options.groups() > 0 &&
            options.out_channels() > 0,
        "in_channels, groups and out_channels must be a positive integer.");
    TORCH_CHECK(
        options.in_channels() % options.groups() == 0,
        "in_channels must be divisible by groups");
    TORCH_CHECK(
        options.out_channels() % options.groups() == 0,
        "out_channels must be divisible by groups");

    std::visit(
        c10::overloaded(
            [&](enumtype::kValid) {
              _reversed_padding_repeated_twice.resize(2 * D);
              std::fill_n(_reversed_padding_repeated_twice.begin(), 2 * D, 0);
            },
            [&](enumtype::kSame) {
              for (const auto i : c10::irange(D)) {
                const auto stride = (*options.stride())[i];
                TORCH_CHECK(
                    stride == 1,
                    "padding='same' is not supported for strided convolutions");
              }

              _reversed_padding_repeated_twice.resize(2 * D);
              for (const auto i : c10::irange(D)) {
                const auto dilation = (*options.dilation())[i];
                const auto kernel_size = (*options.kernel_size())[i];
                const auto total_padding = dilation * (kernel_size - 1);
                auto left_pad = total_padding / 2;
                auto right_pad = total_padding - left_pad;
                _reversed_padding_repeated_twice[2 * i] = left_pad;
                _reversed_padding_repeated_twice[2 * i + 1] = right_pad;
              }
            },
            [&](const ExpandingArray<D>& pad) {
              _reversed_padding_repeated_twice =
                  torch::nn::modules::utils::_reverse_repeat_vector(pad, 2);
            }),
        options.padding());

    if (options.transposed()) {
      std::vector<int64_t> weight_sizes = {
          options.in_channels(), options.out_channels() / options.groups()};
      weight_sizes.insert(
          weight_sizes.end(),
          (*options.kernel_size()).begin(),
          (*options.kernel_size()).end());
      weight = this->register_parameter("weight", torch::empty(weight_sizes));
    } else {
      std::vector<int64_t> weight_sizes = {
          options.out_channels(), options.in_channels() / options.groups()};
      weight_sizes.insert(
          weight_sizes.end(),
          (*options.kernel_size()).begin(),
          (*options.kernel_size()).end());
      weight = this->register_parameter("weight", torch::empty(weight_sizes));
    }

    if (options.bias()) {
      bias = this->register_parameter(
          "bias", torch::empty({options.out_channels()}));
    } else {
      this->register_parameter("bias", Tensor(), /*requires_grad=*/false);
    }

    reset_parameters();
  }

  void reset_parameters() {
    init::kaiming_uniform_(
        weight,
        /*a=*/std::sqrt(5)); // NOLINT(cppcoreguidelines-avoid-magic-numbers)

    if (bias.defined()) {
      auto [fan_in, fan_out] = init::_calculate_fan_in_and_fan_out(weight);
      auto bound = 1 / std::sqrt(fan_in);
      init::uniform_(bias, -bound, bound);
    }
  }

  void pretty_print(std::ostream& stream) const override {
    stream << "torch::nn::Conv" << D << "d"
           << "(" << options.in_channels() << ", " << options.out_channels()
           << ", kernel_size=" << options.kernel_size()
           << ", stride=" << options.stride();
    std::visit(
        c10::overloaded(
            [&](enumtype::kValid) { stream << ", padding='valid'"; },
            [&](enumtype::kSame) { stream << ", padding='same'"; },
            [&](const ExpandingArray<D>& pad) {
              if (*pad != *ExpandingArray<D>(0)) {
                stream << ", padding=" << pad;
              }
            }),
        options.padding());
    if (*options.dilation() != *ExpandingArray<D>(1)) {
      stream << ", dilation=" << options.dilation();
    }
    if (*options.output_padding() != *ExpandingArray<D>(0)) {
      stream << ", output_padding=" << options.output_padding();
    }
    if (options.groups() != 1) {
      stream << ", groups=" << options.groups();
    }
    if (!options.bias()) {
      stream << ", bias=" << std::boolalpha << false;
    }
    if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
      stream << ", padding_mode="
             << enumtype::get_enum_name(options.padding_mode());
    }
    stream << ")";
  }

  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  detail::ConvNdOptions<D> options;

  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  Tensor weight;

  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  Tensor bias;

 protected:
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  std::vector<int64_t> _reversed_padding_repeated_twice;
};

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API Conv1dImpl : public ConvNdImpl<1, Conv1dImpl> {
 public:
  Conv1dImpl(
      int64_t input_channels,
      int64_t output_channels,
      ExpandingArray<1> kernel_size)
      : Conv1dImpl(
            Conv1dOptions(input_channels, output_channels, kernel_size)) {}
  explicit Conv1dImpl(Conv1dOptions options_);
  Tensor forward(const Tensor& input);
};

TORCH_MODULE(Conv1d);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API Conv2dImpl : public ConvNdImpl<2, Conv2dImpl> {
 public:
  Conv2dImpl(
      int64_t input_channels,
      int64_t output_channels,
      ExpandingArray<2> kernel_size)
      : Conv2dImpl(
            Conv2dOptions(input_channels, output_channels, kernel_size)) {}
  explicit Conv2dImpl(Conv2dOptions options_);
  Tensor forward(const Tensor& input);

 protected:
  Tensor _conv_forward(const Tensor& input, const Tensor& weight);
};

TORCH_MODULE(Conv2d);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API Conv3dImpl : public ConvNdImpl<3, Conv3dImpl> {
 public:
  Conv3dImpl(
      int64_t input_channels,
      int64_t output_channels,
      ExpandingArray<3> kernel_size)
      : Conv3dImpl(
            Conv3dOptions(input_channels, output_channels, kernel_size)) {}
  explicit Conv3dImpl(Conv3dOptions options_);
  Tensor forward(const Tensor& input);
};

TORCH_MODULE(Conv3d);

// ~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template <size_t D, typename Derived>
class ConvTransposeNdImpl : public ConvNdImpl<D, Derived> {
 public:
  using torch::nn::ConvNdImpl<D, Derived>::ConvNdImpl;
  explicit ConvTransposeNdImpl(detail::ConvNdOptions<D> options_)
      : ConvNdImpl<D, Derived>(options_) {
    TORCH_INTERNAL_ASSERT(
        std::holds_alternative<ExpandingArray<D>>(this->options.padding()),
        "ConvTranspose padding cannot be a string");
  }

  void pretty_print(std::ostream& stream) const override {
    stream << "torch::nn::ConvTranspose" << D << "d"
           << "(" << this->options.in_channels() << ", "
           << this->options.out_channels()
           << ", kernel_size=" << this->options.kernel_size()
           << ", stride=" << this->options.stride();
    const auto& pad = padding();
    if (*pad != *ExpandingArray<D>(0)) {
      stream << ", padding=" << pad;
    }
    if (*this->options.dilation() != *ExpandingArray<D>(1)) {
      stream << ", dilation=" << this->options.dilation();
    }
    if (*this->options.output_padding() != *ExpandingArray<D>(0)) {
      stream << ", output_padding=" << this->options.output_padding();
    }
    if (this->options.groups() != 1) {
      stream << ", groups=" << this->options.groups();
    }
    if (!this->options.bias()) {
      stream << ", bias=" << std::boolalpha << false;
    }
    if (!std::get_if<enumtype::kZeros>(&this->options.padding_mode())) {
      stream << ", padding_mode="
             << enumtype::get_enum_name(this->options.padding_mode());
    }
    stream << ")";
  }

 protected:
  const ExpandingArray<D>& padding() const {
    return std::get<ExpandingArray<D>>(this->options.padding());
  }

  std::vector<int64_t> _output_padding(
      const Tensor& input,
      const std::optional<at::IntArrayRef>& output_size,
      const ExpandingArray<D>& stride,
      const ExpandingArray<D>& padding,
      const ExpandingArray<D>& kernel_size);
};

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose1d
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API ConvTranspose1dImpl
    : public ConvTransposeNdImpl<1, ConvTranspose1dImpl> {
 public:
  ConvTranspose1dImpl(
      int64_t input_channels,
      int64_t output_channels,
      ExpandingArray<1> kernel_size)
      : ConvTranspose1dImpl(ConvTranspose1dOptions(
            input_channels,
            output_channels,
            kernel_size)) {}
  explicit ConvTranspose1dImpl(ConvTranspose1dOptions options_);
  Tensor forward(
      const Tensor& input,
      const std::optional<at::IntArrayRef>& output_size = std::nullopt);

 protected:
  FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(std::optional<at::IntArrayRef>())})
};

TORCH_MODULE(ConvTranspose1d);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose2d
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API ConvTranspose2dImpl
    : public ConvTransposeNdImpl<2, ConvTranspose2dImpl> {
 public:
  ConvTranspose2dImpl(
      int64_t input_channels,
      int64_t output_channels,
      ExpandingArray<2> kernel_size)
      : ConvTranspose2dImpl(ConvTranspose2dOptions(
            input_channels,
            output_channels,
            kernel_size)) {}
  explicit ConvTranspose2dImpl(ConvTranspose2dOptions options_);
  Tensor forward(
      const Tensor& input,
      const std::optional<at::IntArrayRef>& output_size = std::nullopt);

 protected:
  FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(std::optional<at::IntArrayRef>())})
};

TORCH_MODULE(ConvTranspose2d);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose3d
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API ConvTranspose3dImpl
    : public ConvTransposeNdImpl<3, ConvTranspose3dImpl> {
 public:
  ConvTranspose3dImpl(
      int64_t input_channels,
      int64_t output_channels,
      ExpandingArray<3> kernel_size)
      : ConvTranspose3dImpl(ConvTranspose3dOptions(
            input_channels,
            output_channels,
            kernel_size)) {}
  explicit ConvTranspose3dImpl(ConvTranspose3dOptions options_);
  Tensor forward(
      const Tensor& input,
      const std::optional<at::IntArrayRef>& output_size = std::nullopt);

 protected:
  FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(std::optional<at::IntArrayRef>())})
};

TORCH_MODULE(ConvTranspose3d);

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources