Program Listing for File conv.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/functional/conv.h
)
#pragma once
#include <torch/nn/options/conv.h>
#include <torch/types.h>
namespace torch::nn::functional {
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline std::string padding_unwrap(enumtype::kValid) {
return "valid";
}
inline std::string padding_unwrap(enumtype::kSame) {
return "same";
}
template <size_t D>
IntArrayRef padding_unwrap(const ExpandingArray<D>& array) {
return array;
}
inline Tensor conv1d(
const Tensor& input,
const Tensor& weight,
const Tensor& bias,
ExpandingArray<1> stride,
const Conv1dFuncOptions::padding_t& padding,
ExpandingArray<1> dilation,
int64_t groups) {
return std::visit(
[&](const auto& pad) {
return torch::conv1d(
input, weight, bias, stride, padding_unwrap(pad), dilation, groups);
},
padding);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
inline Tensor conv1d(
const Tensor& input,
const Tensor& weight,
const Conv1dFuncOptions& options = {}) {
return detail::conv1d(
input,
weight,
options.bias(),
options.stride(),
options.padding(),
options.dilation(),
options.groups());
}
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor conv2d(
const Tensor& input,
const Tensor& weight,
const Tensor& bias,
ExpandingArray<2> stride,
const Conv2dFuncOptions::padding_t& padding,
ExpandingArray<2> dilation,
int64_t groups) {
return std::visit(
[&](const auto& pad) {
return torch::conv2d(
input, weight, bias, stride, padding_unwrap(pad), dilation, groups);
},
padding);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
inline Tensor conv2d(
const Tensor& input,
const Tensor& weight,
const Conv2dFuncOptions& options = {}) {
return detail::conv2d(
input,
weight,
options.bias(),
options.stride(),
options.padding(),
options.dilation(),
options.groups());
}
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor conv3d(
const Tensor& input,
const Tensor& weight,
const Tensor& bias,
ExpandingArray<3> stride,
const Conv3dFuncOptions::padding_t& padding,
ExpandingArray<3> dilation,
int64_t groups) {
return std::visit(
[&](const auto& pad) {
return torch::conv3d(
input, weight, bias, stride, padding_unwrap(pad), dilation, groups);
},
padding);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
inline Tensor conv3d(
const Tensor& input,
const Tensor& weight,
const Conv3dFuncOptions& options = {}) {
return detail::conv3d(
input,
weight,
options.bias(),
options.stride(),
options.padding(),
options.dilation(),
options.groups());
}
// ============================================================================
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor conv_transpose1d(
const Tensor& input,
const Tensor& weight,
const Tensor& bias,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef output_padding,
int64_t groups,
IntArrayRef dilation) {
return torch::conv_transpose1d(
input, weight, bias, stride, padding, output_padding, groups, dilation);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
inline Tensor conv_transpose1d(
const Tensor& input,
const Tensor& weight,
const ConvTranspose1dFuncOptions& options = {}) {
return detail::conv_transpose1d(
input,
weight,
options.bias(),
options.stride(),
options.padding(),
options.output_padding(),
options.groups(),
options.dilation());
}
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor conv_transpose2d(
const Tensor& input,
const Tensor& weight,
const Tensor& bias,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef output_padding,
int64_t groups,
IntArrayRef dilation) {
return torch::conv_transpose2d(
input, weight, bias, stride, padding, output_padding, groups, dilation);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
inline Tensor conv_transpose2d(
const Tensor& input,
const Tensor& weight,
const ConvTranspose2dFuncOptions& options = {}) {
return detail::conv_transpose2d(
input,
weight,
options.bias(),
options.stride(),
options.padding(),
options.output_padding(),
options.groups(),
options.dilation());
}
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor conv_transpose3d(
const Tensor& input,
const Tensor& weight,
const Tensor& bias,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef output_padding,
int64_t groups,
IntArrayRef dilation) {
return torch::conv_transpose3d(
input, weight, bias, stride, padding, output_padding, groups, dilation);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
inline Tensor conv_transpose3d(
const Tensor& input,
const Tensor& weight,
const ConvTranspose3dFuncOptions& options = {}) {
return detail::conv_transpose3d(
input,
weight,
options.bias(),
options.stride(),
options.padding(),
options.output_padding(),
options.groups(),
options.dilation());
}
} // namespace torch::nn::functional