Program Listing for File pimpl.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/pimpl.h
)
#pragma once
#include <torch/arg.h>
#include <torch/detail/static.h>
#include <torch/serialize/archive.h>
#include <torch/types.h>
#include <torch/csrc/utils/variadic.h>
#include <memory>
#include <type_traits>
#include <utility>
namespace torch {
namespace detail {
// Dump all the template metaprogramming in this file.
#include <torch/csrc/api/include/torch/nn/pimpl-inl.h>
} // namespace detail
namespace nn {
template <typename Contained>
class ModuleHolder : torch::detail::ModuleHolderIndicator {
protected:
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::shared_ptr<Contained> impl_;
public:
using ContainedType = Contained;
ModuleHolder() : impl_(default_construct()) {
static_assert(
std::is_default_constructible<Contained>::value,
"You are trying to default construct a module which has "
"no default constructor. Use = nullptr to give it the empty state "
"(e.g. `Linear linear = nullptr;` instead of `Linear linear;`).");
}
/* implicit */ ModuleHolder(std::nullptr_t) : impl_(nullptr) {}
template <
typename Head,
typename... Tail,
typename = typename std::enable_if<
!(torch::detail::is_module_holder_of<Head, ContainedType>::value &&
(sizeof...(Tail) == 0))>::type>
explicit ModuleHolder(Head&& head, Tail&&... tail)
: impl_(new Contained(
std::forward<Head>(head),
std::forward<Tail>(tail)...)) {}
/* implicit */ ModuleHolder(std::shared_ptr<Contained> module)
: impl_(std::move(module)) {}
explicit operator bool() const noexcept {
return !is_empty();
}
Contained* operator->() {
return get();
}
const Contained* operator->() const {
return get();
}
Contained& operator*() {
return *get();
}
const Contained& operator*() const {
return *get();
}
const std::shared_ptr<Contained>& ptr() const {
TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder");
return impl_;
}
Contained* get() {
TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder");
return impl_.get();
}
const Contained* get() const {
TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder");
return impl_.get();
}
template <typename... Args>
auto operator()(Args&&... args)
-> torch::detail::return_type_of_forward_t<Contained, Args...> {
// This will not compile if the module does not have a `forward()` method
// (as expected).
// NOTE: `std::forward` is qualified to prevent VS2017 emitting
// error C2872: 'std': ambiguous symbol
return impl_->forward(::std::forward<Args>(args)...);
}
template <typename Arg>
decltype(auto) operator[](Arg&& arg) {
return (*impl_)[::std::forward<Arg>(arg)];
}
bool is_empty() const noexcept {
return impl_ == nullptr;
}
private:
template <
typename T = Contained,
typename = torch::enable_if_t<std::is_default_constructible<T>::value>>
std::shared_ptr<Contained> default_construct() {
return std::make_shared<Contained>();
}
template <typename T = Contained>
torch::disable_if_t<
std::is_default_constructible<T>::value,
std::shared_ptr<Contained>>
default_construct() {
return nullptr;
}
};
template <typename ModuleType>
std::ostream& operator<<(
std::ostream& stream,
const nn::ModuleHolder<ModuleType>& module) {
return stream << *module;
}
template <typename ModuleType>
serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const nn::ModuleHolder<ModuleType>& module) {
return archive << module.ptr();
}
template <typename ModuleType>
serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
nn::ModuleHolder<ModuleType>& module) {
return archive >> module.ptr();
}
} // namespace nn
} // namespace torch
// Workaround for CUDA 10.2 and below not allowing attribute unused on
// using declarations.
#ifdef __CUDACC__
#define TORCH_UNUSED_EXCEPT_CUDA
#else
#define TORCH_UNUSED_EXCEPT_CUDA C10_UNUSED
#endif
#define TORCH_MODULE_IMPL(Name, ImplType) \
class Name : public torch::nn::ModuleHolder<ImplType> { /* NOLINT */ \
public: \
using torch::nn::ModuleHolder<ImplType>::ModuleHolder; \
using Impl TORCH_UNUSED_EXCEPT_CUDA = ImplType; \
}
#define TORCH_MODULE(Name) TORCH_MODULE_IMPL(Name, Name##Impl)