Shortcuts

Program Listing for File pimpl.h

Return to documentation for file (torch/csrc/api/include/torch/nn/pimpl.h)

#pragma once

#include <torch/arg.h>
#include <torch/detail/static.h>
#include <torch/serialize/archive.h>
#include <torch/types.h>

#include <torch/csrc/utils/variadic.h>

#include <memory>
#include <type_traits>
#include <utility>

namespace torch {
namespace detail {
// Dump all the template metaprogramming in this file.
#include <torch/csrc/api/include/torch/nn/pimpl-inl.h>
} // namespace detail

namespace nn {

template <typename Contained>
class ModuleHolder : torch::detail::ModuleHolderIndicator {
 protected:
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  std::shared_ptr<Contained> impl_;

 public:
  using ContainedType = Contained;

  ModuleHolder() : impl_(default_construct()) {
    static_assert(
        std::is_default_constructible<Contained>::value,
        "You are trying to default construct a module which has "
        "no default constructor. Use = nullptr to give it the empty state "
        "(e.g. `Linear linear = nullptr;` instead of `Linear linear;`).");
  }

  /* implicit */ ModuleHolder(std::nullptr_t) : impl_(nullptr) {}

  template <
      typename Head,
      typename... Tail,
      typename = typename std::enable_if<
          !(torch::detail::is_module_holder_of<Head, ContainedType>::value &&
            (sizeof...(Tail) == 0))>::type>
  explicit ModuleHolder(Head&& head, Tail&&... tail)
      : impl_(new Contained(
            std::forward<Head>(head),
            std::forward<Tail>(tail)...)) {}

  /* implicit */ ModuleHolder(std::shared_ptr<Contained> module)
      : impl_(std::move(module)) {}

  explicit operator bool() const noexcept {
    return !is_empty();
  }

  Contained* operator->() {
    return get();
  }

  const Contained* operator->() const {
    return get();
  }

  Contained& operator*() {
    return *get();
  }

  const Contained& operator*() const {
    return *get();
  }

  const std::shared_ptr<Contained>& ptr() const {
    TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder");
    return impl_;
  }

  Contained* get() {
    TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder");
    return impl_.get();
  }

  const Contained* get() const {
    TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder");
    return impl_.get();
  }

  template <typename... Args>
  auto operator()(Args&&... args)
      -> torch::detail::return_type_of_forward_t<Contained, Args...> {
    // This will not compile if the module does not have a `forward()` method
    // (as expected).
    // NOTE: `std::forward` is qualified to prevent VS2017 emitting
    // error C2872: 'std': ambiguous symbol
    return impl_->forward(::std::forward<Args>(args)...);
  }

  template <typename Arg>
  decltype(auto) operator[](Arg&& arg) {
    return (*impl_)[::std::forward<Arg>(arg)];
  }

  bool is_empty() const noexcept {
    return impl_ == nullptr;
  }

 private:

  template <
      typename T = Contained,
      typename = torch::enable_if_t<std::is_default_constructible<T>::value>>
  std::shared_ptr<Contained> default_construct() {
    return std::make_shared<Contained>();
  }

  template <typename T = Contained>
  torch::disable_if_t<
      std::is_default_constructible<T>::value,
      std::shared_ptr<Contained>>
  default_construct() {
    return nullptr;
  }
};

template <typename ModuleType>
std::ostream& operator<<(
    std::ostream& stream,
    const nn::ModuleHolder<ModuleType>& module) {
  return stream << *module;
}

template <typename ModuleType>
serialize::OutputArchive& operator<<(
    serialize::OutputArchive& archive,
    const nn::ModuleHolder<ModuleType>& module) {
  return archive << module.ptr();
}

template <typename ModuleType>
serialize::InputArchive& operator>>(
    serialize::InputArchive& archive,
    nn::ModuleHolder<ModuleType>& module) {
  return archive >> module.ptr();
}

} // namespace nn
} // namespace torch

// Workaround for CUDA 10.2 and below not allowing attribute unused on
// using declarations.
#ifdef __CUDACC__
#define TORCH_UNUSED_EXCEPT_CUDA
#else
#define TORCH_UNUSED_EXCEPT_CUDA C10_UNUSED
#endif

#define TORCH_MODULE_IMPL(Name, ImplType)                              \
  class Name : public torch::nn::ModuleHolder<ImplType> { /* NOLINT */ \
   public:                                                             \
    using torch::nn::ModuleHolder<ImplType>::ModuleHolder;             \
    using Impl TORCH_UNUSED_EXCEPT_CUDA = ImplType;                    \
  }

#define TORCH_MODULE(Name) TORCH_MODULE_IMPL(Name, Name##Impl)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources