Program Listing for File sequential.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/modules/container/sequential.h
)
#pragma once
#include <torch/detail/static.h>
#include <torch/nn/cloneable.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/container/any.h>
#include <torch/nn/modules/container/named_any.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>
#include <c10/util/Exception.h>
#include <cstdint>
#include <memory>
#include <ostream>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
namespace torch::nn {
class SequentialImpl : public Cloneable<SequentialImpl> {
public:
using Iterator = std::vector<AnyModule>::iterator;
using ConstIterator = std::vector<AnyModule>::const_iterator;
SequentialImpl() = default;
template <typename... Modules>
explicit SequentialImpl(Modules&&... modules) {
modules_.reserve(sizeof...(Modules));
push_back(std::forward<Modules>(modules)...);
}
explicit SequentialImpl(
torch::OrderedDict<std::string, AnyModule>&& ordered_dict) {
modules_.reserve(ordered_dict.size());
for (auto& item : ordered_dict) {
push_back(item.key(), std::move(item.value()));
}
}
explicit SequentialImpl(std::initializer_list<NamedAnyModule> named_modules) {
modules_.reserve(named_modules.size());
for (const auto& named_module : named_modules) {
push_back(named_module.name(), named_module.module());
}
}
std::shared_ptr<Module> clone(
const std::optional<Device>& device = std::nullopt) const override {
auto clone = std::make_shared<SequentialImpl>();
for (const auto& module : modules_) {
clone->push_back(module.clone(device));
}
return clone;
}
void reset() override {}
void pretty_print(std::ostream& stream) const override {
stream << "torch::nn::Sequential";
}
template <typename ReturnType = Tensor, typename... InputTypes>
ReturnType forward(InputTypes&&... inputs) {
TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty Sequential");
auto iterator = modules_.begin();
auto input = iterator->any_forward(std::forward<InputTypes>(inputs)...);
for (++iterator; iterator != modules_.end(); ++iterator) {
input = iterator->any_forward(std::move(input));
}
// Check the return value and give a nice error message if the requested
// return type was incorrect.
if (auto* return_value = input.template try_get<ReturnType>()) {
return std::move(*return_value);
}
TORCH_CHECK(
false,
"The type of the return value is ",
c10::demangle(input.type_info().name()),
", but you asked for type ",
c10::demangle(typeid(ReturnType).name()));
}
template <typename ModuleType>
void push_back(std::shared_ptr<ModuleType> module_ptr) {
push_back(std::to_string(modules_.size()), std::move(module_ptr));
}
template <typename ModuleType>
void push_back(std::string name, std::shared_ptr<ModuleType> module_ptr) {
push_back(std::move(name), AnyModule(std::move(module_ptr)));
}
template <typename M, typename = torch::detail::enable_if_module_t<M>>
void push_back(M&& module) {
push_back(std::to_string(modules_.size()), std::forward<M>(module));
}
template <typename M, typename = torch::detail::enable_if_module_t<M>>
void push_back(std::string name, M&& module) {
using Type = typename std::remove_reference_t<M>;
push_back(std::move(name), std::make_shared<Type>(std::forward<M>(module)));
}
template <typename M>
void push_back(const ModuleHolder<M>& module_holder) {
push_back(std::to_string(modules_.size()), module_holder);
}
template <typename M>
void push_back(std::string name, const ModuleHolder<M>& module_holder) {
push_back(std::move(name), module_holder.ptr());
}
template <typename Container>
void extend(const Container& container) {
for (const auto& module : container) {
push_back(module);
}
}
void push_back(AnyModule any_module) {
push_back(std::to_string(modules_.size()), std::move(any_module));
}
void push_back(std::string name, AnyModule any_module) {
modules_.push_back(std::move(any_module));
const auto index = modules_.size() - 1;
register_module(std::move(name), modules_[index].ptr());
}
Iterator begin() {
return modules_.begin();
}
ConstIterator begin() const {
return modules_.begin();
}
Iterator end() {
return modules_.end();
}
ConstIterator end() const {
return modules_.end();
}
template <typename T>
T& at(size_t index) {
static_assert(
torch::detail::is_module<T>::value,
"Can only call Sequential::at with an nn::Module type");
TORCH_CHECK(index < size(), "Index out of range");
return modules_[index].get<T>();
}
template <typename T>
const T& at(size_t index) const {
static_assert(
torch::detail::is_module<T>::value,
"Can only call Sequential::at with an nn::Module type");
TORCH_CHECK(index < size(), "Index out of range");
return modules_[index].get<T>();
}
std::shared_ptr<Module> ptr(size_t index) const {
TORCH_CHECK(index < size(), "Index out of range");
return modules_[index].ptr();
}
template <typename T>
std::shared_ptr<T> ptr(size_t index) const {
static_assert(
torch::detail::is_module<T>::value,
"Can only call Sequential::ptr with an nn::Module type");
TORCH_CHECK(index < size(), "Index out of range");
return modules_[index].ptr<T>();
}
std::shared_ptr<Module> operator[](size_t index) const {
// This is the only method we can call without a type.
return ptr(index);
}
size_t size() const noexcept {
return modules_.size();
}
bool is_empty() const noexcept {
return size() == 0;
}
private:
template <
typename First,
typename Second,
typename... Rest,
typename = std::enable_if_t<
!std::is_same_v<First, std::string> &&
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
!std::is_same_v<std::decay_t<First>, std::decay_t<const char (&)[]>>>>
void push_back(First&& first, Second&& second, Rest&&... rest) {
push_back(std::forward<First>(first));
// Recursively calls this method, until the parameter pack only thas this
// entry left. Then calls `push_back()` a final time (above).
push_back(std::forward<Second>(second), std::forward<Rest>(rest)...);
}
void push_back() {}
// Box the AnyModules to give Sequential reference semantics, like the rest of
// the API. Note that this is not required otherwise, this could just be a
// `vector<AnyModule>`.
std::vector<AnyModule> modules_;
};
class Sequential : public torch::nn::ModuleHolder<SequentialImpl> {
public:
using torch::nn::ModuleHolder<SequentialImpl>::ModuleHolder;
Sequential() : ModuleHolder() {}
Sequential(std::initializer_list<NamedAnyModule> named_modules)
: ModuleHolder(std::make_shared<SequentialImpl>(named_modules)) {}
};
} // namespace torch::nn