Shortcuts

Program Listing for File fold.h

Return to documentation for file (torch/csrc/api/include/torch/nn/functional/fold.h)

#pragma once

#include <torch/nn/options/fold.h>

namespace torch {
namespace nn {
namespace functional {

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor fold(const Tensor& input,
                   ExpandingArray<2> output_size,
                   ExpandingArray<2> kernel_size,
                   ExpandingArray<2> dilation,
                   ExpandingArray<2> padding,
                   ExpandingArray<2> stride) {
  if (input.dim() == 3 || input.dim() == 2) {
    return torch::col2im(
        input,
        output_size,
        kernel_size,
        dilation,
        padding,
        stride);
  } else {
    TORCH_CHECK(
        false,
        "Input Error: Only unbatched (2D) or batched (3D) input Tensors are supported "
        "(got ", input.dim(), "D)");
  }
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor fold(const Tensor& input, const FoldFuncOptions& options) {
  return detail::fold(
    input,
    options.output_size(),
    options.kernel_size(),
    options.dilation(),
    options.padding(),
    options.stride());
}

// ============================================================================

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor unfold(const Tensor& input,
                     ExpandingArray<2> kernel_size,
                     ExpandingArray<2> dilation,
                     ExpandingArray<2> padding,
                     ExpandingArray<2> stride) {
  if (input.dim() == 4) {
    return torch::im2col(
        input,
        kernel_size,
        dilation,
        padding,
        stride);
  } else {
    TORCH_CHECK(
        false,
        "Input Error: Only 4D input Tensors are supported "
        "(got ", input.dim(), "D)");
  }
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

inline Tensor unfold(const Tensor& input, const UnfoldFuncOptions& options) {
  return detail::unfold(input, options.kernel_size(), options.dilation(), options.padding(), options.stride());
}

} // namespace functional
} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources