Shortcuts

Program Listing for File conv.h

Return to documentation for file (torch/csrc/api/include/torch/nn/functional/conv.h)

#pragma once

#include <torch/nn/options/conv.h>
#include <torch/types.h>

namespace torch::nn::functional {

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {

inline std::string padding_unwrap(enumtype::kValid) {
  return "valid";
}

inline std::string padding_unwrap(enumtype::kSame) {
  return "same";
}

template <size_t D>
IntArrayRef padding_unwrap(const ExpandingArray<D>& array) {
  return array;
}

inline Tensor conv1d(
    const Tensor& input,
    const Tensor& weight,
    const Tensor& bias,
    ExpandingArray<1> stride,
    const Conv1dFuncOptions::padding_t& padding,
    ExpandingArray<1> dilation,
    int64_t groups) {
  return std::visit(
      [&](const auto& pad) {
        return torch::conv1d(
            input, weight, bias, stride, padding_unwrap(pad), dilation, groups);
      },
      padding);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor conv1d(
    const Tensor& input,
    const Tensor& weight,
    const Conv1dFuncOptions& options = {}) {
  return detail::conv1d(
      input,
      weight,
      options.bias(),
      options.stride(),
      options.padding(),
      options.dilation(),
      options.groups());
}

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor conv2d(
    const Tensor& input,
    const Tensor& weight,
    const Tensor& bias,
    ExpandingArray<2> stride,
    const Conv2dFuncOptions::padding_t& padding,
    ExpandingArray<2> dilation,
    int64_t groups) {
  return std::visit(
      [&](const auto& pad) {
        return torch::conv2d(
            input, weight, bias, stride, padding_unwrap(pad), dilation, groups);
      },
      padding);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor conv2d(
    const Tensor& input,
    const Tensor& weight,
    const Conv2dFuncOptions& options = {}) {
  return detail::conv2d(
      input,
      weight,
      options.bias(),
      options.stride(),
      options.padding(),
      options.dilation(),
      options.groups());
}

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor conv3d(
    const Tensor& input,
    const Tensor& weight,
    const Tensor& bias,
    ExpandingArray<3> stride,
    const Conv3dFuncOptions::padding_t& padding,
    ExpandingArray<3> dilation,
    int64_t groups) {
  return std::visit(
      [&](const auto& pad) {
        return torch::conv3d(
            input, weight, bias, stride, padding_unwrap(pad), dilation, groups);
      },
      padding);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor conv3d(
    const Tensor& input,
    const Tensor& weight,
    const Conv3dFuncOptions& options = {}) {
  return detail::conv3d(
      input,
      weight,
      options.bias(),
      options.stride(),
      options.padding(),
      options.dilation(),
      options.groups());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor conv_transpose1d(
    const Tensor& input,
    const Tensor& weight,
    const Tensor& bias,
    IntArrayRef stride,
    IntArrayRef padding,
    IntArrayRef output_padding,
    int64_t groups,
    IntArrayRef dilation) {
  return torch::conv_transpose1d(
      input, weight, bias, stride, padding, output_padding, groups, dilation);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor conv_transpose1d(
    const Tensor& input,
    const Tensor& weight,
    const ConvTranspose1dFuncOptions& options = {}) {
  return detail::conv_transpose1d(
      input,
      weight,
      options.bias(),
      options.stride(),
      options.padding(),
      options.output_padding(),
      options.groups(),
      options.dilation());
}

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor conv_transpose2d(
    const Tensor& input,
    const Tensor& weight,
    const Tensor& bias,
    IntArrayRef stride,
    IntArrayRef padding,
    IntArrayRef output_padding,
    int64_t groups,
    IntArrayRef dilation) {
  return torch::conv_transpose2d(
      input, weight, bias, stride, padding, output_padding, groups, dilation);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor conv_transpose2d(
    const Tensor& input,
    const Tensor& weight,
    const ConvTranspose2dFuncOptions& options = {}) {
  return detail::conv_transpose2d(
      input,
      weight,
      options.bias(),
      options.stride(),
      options.padding(),
      options.output_padding(),
      options.groups(),
      options.dilation());
}

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor conv_transpose3d(
    const Tensor& input,
    const Tensor& weight,
    const Tensor& bias,
    IntArrayRef stride,
    IntArrayRef padding,
    IntArrayRef output_padding,
    int64_t groups,
    IntArrayRef dilation) {
  return torch::conv_transpose3d(
      input, weight, bias, stride, padding, output_padding, groups, dilation);
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor conv_transpose3d(
    const Tensor& input,
    const Tensor& weight,
    const ConvTranspose3dFuncOptions& options = {}) {
  return detail::conv_transpose3d(
      input,
      weight,
      options.bias(),
      options.stride(),
      options.padding(),
      options.output_padding(),
      options.groups(),
      options.dilation());
}

} // namespace torch::nn::functional

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources