Program Listing for File parameterlist.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/modules/container/parameterlist.h
)
#pragma once
#include <torch/nn/cloneable.h>
#include <torch/nn/module.h>
#include <vector>
namespace torch {
namespace nn {
class ParameterListImpl : public Cloneable<ParameterListImpl> {
public:
using Iterator = typename std::vector<
OrderedDict<std::string, torch::Tensor>::Item>::iterator;
using ConstIterator = typename std::vector<
OrderedDict<std::string, torch::Tensor>::Item>::const_iterator;
ParameterListImpl() = default;
template <typename... Tensors>
explicit ParameterListImpl(Tensors&&... params) {
parameters_.reserve(sizeof...(Tensors));
push_back_var(std::forward<Tensors>(params)...);
}
template <typename... Tensors>
explicit ParameterListImpl(const Tensors&... params) {
parameters_.reserve(sizeof...(Tensors));
push_back_var(std::forward<Tensors>(params)...);
}
void reset() override {}
void pretty_print(std::ostream& stream) const override {
stream << "torch::nn::ParameterList(" << std::endl;
for (const auto& pair : parameters_) {
stream << "(" << pair.key() << ")"
<< ": Parameter containing: [" << pair.value().scalar_type()
<< " of size " << pair.value().sizes() << "]";
;
stream << std::endl;
}
stream << ")";
}
void append(torch::Tensor&& param) {
bool requires_grad = param.requires_grad();
register_parameter(
c10::to_string(parameters_.size()), std::move(param), requires_grad);
}
void append(const torch::Tensor& param) {
bool requires_grad = param.requires_grad();
register_parameter(
c10::to_string(parameters_.size()), param, requires_grad);
}
void append(const OrderedDict<std::string, torch::Tensor>::Item& pair) {
register_parameter(
c10::to_string(parameters_.size()),
pair.value(),
pair.value().requires_grad());
}
template <typename Container>
void extend(const Container& container) {
for (const auto& param : container) {
append(param);
}
}
Iterator begin() {
return parameters_.begin();
}
ConstIterator begin() const {
return parameters_.begin();
}
Iterator end() {
return parameters_.end();
}
ConstIterator end() const {
return parameters_.end();
}
at::Tensor& at(size_t idx) {
TORCH_CHECK(idx < size(), "Index out of range");
return parameters_[c10::to_string(idx)];
}
const at::Tensor& at(size_t idx) const {
TORCH_CHECK(idx < size(), "Index out of range");
return parameters_[c10::to_string(idx)];
}
at::Tensor& operator[](size_t idx) {
return at(idx);
}
const at::Tensor& operator[](size_t idx) const {
return at(idx);
}
size_t size() const noexcept {
return parameters_.size();
}
bool is_empty() const noexcept {
return parameters_.is_empty();
}
template <typename Container>
Container& operator+=(const Container& other) {
extend(other);
return *this;
}
private:
template <typename Head, typename... Tail>
void push_back_var(Head&& head, Tail&&... tail) {
append(std::forward<Head>(head));
// Recursively calls this method, until the parameter pack only thas this
// entry left. Then calls `push_back()` a final time (above).
push_back_var(std::forward<Tail>(tail)...);
}
void push_back_var() {}
};
TORCH_MODULE(ParameterList);
} // namespace nn
} // namespace torch