Shortcuts

Program Listing for File sequential.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/container/sequential.h)

#pragma once

#include <torch/detail/static.h>
#include <torch/nn/cloneable.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/container/any.h>
#include <torch/nn/modules/container/named_any.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>

#include <c10/util/Exception.h>

#include <cstdint>
#include <memory>
#include <ostream>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>

namespace torch::nn {

class SequentialImpl : public Cloneable<SequentialImpl> {
 public:
  using Iterator = std::vector<AnyModule>::iterator;
  using ConstIterator = std::vector<AnyModule>::const_iterator;

  SequentialImpl() = default;

  template <typename... Modules>
  explicit SequentialImpl(Modules&&... modules) {
    modules_.reserve(sizeof...(Modules));
    push_back(std::forward<Modules>(modules)...);
  }

  explicit SequentialImpl(
      torch::OrderedDict<std::string, AnyModule>&& ordered_dict) {
    modules_.reserve(ordered_dict.size());
    for (auto& item : ordered_dict) {
      push_back(item.key(), std::move(item.value()));
    }
  }

  explicit SequentialImpl(std::initializer_list<NamedAnyModule> named_modules) {
    modules_.reserve(named_modules.size());
    for (const auto& named_module : named_modules) {
      push_back(named_module.name(), named_module.module());
    }
  }

  std::shared_ptr<Module> clone(
      const std::optional<Device>& device = std::nullopt) const override {
    auto clone = std::make_shared<SequentialImpl>();
    for (const auto& module : modules_) {
      clone->push_back(module.clone(device));
    }
    return clone;
  }

  void reset() override {}

  void pretty_print(std::ostream& stream) const override {
    stream << "torch::nn::Sequential";
  }

  template <typename ReturnType = Tensor, typename... InputTypes>
  ReturnType forward(InputTypes&&... inputs) {
    TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty Sequential");

    auto iterator = modules_.begin();
    auto input = iterator->any_forward(std::forward<InputTypes>(inputs)...);

    for (++iterator; iterator != modules_.end(); ++iterator) {
      input = iterator->any_forward(std::move(input));
    }

    // Check the return value and give a nice error message if the requested
    // return type was incorrect.
    if (auto* return_value = input.template try_get<ReturnType>()) {
      return std::move(*return_value);
    }
    TORCH_CHECK(
        false,
        "The type of the return value is ",
        c10::demangle(input.type_info().name()),
        ", but you asked for type ",
        c10::demangle(typeid(ReturnType).name()));
  }

  template <typename ModuleType>
  void push_back(std::shared_ptr<ModuleType> module_ptr) {
    push_back(std::to_string(modules_.size()), std::move(module_ptr));
  }

  template <typename ModuleType>
  void push_back(std::string name, std::shared_ptr<ModuleType> module_ptr) {
    push_back(std::move(name), AnyModule(std::move(module_ptr)));
  }

  template <typename M, typename = torch::detail::enable_if_module_t<M>>
  void push_back(M&& module) {
    push_back(std::to_string(modules_.size()), std::forward<M>(module));
  }

  template <typename M, typename = torch::detail::enable_if_module_t<M>>
  void push_back(std::string name, M&& module) {
    using Type = typename std::remove_reference_t<M>;
    push_back(std::move(name), std::make_shared<Type>(std::forward<M>(module)));
  }

  template <typename M>
  void push_back(const ModuleHolder<M>& module_holder) {
    push_back(std::to_string(modules_.size()), module_holder);
  }

  template <typename M>
  void push_back(std::string name, const ModuleHolder<M>& module_holder) {
    push_back(std::move(name), module_holder.ptr());
  }

  template <typename Container>
  void extend(const Container& container) {
    for (const auto& module : container) {
      push_back(module);
    }
  }

  void push_back(AnyModule any_module) {
    push_back(std::to_string(modules_.size()), std::move(any_module));
  }

  void push_back(std::string name, AnyModule any_module) {
    modules_.push_back(std::move(any_module));
    const auto index = modules_.size() - 1;
    register_module(std::move(name), modules_[index].ptr());
  }

  Iterator begin() {
    return modules_.begin();
  }

  ConstIterator begin() const {
    return modules_.begin();
  }

  Iterator end() {
    return modules_.end();
  }

  ConstIterator end() const {
    return modules_.end();
  }

  template <typename T>
  T& at(size_t index) {
    static_assert(
        torch::detail::is_module<T>::value,
        "Can only call Sequential::at with an nn::Module type");
    TORCH_CHECK(index < size(), "Index out of range");
    return modules_[index].get<T>();
  }

  template <typename T>
  const T& at(size_t index) const {
    static_assert(
        torch::detail::is_module<T>::value,
        "Can only call Sequential::at with an nn::Module type");
    TORCH_CHECK(index < size(), "Index out of range");
    return modules_[index].get<T>();
  }

  std::shared_ptr<Module> ptr(size_t index) const {
    TORCH_CHECK(index < size(), "Index out of range");
    return modules_[index].ptr();
  }

  template <typename T>
  std::shared_ptr<T> ptr(size_t index) const {
    static_assert(
        torch::detail::is_module<T>::value,
        "Can only call Sequential::ptr with an nn::Module type");
    TORCH_CHECK(index < size(), "Index out of range");
    return modules_[index].ptr<T>();
  }

  std::shared_ptr<Module> operator[](size_t index) const {
    // This is the only method we can call without a type.
    return ptr(index);
  }

  size_t size() const noexcept {
    return modules_.size();
  }

  bool is_empty() const noexcept {
    return size() == 0;
  }

 private:
  template <
      typename First,
      typename Second,
      typename... Rest,
      typename = std::enable_if_t<
          !std::is_same_v<First, std::string> &&
          // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
          !std::is_same_v<std::decay_t<First>, std::decay_t<const char (&)[]>>>>
  void push_back(First&& first, Second&& second, Rest&&... rest) {
    push_back(std::forward<First>(first));
    // Recursively calls this method, until the parameter pack only thas this
    // entry left. Then calls `push_back()` a final time (above).
    push_back(std::forward<Second>(second), std::forward<Rest>(rest)...);
  }

  void push_back() {}

  // Box the AnyModules to give Sequential reference semantics, like the rest of
  // the API. Note that this is not required otherwise, this could just be a
  // `vector<AnyModule>`.
  std::vector<AnyModule> modules_;
};

class Sequential : public torch::nn::ModuleHolder<SequentialImpl> {
 public:
  using torch::nn::ModuleHolder<SequentialImpl>::ModuleHolder;

  Sequential() : ModuleHolder() {}

  Sequential(std::initializer_list<NamedAnyModule> named_modules)
      : ModuleHolder(std::make_shared<SequentialImpl>(named_modules)) {}
};
} // namespace torch::nn

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources