Shortcuts

Program Listing for File modulelist.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/container/modulelist.h)

#pragma once

#include <c10/util/irange.h>
#include <torch/nn/cloneable.h>
#include <torch/nn/module.h>

#include <utility>
#include <vector>

namespace torch {
namespace nn {

class ModuleListImpl : public Cloneable<ModuleListImpl> {
 public:
  using Iterator = std::vector<std::shared_ptr<Module>>::iterator;
  using ConstIterator = std::vector<std::shared_ptr<Module>>::const_iterator;

  ModuleListImpl() = default;

  template <typename... Modules>
  explicit ModuleListImpl(Modules&&... modules) {
    modules_.reserve(sizeof...(Modules));
    push_back_var(std::forward<Modules>(modules)...);
  }

  std::shared_ptr<Module> clone(
      const optional<Device>& device = nullopt) const override {
    auto clone = std::make_shared<ModuleListImpl>();
    for (const auto& module : modules_) {
      clone->push_back(module->clone(device));
    }
    return clone;
  }

  void reset() override {}

  void pretty_print(std::ostream& stream) const override {
    stream << "torch::nn::ModuleList";
  }

  void push_back(std::shared_ptr<Module> module) {
    modules_.push_back(std::move(module));
    const auto index = modules_.size() - 1;
    register_module(c10::to_string(index), modules_[index]);
  }

  template <typename M, typename = torch::detail::enable_if_module_t<M>>
  void push_back(M&& module) {
    using Type = typename std::remove_reference<M>::type;
    push_back(std::make_shared<Type>(std::forward<M>(module)));
  }

  template <typename M>
  void push_back(const ModuleHolder<M>& module_holder) {
    push_back(module_holder.ptr());
  }

  template <typename Container>
  void extend(const Container& container) {
    for (const auto& module : container) {
      push_back(module);
    }
  }

  Iterator begin() {
    return modules_.begin();
  }

  ConstIterator begin() const {
    return modules_.begin();
  }

  Iterator end() {
    return modules_.end();
  }

  ConstIterator end() const {
    return modules_.end();
  }

  template <typename T>
  T& at(size_t index) {
    static_assert(
        torch::detail::is_module<T>::value,
        "Can only call ModuleList::at with an nn::Module type");
    TORCH_CHECK(index < size(), "Index out of range");
    auto module = modules_[index]->as<T>();
    TORCH_CHECK(
        module,
        "Unable to cast module[",
        index,
        "] to ",
        c10::demangle(typeid(T).name()));
    return *module;
  }

  template <typename T>
  const T& at(size_t index) const {
    static_assert(
        torch::detail::is_module<T>::value,
        "Can only call ModuleList::at with an nn::Module type");
    TORCH_CHECK(index < size(), "Index out of range");
    const auto module = modules_[index]->as<T>();
    TORCH_CHECK(
        module,
        "Unable to cast module[",
        index,
        "] to ",
        c10::demangle(typeid(T).name()));
    return *module;
  }

  std::shared_ptr<Module> ptr(size_t index) const {
    TORCH_CHECK(index < size(), "Index out of range");
    return modules_[index];
  }

  template <typename T>
  std::shared_ptr<T> ptr(size_t index) const {
    static_assert(
        torch::detail::is_module<T>::value,
        "Can only call ModuleList::ptr with an nn::Module type");
    TORCH_CHECK(index < size(), "Index out of range");
    return std::dynamic_pointer_cast<T>(modules_[index]);
  }

  std::shared_ptr<Module> operator[](size_t index) const {
    // This is the only method we can call without a type.
    return ptr(index);
  }

  size_t size() const noexcept {
    return modules_.size();
  }

  bool is_empty() const noexcept {
    return size() == 0;
  }

  void insert(size_t index, std::shared_ptr<Module> module) {
    TORCH_CHECK(index <= size(), "Index out of range");

    if (index == size())
      push_back(std::move(module));
    else {
      modules_.insert(
          modules_.begin() + Iterator::difference_type(index),
          std::move(module));

      for (const auto i : c10::irange(index, size() - 1)) {
        (void)i; // Suppress unused variable warning
        replace_module(c10::to_string(index), modules_[index]);
      }
      register_module(c10::to_string(size() - 1), modules_.back());
    }
  }

  template <typename M>
  void insert(size_t index, const ModuleHolder<M>& module_holder) {
    insert(index, module_holder.ptr());
  }

  template <typename M, typename = torch::detail::enable_if_module_t<M>>
  void insert(size_t index, M&& module) {
    using Type = typename std::remove_reference<M>::type;
    insert(index, std::make_shared<Type>(std::forward<M>(module)));
  }

 private:
  template <typename Head, typename... Tail>
  void push_back_var(Head&& head, Tail&&... tail) {
    push_back(std::forward<Head>(head));
    // Recursively calls this method, until the parameter pack only thas this
    // entry left. Then calls `push_back()` a final time (above).
    push_back_var(std::forward<Tail>(tail)...);
  }

  void push_back_var() {}

  // Box the AnyModules to give ModuleList reference semantics, like the rest of
  // the API. Note that this is not required otherwise, this could just be a
  // `vector<AnyModule>`.
  std::vector<std::shared_ptr<Module>> modules_;
};

TORCH_MODULE(ModuleList);

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources