Program Listing for File modulelist.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/modules/container/modulelist.h
)
#pragma once
#include <c10/util/irange.h>
#include <torch/nn/cloneable.h>
#include <torch/nn/module.h>
#include <utility>
#include <vector>
namespace torch {
namespace nn {
class ModuleListImpl : public Cloneable<ModuleListImpl> {
public:
using Iterator = std::vector<std::shared_ptr<Module>>::iterator;
using ConstIterator = std::vector<std::shared_ptr<Module>>::const_iterator;
ModuleListImpl() = default;
template <typename... Modules>
explicit ModuleListImpl(Modules&&... modules) {
modules_.reserve(sizeof...(Modules));
push_back_var(std::forward<Modules>(modules)...);
}
std::shared_ptr<Module> clone(
const optional<Device>& device = nullopt) const override {
auto clone = std::make_shared<ModuleListImpl>();
for (const auto& module : modules_) {
clone->push_back(module->clone(device));
}
return clone;
}
void reset() override {}
void pretty_print(std::ostream& stream) const override {
stream << "torch::nn::ModuleList";
}
void push_back(std::shared_ptr<Module> module) {
modules_.push_back(std::move(module));
const auto index = modules_.size() - 1;
register_module(c10::to_string(index), modules_[index]);
}
template <typename M, typename = torch::detail::enable_if_module_t<M>>
void push_back(M&& module) {
using Type = typename std::remove_reference<M>::type;
push_back(std::make_shared<Type>(std::forward<M>(module)));
}
template <typename M>
void push_back(const ModuleHolder<M>& module_holder) {
push_back(module_holder.ptr());
}
template <typename Container>
void extend(const Container& container) {
for (const auto& module : container) {
push_back(module);
}
}
Iterator begin() {
return modules_.begin();
}
ConstIterator begin() const {
return modules_.begin();
}
Iterator end() {
return modules_.end();
}
ConstIterator end() const {
return modules_.end();
}
template <typename T>
T& at(size_t index) {
static_assert(
torch::detail::is_module<T>::value,
"Can only call ModuleList::at with an nn::Module type");
TORCH_CHECK(index < size(), "Index out of range");
auto module = modules_[index]->as<T>();
TORCH_CHECK(
module,
"Unable to cast module[",
index,
"] to ",
c10::demangle(typeid(T).name()));
return *module;
}
template <typename T>
const T& at(size_t index) const {
static_assert(
torch::detail::is_module<T>::value,
"Can only call ModuleList::at with an nn::Module type");
TORCH_CHECK(index < size(), "Index out of range");
const auto module = modules_[index]->as<T>();
TORCH_CHECK(
module,
"Unable to cast module[",
index,
"] to ",
c10::demangle(typeid(T).name()));
return *module;
}
std::shared_ptr<Module> ptr(size_t index) const {
TORCH_CHECK(index < size(), "Index out of range");
return modules_[index];
}
template <typename T>
std::shared_ptr<T> ptr(size_t index) const {
static_assert(
torch::detail::is_module<T>::value,
"Can only call ModuleList::ptr with an nn::Module type");
TORCH_CHECK(index < size(), "Index out of range");
return std::dynamic_pointer_cast<T>(modules_[index]);
}
std::shared_ptr<Module> operator[](size_t index) const {
// This is the only method we can call without a type.
return ptr(index);
}
size_t size() const noexcept {
return modules_.size();
}
bool is_empty() const noexcept {
return size() == 0;
}
void insert(size_t index, std::shared_ptr<Module> module) {
TORCH_CHECK(index <= size(), "Index out of range");
if (index == size())
push_back(std::move(module));
else {
modules_.insert(
modules_.begin() + Iterator::difference_type(index),
std::move(module));
for (const auto i : c10::irange(index, size() - 1)) {
(void)i; // Suppress unused variable warning
replace_module(c10::to_string(index), modules_[index]);
}
register_module(c10::to_string(size() - 1), modules_.back());
}
}
template <typename M>
void insert(size_t index, const ModuleHolder<M>& module_holder) {
insert(index, module_holder.ptr());
}
template <typename M, typename = torch::detail::enable_if_module_t<M>>
void insert(size_t index, M&& module) {
using Type = typename std::remove_reference<M>::type;
insert(index, std::make_shared<Type>(std::forward<M>(module)));
}
private:
template <typename Head, typename... Tail>
void push_back_var(Head&& head, Tail&&... tail) {
push_back(std::forward<Head>(head));
// Recursively calls this method, until the parameter pack only thas this
// entry left. Then calls `push_back()` a final time (above).
push_back_var(std::forward<Tail>(tail)...);
}
void push_back_var() {}
// Box the AnyModules to give ModuleList reference semantics, like the rest of
// the API. Note that this is not required otherwise, this could just be a
// `vector<AnyModule>`.
std::vector<std::shared_ptr<Module>> modules_;
};
TORCH_MODULE(ModuleList);
} // namespace nn
} // namespace torch