Program Listing for File conv.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/options/conv.h
)
#pragma once
#include <torch/arg.h>
#include <torch/csrc/Export.h>
#include <torch/enum.h>
#include <torch/expanding_array.h>
#include <torch/types.h>
namespace torch::nn {
namespace detail {
typedef std::variant<
enumtype::kZeros,
enumtype::kReflect,
enumtype::kReplicate,
enumtype::kCircular>
conv_padding_mode_t;
template <size_t D>
using conv_padding_t =
std::variant<ExpandingArray<D>, enumtype::kValid, enumtype::kSame>;
template <size_t D>
struct ConvNdOptions {
using padding_t = conv_padding_t<D>;
ConvNdOptions(
int64_t in_channels,
int64_t out_channels,
ExpandingArray<D> kernel_size)
: in_channels_(in_channels),
out_channels_(out_channels),
kernel_size_(std::move(kernel_size)) {}
TORCH_ARG(int64_t, in_channels);
TORCH_ARG(int64_t, out_channels);
TORCH_ARG(ExpandingArray<D>, kernel_size);
TORCH_ARG(ExpandingArray<D>, stride) = 1;
TORCH_ARG(padding_t, padding) = 0;
public:
decltype(auto) padding(std::initializer_list<int64_t> il) {
return padding(IntArrayRef{il});
}
TORCH_ARG(ExpandingArray<D>, dilation) = 1;
TORCH_ARG(bool, transposed) = false;
TORCH_ARG(ExpandingArray<D>, output_padding) = 0;
TORCH_ARG(int64_t, groups) = 1;
TORCH_ARG(bool, bias) = true;
TORCH_ARG(conv_padding_mode_t, padding_mode) = torch::kZeros;
};
} // namespace detail
// ============================================================================
template <size_t D>
struct ConvOptions {
using padding_mode_t = detail::conv_padding_mode_t;
using padding_t = detail::conv_padding_t<D>;
ConvOptions(
int64_t in_channels,
int64_t out_channels,
ExpandingArray<D> kernel_size)
: in_channels_(in_channels),
out_channels_(out_channels),
kernel_size_(std::move(kernel_size)) {}
TORCH_ARG(int64_t, in_channels);
TORCH_ARG(int64_t, out_channels);
TORCH_ARG(ExpandingArray<D>, kernel_size);
TORCH_ARG(ExpandingArray<D>, stride) = 1;
TORCH_ARG(padding_t, padding) = 0;
public:
decltype(auto) padding(std::initializer_list<int64_t> il) {
return padding(IntArrayRef{il});
}
TORCH_ARG(ExpandingArray<D>, dilation) = 1;
TORCH_ARG(int64_t, groups) = 1;
TORCH_ARG(bool, bias) = true;
TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros;
};
using Conv1dOptions = ConvOptions<1>;
using Conv2dOptions = ConvOptions<2>;
using Conv3dOptions = ConvOptions<3>;
// ============================================================================
namespace functional {
template <size_t D>
struct ConvFuncOptions {
using padding_t = torch::nn::detail::conv_padding_t<D>;
TORCH_ARG(torch::Tensor, bias) = Tensor();
TORCH_ARG(ExpandingArray<D>, stride) = 1;
TORCH_ARG(padding_t, padding) = 0;
public:
decltype(auto) padding(std::initializer_list<int64_t> il) {
return padding(IntArrayRef{il});
}
TORCH_ARG(ExpandingArray<D>, dilation) = 1;
TORCH_ARG(int64_t, groups) = 1;
};
using Conv1dFuncOptions = ConvFuncOptions<1>;
using Conv2dFuncOptions = ConvFuncOptions<2>;
using Conv3dFuncOptions = ConvFuncOptions<3>;
} // namespace functional
// ============================================================================
template <size_t D>
struct ConvTransposeOptions {
using padding_mode_t = detail::conv_padding_mode_t;
ConvTransposeOptions(
int64_t in_channels,
int64_t out_channels,
ExpandingArray<D> kernel_size)
: in_channels_(in_channels),
out_channels_(out_channels),
kernel_size_(std::move(kernel_size)) {}
TORCH_ARG(int64_t, in_channels);
TORCH_ARG(int64_t, out_channels);
TORCH_ARG(ExpandingArray<D>, kernel_size);
TORCH_ARG(ExpandingArray<D>, stride) = 1;
TORCH_ARG(ExpandingArray<D>, padding) = 0;
TORCH_ARG(ExpandingArray<D>, output_padding) = 0;
TORCH_ARG(int64_t, groups) = 1;
TORCH_ARG(bool, bias) = true;
TORCH_ARG(ExpandingArray<D>, dilation) = 1;
TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros;
};
using ConvTranspose1dOptions = ConvTransposeOptions<1>;
using ConvTranspose2dOptions = ConvTransposeOptions<2>;
using ConvTranspose3dOptions = ConvTransposeOptions<3>;
// ============================================================================
namespace functional {
template <size_t D>
struct ConvTransposeFuncOptions {
TORCH_ARG(torch::Tensor, bias) = Tensor();
TORCH_ARG(ExpandingArray<D>, stride) = 1;
TORCH_ARG(ExpandingArray<D>, padding) = 0;
TORCH_ARG(ExpandingArray<D>, output_padding) = 0;
TORCH_ARG(int64_t, groups) = 1;
TORCH_ARG(ExpandingArray<D>, dilation) = 1;
};
using ConvTranspose1dFuncOptions = ConvTransposeFuncOptions<1>;
using ConvTranspose2dFuncOptions = ConvTransposeFuncOptions<2>;
using ConvTranspose3dFuncOptions = ConvTransposeFuncOptions<3>;
} // namespace functional
} // namespace torch::nn