Shortcuts

Program Listing for File conv.h

Return to documentation for file (torch/csrc/api/include/torch/nn/options/conv.h)

#pragma once

#include <torch/arg.h>
#include <torch/csrc/Export.h>
#include <torch/enum.h>
#include <torch/expanding_array.h>
#include <torch/types.h>

namespace torch {
namespace nn {

namespace detail {

typedef std::variant<
    enumtype::kZeros,
    enumtype::kReflect,
    enumtype::kReplicate,
    enumtype::kCircular>
    conv_padding_mode_t;

template <size_t D>
using conv_padding_t =
    std::variant<ExpandingArray<D>, enumtype::kValid, enumtype::kSame>;

template <size_t D>
struct ConvNdOptions {
  using padding_t = conv_padding_t<D>;
  ConvNdOptions(
      int64_t in_channels,
      int64_t out_channels,
      ExpandingArray<D> kernel_size)
      : in_channels_(in_channels),
        out_channels_(out_channels),
        kernel_size_(std::move(kernel_size)) {}

  TORCH_ARG(int64_t, in_channels);

  TORCH_ARG(int64_t, out_channels);

  TORCH_ARG(ExpandingArray<D>, kernel_size);

  TORCH_ARG(ExpandingArray<D>, stride) = 1;

  TORCH_ARG(padding_t, padding) = 0;

 public:
  decltype(auto) padding(std::initializer_list<int64_t> il) {
    return padding(IntArrayRef{il});
  }

  TORCH_ARG(ExpandingArray<D>, dilation) = 1;

  TORCH_ARG(bool, transposed) = false;

  TORCH_ARG(ExpandingArray<D>, output_padding) = 0;

  TORCH_ARG(int64_t, groups) = 1;

  TORCH_ARG(bool, bias) = true;

  TORCH_ARG(conv_padding_mode_t, padding_mode) = torch::kZeros;
};

} // namespace detail

// ============================================================================

template <size_t D>
struct ConvOptions {
  using padding_mode_t = detail::conv_padding_mode_t;
  using padding_t = detail::conv_padding_t<D>;

  ConvOptions(
      int64_t in_channels,
      int64_t out_channels,
      ExpandingArray<D> kernel_size)
      : in_channels_(in_channels),
        out_channels_(out_channels),
        kernel_size_(std::move(kernel_size)) {}

  TORCH_ARG(int64_t, in_channels);

  TORCH_ARG(int64_t, out_channels);

  TORCH_ARG(ExpandingArray<D>, kernel_size);

  TORCH_ARG(ExpandingArray<D>, stride) = 1;

  TORCH_ARG(padding_t, padding) = 0;

 public:
  decltype(auto) padding(std::initializer_list<int64_t> il) {
    return padding(IntArrayRef{il});
  }

  TORCH_ARG(ExpandingArray<D>, dilation) = 1;

  TORCH_ARG(int64_t, groups) = 1;

  TORCH_ARG(bool, bias) = true;

  TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros;
};

using Conv1dOptions = ConvOptions<1>;

using Conv2dOptions = ConvOptions<2>;

using Conv3dOptions = ConvOptions<3>;

// ============================================================================

namespace functional {

template <size_t D>
struct ConvFuncOptions {
  using padding_t = torch::nn::detail::conv_padding_t<D>;

  TORCH_ARG(torch::Tensor, bias) = Tensor();

  TORCH_ARG(ExpandingArray<D>, stride) = 1;

  TORCH_ARG(padding_t, padding) = 0;

 public:
  decltype(auto) padding(std::initializer_list<int64_t> il) {
    return padding(IntArrayRef{il});
  }

  TORCH_ARG(ExpandingArray<D>, dilation) = 1;

  TORCH_ARG(int64_t, groups) = 1;
};

using Conv1dFuncOptions = ConvFuncOptions<1>;

using Conv2dFuncOptions = ConvFuncOptions<2>;

using Conv3dFuncOptions = ConvFuncOptions<3>;

} // namespace functional

// ============================================================================

template <size_t D>
struct ConvTransposeOptions {
  using padding_mode_t = detail::conv_padding_mode_t;

  ConvTransposeOptions(
      int64_t in_channels,
      int64_t out_channels,
      ExpandingArray<D> kernel_size)
      : in_channels_(in_channels),
        out_channels_(out_channels),
        kernel_size_(std::move(kernel_size)) {}

  TORCH_ARG(int64_t, in_channels);

  TORCH_ARG(int64_t, out_channels);

  TORCH_ARG(ExpandingArray<D>, kernel_size);

  TORCH_ARG(ExpandingArray<D>, stride) = 1;

  TORCH_ARG(ExpandingArray<D>, padding) = 0;

  TORCH_ARG(ExpandingArray<D>, output_padding) = 0;

  TORCH_ARG(int64_t, groups) = 1;

  TORCH_ARG(bool, bias) = true;

  TORCH_ARG(ExpandingArray<D>, dilation) = 1;

  TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros;
};

using ConvTranspose1dOptions = ConvTransposeOptions<1>;

using ConvTranspose2dOptions = ConvTransposeOptions<2>;

using ConvTranspose3dOptions = ConvTransposeOptions<3>;

// ============================================================================

namespace functional {

template <size_t D>
struct ConvTransposeFuncOptions {
  TORCH_ARG(torch::Tensor, bias) = Tensor();

  TORCH_ARG(ExpandingArray<D>, stride) = 1;

  TORCH_ARG(ExpandingArray<D>, padding) = 0;

  TORCH_ARG(ExpandingArray<D>, output_padding) = 0;

  TORCH_ARG(int64_t, groups) = 1;

  TORCH_ARG(ExpandingArray<D>, dilation) = 1;
};

using ConvTranspose1dFuncOptions = ConvTransposeFuncOptions<1>;

using ConvTranspose2dFuncOptions = ConvTransposeFuncOptions<2>;

using ConvTranspose3dFuncOptions = ConvTransposeFuncOptions<3>;

} // namespace functional

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources