Program Listing for File any.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/container/any.h)

#pragma once

#include <torch/detail/static.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/container/any_module_holder.h>
#include <torch/nn/modules/container/any_value.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>

#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/utils/variadic.h>

#include <ATen/Device.h>

#include <memory>
#include <type_traits>
#include <typeinfo>
#include <utility>
#include <vector>

namespace torch {
namespace nn {

class AnyModule {
  AnyModule() = default;

  template <typename ModuleType>
  explicit AnyModule(std::shared_ptr<ModuleType> module);

  template <
      typename ModuleType,
      typename = torch::detail::enable_if_module_t<ModuleType>>
  explicit AnyModule(ModuleType&& module);

  template <typename ModuleType>
  explicit AnyModule(const ModuleHolder<ModuleType>& module_holder);

  AnyModule(AnyModule&&) = default;
  AnyModule& operator=(AnyModule&&) = default;

  AnyModule(const AnyModule& other);
  AnyModule& operator=(const AnyModule& other);

  AnyModule clone(optional<Device> device = nullopt) const;

  template <typename ModuleType>
  AnyModule& operator=(std::shared_ptr<ModuleType> module);

  template <typename... ArgumentTypes>
  AnyValue any_forward(ArgumentTypes&&... arguments);

  template <typename ReturnType = torch::Tensor, typename... ArgumentTypes>
  ReturnType forward(ArgumentTypes&&... arguments);

  template <typename T, typename = torch::detail::enable_if_module_t<T>>
  T& get();

  template <typename T, typename = torch::detail::enable_if_module_t<T>>
  const T& get() const;

  template <typename T, typename ContainedType = typename T::ContainedType>
  T get() const;

  std::shared_ptr<Module> ptr() const;

  template <typename T, typename = torch::detail::enable_if_module_t<T>>
  std::shared_ptr<T> ptr() const;

  const std::type_info& type_info() const;

  bool is_empty() const noexcept;

  template <
      typename ModuleType,
      typename Class,
      typename ReturnType,
      typename... ArgumentTypes>
  std::unique_ptr<AnyModulePlaceholder> make_holder(
      std::shared_ptr<ModuleType>&& module,
      ReturnType (Class::*)(ArgumentTypes...));

  template <typename ModuleType, typename ReturnType, typename... ArgumentTypes>
  ModuleType& get_(ReturnType (ModuleType::*)(ArgumentTypes...)) const;

  template <typename ModuleType>
  ModuleType& get_() const;

  std::unique_ptr<AnyModulePlaceholder> content_;

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template <typename ModuleType>
AnyModule::AnyModule(std::shared_ptr<ModuleType> module)
    : content_(make_holder(
          &std::remove_reference<ModuleType>::type::forward)) {
  // `AnyModule` can only store an `nn::Module` subclass object that provides
  // a `forward()` method that has a non-templatized return type.
  // (e.g. `AnyModule` cannot store `nn::Sequential`, because `nn::Sequential`'s
  // `forward()` method has a templatized return type.)
      "Can only store object derived from nn::Module into AnyModule");
      "Can only store module with a forward() method that has a non-templatized"
      " argument type and return type into AnyModule (e.g. we cannot store nn::Sequential"
      "into AnyModule, because its forward() method's argument type and return type are templatized."
      " If you need to use nn::Sequentials inside each other you can subclass "
      "nn::Sequential and write a non-templatized forward function for it. You can checkout "
      " "
      "for an example on how to do this.).");

template <typename ModuleType, typename>
AnyModule::AnyModule(ModuleType&& module)
    : AnyModule(
          std::make_shared<ModuleType>(std::forward<ModuleType>(module))) {}

template <typename ModuleType>
AnyModule::AnyModule(const ModuleHolder<ModuleType>& module_holder)
    : AnyModule(module_holder.ptr()) {}

inline AnyModule::AnyModule(const AnyModule& other)
    : content_(other.content_ ? other.content_->copy() : nullptr) {}

inline AnyModule& AnyModule::operator=(const AnyModule& other) {
  if (this != &other) {
    content_ = other.content_ ? other.content_->copy() : nullptr;
  return *this;

inline AnyModule AnyModule::clone(optional<Device> device) const {
  AnyModule clone;
  clone.content_ = content_ ? content_->clone_module(device) : nullptr;
  return clone;

template <typename ModuleType>
AnyModule& AnyModule::operator=(std::shared_ptr<ModuleType> module) {
  // NOLINTNEXTLINE(cppcoreguidelines-c-copy-assignment-signature)
  return (*this = AnyModule(std::move(module)));

template <typename... ArgumentTypes>
AnyValue AnyModule::any_forward(ArgumentTypes&&... arguments) {
  TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty AnyModule");
  std::vector<AnyValue> values;
      [&values](AnyValue&& value) { values.push_back(std::move(value)); },
  return content_->forward(std::move(values));

template <typename ReturnType, typename... ArgumentTypes>
ReturnType AnyModule::forward(ArgumentTypes&&... arguments) {
  return any_forward(std::forward<ArgumentTypes>(arguments)...)
      .template get<ReturnType>();

template <typename T, typename>
T& AnyModule::get() {
  TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
  return get_<T>();

template <typename T, typename>
const T& AnyModule::get() const {
  TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
  return get_<T>();

template <typename T, typename ContainedType>
T AnyModule::get() const {
  return T(ptr<ContainedType>());

inline std::shared_ptr<Module> AnyModule::ptr() const {
  TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
  return content_->ptr();

template <typename T, typename>
std::shared_ptr<T> AnyModule::ptr() const {
  TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
  // Call get() but discard the value, just to do the type checking.
  return std::dynamic_pointer_cast<T>(ptr());

inline const std::type_info& AnyModule::type_info() const {
  TORCH_CHECK(!is_empty(), "Cannot call type_info() on an empty AnyModule");
  return content_->type_info;

inline bool AnyModule::is_empty() const noexcept {
  return content_ == nullptr;

// Private Methods

template <
    typename ModuleType,
    typename Class,
    typename ReturnType,
    typename... ArgumentTypes>
std::unique_ptr<AnyModulePlaceholder> AnyModule::make_holder(
    std::shared_ptr<ModuleType>&& module,
    ReturnType (Class::*)(ArgumentTypes...)) {
      "Modules stored inside AnyModule must not take references. "
      "Use pointers instead.");
      "AnyModule cannot store modules that return void "
      "(you can return a dummy value).");
  return std::make_unique<
      AnyModuleHolder<std::decay_t<ModuleType>, ArgumentTypes...>>(

template <typename ModuleType>
ModuleType& AnyModule::get_() const {
  using M = typename std::remove_reference<ModuleType>::type;
      "Can only call AnyModule::get<T> with a type T that has a forward method");
  return get_(&M::forward);

template <typename ModuleType, typename ReturnType, typename... ArgumentTypes>
ModuleType& AnyModule::get_(
    ReturnType (ModuleType::*)(ArgumentTypes...)) const {
  if (typeid(ModuleType).hash_code() == type_info().hash_code()) {
    return *static_cast<AnyModuleHolder<ModuleType, ArgumentTypes...>&>(
      "Attempted to cast module of type ",
      " to type ",

} // namespace nn
} // namespace torch


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources