Program Listing for File padding.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/modules/padding.h
)
#pragma once
#include <torch/expanding_array.h>
#include <torch/nn/cloneable.h>
#include <torch/nn/functional/padding.h>
#include <torch/csrc/Export.h>
namespace torch {
namespace nn {
template <size_t D, typename Derived>
class TORCH_API ReflectionPadImpl : public torch::nn::Cloneable<Derived> {
public:
ReflectionPadImpl(ExpandingArray<D * 2> padding)
: ReflectionPadImpl(ReflectionPadOptions<D>(padding)) {}
explicit ReflectionPadImpl(const ReflectionPadOptions<D>& options_);
void reset() override;
Tensor forward(const Tensor& input);
void pretty_print(std::ostream& stream) const override;
ReflectionPadOptions<D> options;
};
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReflectionPad1d
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API ReflectionPad1dImpl
: public ReflectionPadImpl<1, ReflectionPad1dImpl> {
public:
using ReflectionPadImpl<1, ReflectionPad1dImpl>::ReflectionPadImpl;
};
TORCH_MODULE(ReflectionPad1d);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReflectionPad2d
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API ReflectionPad2dImpl
: public ReflectionPadImpl<2, ReflectionPad2dImpl> {
public:
using ReflectionPadImpl<2, ReflectionPad2dImpl>::ReflectionPadImpl;
};
TORCH_MODULE(ReflectionPad2d);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReflectionPad3d
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API ReflectionPad3dImpl
: public ReflectionPadImpl<3, ReflectionPad3dImpl> {
public:
using ReflectionPadImpl<3, ReflectionPad3dImpl>::ReflectionPadImpl;
};
TORCH_MODULE(ReflectionPad3d);
// ============================================================================
template <size_t D, typename Derived>
class TORCH_API ReplicationPadImpl : public torch::nn::Cloneable<Derived> {
public:
ReplicationPadImpl(ExpandingArray<D * 2> padding)
: ReplicationPadImpl(ReplicationPadOptions<D>(padding)) {}
explicit ReplicationPadImpl(const ReplicationPadOptions<D>& options_);
void reset() override;
Tensor forward(const Tensor& input);
void pretty_print(std::ostream& stream) const override;
ReplicationPadOptions<D> options;
};
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReplicationPad1d
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API ReplicationPad1dImpl
: public ReplicationPadImpl<1, ReplicationPad1dImpl> {
public:
using ReplicationPadImpl<1, ReplicationPad1dImpl>::ReplicationPadImpl;
};
TORCH_MODULE(ReplicationPad1d);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReplicationPad2d
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API ReplicationPad2dImpl
: public ReplicationPadImpl<2, ReplicationPad2dImpl> {
public:
using ReplicationPadImpl<2, ReplicationPad2dImpl>::ReplicationPadImpl;
};
TORCH_MODULE(ReplicationPad2d);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReplicationPad3d
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API ReplicationPad3dImpl
: public ReplicationPadImpl<3, ReplicationPad3dImpl> {
public:
using ReplicationPadImpl<3, ReplicationPad3dImpl>::ReplicationPadImpl;
};
TORCH_MODULE(ReplicationPad3d);
// ============================================================================
template <size_t D, typename Derived>
class TORCH_API ZeroPadImpl : public torch::nn::Cloneable<Derived> {
public:
ZeroPadImpl(ExpandingArray<D * 2> padding)
: ZeroPadImpl(ZeroPadOptions<D>(padding)) {}
explicit ZeroPadImpl(const ZeroPadOptions<D>& options_);
void reset() override;
Tensor forward(const Tensor& input);
void pretty_print(std::ostream& stream) const override;
ZeroPadOptions<D> options;
};
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ZeroPad1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Applies ZeroPad over a 1-D input.
class TORCH_API ZeroPad1dImpl : public ZeroPadImpl<1, ZeroPad1dImpl> {
public:
using ZeroPadImpl<1, ZeroPad1dImpl>::ZeroPadImpl;
};
TORCH_MODULE(ZeroPad1d);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ZeroPad2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Applies ZeroPad over a 2-D input.
class TORCH_API ZeroPad2dImpl : public ZeroPadImpl<2, ZeroPad2dImpl> {
public:
using ZeroPadImpl<2, ZeroPad2dImpl>::ZeroPadImpl;
};
TORCH_MODULE(ZeroPad2d);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ZeroPad3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Applies ZeroPad over a 3-D input.
class TORCH_API ZeroPad3dImpl : public ZeroPadImpl<3, ZeroPad3dImpl> {
public:
using ZeroPadImpl<3, ZeroPad3dImpl>::ZeroPadImpl;
};
TORCH_MODULE(ZeroPad3d);
// ============================================================================
template <size_t D, typename Derived>
class TORCH_API ConstantPadImpl : public torch::nn::Cloneable<Derived> {
public:
ConstantPadImpl(ExpandingArray<D * 2> padding, double value)
: ConstantPadImpl(ConstantPadOptions<D>(padding, value)) {}
explicit ConstantPadImpl(const ConstantPadOptions<D>& options_);
void reset() override;
Tensor forward(const Tensor& input);
void pretty_print(std::ostream& stream) const override;
ConstantPadOptions<D> options;
};
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConstantPad1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API ConstantPad1dImpl
: public ConstantPadImpl<1, ConstantPad1dImpl> {
public:
using ConstantPadImpl<1, ConstantPad1dImpl>::ConstantPadImpl;
};
TORCH_MODULE(ConstantPad1d);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConstantPad2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API ConstantPad2dImpl
: public ConstantPadImpl<2, ConstantPad2dImpl> {
public:
using ConstantPadImpl<2, ConstantPad2dImpl>::ConstantPadImpl;
};
TORCH_MODULE(ConstantPad2d);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConstantPad3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API ConstantPad3dImpl
: public ConstantPadImpl<3, ConstantPad3dImpl> {
public:
using ConstantPadImpl<3, ConstantPad3dImpl>::ConstantPadImpl;
};
TORCH_MODULE(ConstantPad3d);
} // namespace nn
} // namespace torch