Program Listing for File any.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/modules/container/any.h
)
#pragma once
#include <torch/nn/module.h>
#include <torch/nn/modules/container/any_module_holder.h>
#include <torch/types.h>
#include <memory>
#include <type_traits>
#include <utility>
#include <vector>
namespace torch::nn {
class AnyModule {
public:
AnyModule() = default;
template <typename ModuleType>
explicit AnyModule(std::shared_ptr<ModuleType> module);
template <
typename ModuleType,
typename = torch::detail::enable_if_module_t<ModuleType>>
explicit AnyModule(ModuleType&& module);
template <typename ModuleType>
explicit AnyModule(const ModuleHolder<ModuleType>& module_holder);
AnyModule(AnyModule&&) = default;
AnyModule& operator=(AnyModule&&) = default;
AnyModule(const AnyModule& other);
AnyModule& operator=(const AnyModule& other);
AnyModule clone(std::optional<Device> device = std::nullopt) const;
template <typename ModuleType>
AnyModule& operator=(std::shared_ptr<ModuleType> module);
template <typename... ArgumentTypes>
AnyValue any_forward(ArgumentTypes&&... arguments);
template <typename ReturnType = torch::Tensor, typename... ArgumentTypes>
ReturnType forward(ArgumentTypes&&... arguments);
template <typename T, typename = torch::detail::enable_if_module_t<T>>
T& get();
template <typename T, typename = torch::detail::enable_if_module_t<T>>
const T& get() const;
template <typename T, typename ContainedType = typename T::ContainedType>
T get() const;
std::shared_ptr<Module> ptr() const;
template <typename T, typename = torch::detail::enable_if_module_t<T>>
std::shared_ptr<T> ptr() const;
const std::type_info& type_info() const;
bool is_empty() const noexcept;
private:
template <
typename ModuleType,
typename Class,
typename ReturnType,
typename... ArgumentTypes>
std::unique_ptr<AnyModulePlaceholder> make_holder(
std::shared_ptr<ModuleType>&& module,
ReturnType (Class::*)(ArgumentTypes...));
template <typename ModuleType, typename ReturnType, typename... ArgumentTypes>
ModuleType& get_(ReturnType (ModuleType::*)(ArgumentTypes...)) const;
template <typename ModuleType>
ModuleType& get_() const;
std::unique_ptr<AnyModulePlaceholder> content_;
};
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <typename ModuleType>
AnyModule::AnyModule(std::shared_ptr<ModuleType> module)
: content_(make_holder(
std::move(module),
&std::remove_reference_t<ModuleType>::forward)) {
// `AnyModule` can only store an `nn::Module` subclass object that provides
// a `forward()` method that has a non-templatized return type.
// (e.g. `AnyModule` cannot store `nn::Sequential`, because `nn::Sequential`'s
// `forward()` method has a templatized return type.)
static_assert(
torch::detail::is_module<ModuleType>::value,
"Can only store object derived from nn::Module into AnyModule");
static_assert(
torch::detail::has_forward<ModuleType>::value,
"Can only store module with a forward() method that has a non-templatized"
" argument type and return type into AnyModule (e.g. we cannot store nn::Sequential"
"into AnyModule, because its forward() method's argument type and return type are templatized."
" If you need to use nn::Sequentials inside each other you can subclass "
"nn::Sequential and write a non-templatized forward function for it. You can checkout "
"https://github.com/pytorch/vision/blob/2f46070f3cb1ea894d82578f3dc5677f82f34958/torchvision/csrc/models/mnasnet.cpp#L59 "
"for an example on how to do this.).");
}
template <typename ModuleType, typename>
AnyModule::AnyModule(ModuleType&& module)
: AnyModule(
std::make_shared<ModuleType>(std::forward<ModuleType>(module))) {}
template <typename ModuleType>
AnyModule::AnyModule(const ModuleHolder<ModuleType>& module_holder)
: AnyModule(module_holder.ptr()) {}
inline AnyModule::AnyModule(const AnyModule& other)
: content_(other.content_ ? other.content_->copy() : nullptr) {}
inline AnyModule& AnyModule::operator=(const AnyModule& other) {
if (this != &other) {
content_ = other.content_ ? other.content_->copy() : nullptr;
}
return *this;
}
inline AnyModule AnyModule::clone(std::optional<Device> device) const {
AnyModule clone;
clone.content_ = content_ ? content_->clone_module(device) : nullptr;
return clone;
}
template <typename ModuleType>
AnyModule& AnyModule::operator=(std::shared_ptr<ModuleType> module) {
*this = AnyModule(std::move(module));
return *this;
}
template <typename... ArgumentTypes>
AnyValue AnyModule::any_forward(ArgumentTypes&&... arguments) {
TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty AnyModule");
std::vector<AnyValue> values;
values.reserve(sizeof...(ArgumentTypes));
torch::apply(
[&values](AnyValue&& value) { values.push_back(std::move(value)); },
AnyValue(std::forward<ArgumentTypes>(arguments))...);
return content_->forward(std::move(values));
}
template <typename ReturnType, typename... ArgumentTypes>
ReturnType AnyModule::forward(ArgumentTypes&&... arguments) {
return any_forward(std::forward<ArgumentTypes>(arguments)...)
.template get<ReturnType>();
}
template <typename T, typename>
T& AnyModule::get() {
TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
return get_<T>();
}
template <typename T, typename>
const T& AnyModule::get() const {
TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
return get_<T>();
}
template <typename T, typename ContainedType>
T AnyModule::get() const {
return T(ptr<ContainedType>());
}
inline std::shared_ptr<Module> AnyModule::ptr() const {
TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
return content_->ptr();
}
template <typename T, typename>
std::shared_ptr<T> AnyModule::ptr() const {
TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
// Call get() but discard the value, just to do the type checking.
get_<T>();
return std::dynamic_pointer_cast<T>(ptr());
}
inline const std::type_info& AnyModule::type_info() const {
TORCH_CHECK(!is_empty(), "Cannot call type_info() on an empty AnyModule");
return content_->type_info;
}
inline bool AnyModule::is_empty() const noexcept {
return content_ == nullptr;
}
// Private Methods
template <
typename ModuleType,
typename Class,
typename ReturnType,
typename... ArgumentTypes>
std::unique_ptr<AnyModulePlaceholder> AnyModule::make_holder(
std::shared_ptr<ModuleType>&& module,
ReturnType (Class::*)(ArgumentTypes...)) {
static_assert(
torch::detail::check_not_lvalue_references<ArgumentTypes...>(),
"Modules stored inside AnyModule must not take references. "
"Use pointers instead.");
static_assert(
!std::is_void_v<ReturnType>,
"AnyModule cannot store modules that return void "
"(you can return a dummy value).");
return std::make_unique<
AnyModuleHolder<std::decay_t<ModuleType>, ArgumentTypes...>>(
std::move(module));
}
template <typename ModuleType>
ModuleType& AnyModule::get_() const {
using M = std::remove_reference_t<ModuleType>;
static_assert(
torch::detail::has_forward<M>::value,
"Can only call AnyModule::get<T> with a type T that has a forward method");
return get_(&M::forward);
}
template <typename ModuleType, typename ReturnType, typename... ArgumentTypes>
ModuleType& AnyModule::get_(
ReturnType (ModuleType::*)(ArgumentTypes...)) const {
if (typeid(ModuleType).hash_code() == type_info().hash_code()) {
return *static_cast<AnyModuleHolder<ModuleType, ArgumentTypes...>&>(
*content_)
.module;
}
TORCH_CHECK(
false,
"Attempted to cast module of type ",
c10::demangle(type_info().name()),
" to type ",
c10::demangle(typeid(ModuleType).name()));
}
} // namespace torch::nn