Shortcuts

Program Listing for File any.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/container/any.h)

#pragma once

#include <torch/detail/static.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/container/any_module_holder.h>
#include <torch/nn/modules/container/any_value.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>

#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/utils/variadic.h>

#include <ATen/Device.h>

#include <memory>
#include <type_traits>
#include <typeinfo>
#include <utility>
#include <vector>

namespace torch {
namespace nn {

class AnyModule {
 public:
  AnyModule() = default;

  template <typename ModuleType>
  explicit AnyModule(std::shared_ptr<ModuleType> module);

  template <
      typename ModuleType,
      typename = torch::detail::enable_if_module_t<ModuleType>>
  explicit AnyModule(ModuleType&& module);

  template <typename ModuleType>
  explicit AnyModule(const ModuleHolder<ModuleType>& module_holder);

  AnyModule(AnyModule&&) = default;
  AnyModule& operator=(AnyModule&&) = default;

  AnyModule(const AnyModule& other);
  AnyModule& operator=(const AnyModule& other);

  AnyModule clone(optional<Device> device = nullopt) const;

  template <typename ModuleType>
  AnyModule& operator=(std::shared_ptr<ModuleType> module);

  template <typename... ArgumentTypes>
  AnyValue any_forward(ArgumentTypes&&... arguments);

  template <typename ReturnType = torch::Tensor, typename... ArgumentTypes>
  ReturnType forward(ArgumentTypes&&... arguments);

  template <typename T, typename = torch::detail::enable_if_module_t<T>>
  T& get();

  template <typename T, typename = torch::detail::enable_if_module_t<T>>
  const T& get() const;

  template <typename T, typename ContainedType = typename T::ContainedType>
  T get() const;

  std::shared_ptr<Module> ptr() const;

  template <typename T, typename = torch::detail::enable_if_module_t<T>>
  std::shared_ptr<T> ptr() const;

  const std::type_info& type_info() const;

  bool is_empty() const noexcept;

 private:
  template <
      typename ModuleType,
      typename Class,
      typename ReturnType,
      typename... ArgumentTypes>
  std::unique_ptr<AnyModulePlaceholder> make_holder(
      std::shared_ptr<ModuleType>&& module,
      ReturnType (Class::*)(ArgumentTypes...));

  template <typename ModuleType, typename ReturnType, typename... ArgumentTypes>
  ModuleType& get_(ReturnType (ModuleType::*)(ArgumentTypes...)) const;

  template <typename ModuleType>
  ModuleType& get_() const;

  std::unique_ptr<AnyModulePlaceholder> content_;
};

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template <typename ModuleType>
AnyModule::AnyModule(std::shared_ptr<ModuleType> module)
    : content_(make_holder(
          std::move(module),
          &std::remove_reference<ModuleType>::type::forward)) {
  // `AnyModule` can only store an `nn::Module` subclass object that provides
  // a `forward()` method that has a non-templatized return type.
  // (e.g. `AnyModule` cannot store `nn::Sequential`, because `nn::Sequential`'s
  // `forward()` method has a templatized return type.)
  static_assert(
      torch::detail::is_module<ModuleType>::value,
      "Can only store object derived from nn::Module into AnyModule");
  static_assert(
      torch::detail::has_forward<ModuleType>::value,
      "Can only store module with a forward() method that has a non-templatized"
      " argument type and return type into AnyModule (e.g. we cannot store nn::Sequential"
      "into AnyModule, because its forward() method's argument type and return type are templatized."
      " If you need to use nn::Sequentials inside each other you can subclass "
      "nn::Sequential and write a non-templatized forward function for it. You can checkout "
      "https://github.com/pytorch/vision/blob/2f46070f3cb1ea894d82578f3dc5677f82f34958/torchvision/csrc/models/mnasnet.cpp#L59 "
      "for an example on how to do this.).");
}

template <typename ModuleType, typename>
AnyModule::AnyModule(ModuleType&& module)
    : AnyModule(
          std::make_shared<ModuleType>(std::forward<ModuleType>(module))) {}

template <typename ModuleType>
AnyModule::AnyModule(const ModuleHolder<ModuleType>& module_holder)
    : AnyModule(module_holder.ptr()) {}

inline AnyModule::AnyModule(const AnyModule& other)
    : content_(other.content_ ? other.content_->copy() : nullptr) {}

inline AnyModule& AnyModule::operator=(const AnyModule& other) {
  if (this != &other) {
    content_ = other.content_ ? other.content_->copy() : nullptr;
  }
  return *this;
}

inline AnyModule AnyModule::clone(optional<Device> device) const {
  AnyModule clone;
  clone.content_ = content_ ? content_->clone_module(device) : nullptr;
  return clone;
}

template <typename ModuleType>
AnyModule& AnyModule::operator=(std::shared_ptr<ModuleType> module) {
  // NOLINTNEXTLINE(cppcoreguidelines-c-copy-assignment-signature)
  return (*this = AnyModule(std::move(module)));
}

template <typename... ArgumentTypes>
AnyValue AnyModule::any_forward(ArgumentTypes&&... arguments) {
  TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty AnyModule");
  std::vector<AnyValue> values;
  values.reserve(sizeof...(ArgumentTypes));
  torch::apply(
      [&values](AnyValue&& value) { values.push_back(std::move(value)); },
      AnyValue(std::forward<ArgumentTypes>(arguments))...);
  return content_->forward(std::move(values));
}

template <typename ReturnType, typename... ArgumentTypes>
ReturnType AnyModule::forward(ArgumentTypes&&... arguments) {
  return any_forward(std::forward<ArgumentTypes>(arguments)...)
      .template get<ReturnType>();
}

template <typename T, typename>
T& AnyModule::get() {
  TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
  return get_<T>();
}

template <typename T, typename>
const T& AnyModule::get() const {
  TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
  return get_<T>();
}

template <typename T, typename ContainedType>
T AnyModule::get() const {
  return T(ptr<ContainedType>());
}

inline std::shared_ptr<Module> AnyModule::ptr() const {
  TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
  return content_->ptr();
}

template <typename T, typename>
std::shared_ptr<T> AnyModule::ptr() const {
  TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
  // Call get() but discard the value, just to do the type checking.
  get_<T>();
  return std::dynamic_pointer_cast<T>(ptr());
}

inline const std::type_info& AnyModule::type_info() const {
  TORCH_CHECK(!is_empty(), "Cannot call type_info() on an empty AnyModule");
  return content_->type_info;
}

inline bool AnyModule::is_empty() const noexcept {
  return content_ == nullptr;
}

// Private Methods

template <
    typename ModuleType,
    typename Class,
    typename ReturnType,
    typename... ArgumentTypes>
std::unique_ptr<AnyModulePlaceholder> AnyModule::make_holder(
    std::shared_ptr<ModuleType>&& module,
    ReturnType (Class::*)(ArgumentTypes...)) {
  static_assert(
      torch::detail::check_not_lvalue_references<ArgumentTypes...>(),
      "Modules stored inside AnyModule must not take references. "
      "Use pointers instead.");
  static_assert(
      !std::is_void<ReturnType>::value,
      "AnyModule cannot store modules that return void "
      "(you can return a dummy value).");
  return std::make_unique<
      AnyModuleHolder<decay_t<ModuleType>, ArgumentTypes...>>(
      std::move(module));
}

template <typename ModuleType>
ModuleType& AnyModule::get_() const {
  using M = typename std::remove_reference<ModuleType>::type;
  static_assert(
      torch::detail::has_forward<M>::value,
      "Can only call AnyModule::get<T> with a type T that has a forward method");
  return get_(&M::forward);
}

template <typename ModuleType, typename ReturnType, typename... ArgumentTypes>
ModuleType& AnyModule::get_(
    ReturnType (ModuleType::*)(ArgumentTypes...)) const {
  if (typeid(ModuleType).hash_code() == type_info().hash_code()) {
    return *static_cast<AnyModuleHolder<ModuleType, ArgumentTypes...>&>(
                *content_)
                .module;
  }
  AT_ERROR(
      "Attempted to cast module of type ",
      c10::demangle(type_info().name()),
      " to type ",
      c10::demangle(typeid(ModuleType).name()));
}

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources