Shortcuts

Program Listing for File parameterlist.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/container/parameterlist.h)

#pragma once

#include <torch/nn/cloneable.h>
#include <torch/nn/module.h>

#include <vector>

namespace torch {
namespace nn {
class ParameterListImpl : public Cloneable<ParameterListImpl> {
 public:
  using Iterator = typename std::vector<
      OrderedDict<std::string, torch::Tensor>::Item>::iterator;
  using ConstIterator = typename std::vector<
      OrderedDict<std::string, torch::Tensor>::Item>::const_iterator;

  ParameterListImpl() = default;

  template <typename... Tensors>
  explicit ParameterListImpl(Tensors&&... params) {
    parameters_.reserve(sizeof...(Tensors));
    push_back_var(std::forward<Tensors>(params)...);
  }

  template <typename... Tensors>
  explicit ParameterListImpl(const Tensors&... params) {
    parameters_.reserve(sizeof...(Tensors));
    push_back_var(std::forward<Tensors>(params)...);
  }

  void reset() override {}

  void pretty_print(std::ostream& stream) const override {
    stream << "torch::nn::ParameterList(" << std::endl;
    for (const auto& pair : parameters_) {
      stream << "(" << pair.key() << ")"
             << ": Parameter containing: [" << pair.value().scalar_type()
             << " of size " << pair.value().sizes() << "]";
      ;
      stream << std::endl;
    }
    stream << ")";
  }

  void append(torch::Tensor&& param) {
    bool requires_grad = param.requires_grad();
    register_parameter(
        c10::to_string(parameters_.size()), std::move(param), requires_grad);
  }

  void append(const torch::Tensor& param) {
    bool requires_grad = param.requires_grad();
    register_parameter(
        c10::to_string(parameters_.size()), param, requires_grad);
  }

  void append(const OrderedDict<std::string, torch::Tensor>::Item& pair) {
    register_parameter(
        c10::to_string(parameters_.size()),
        pair.value(),
        pair.value().requires_grad());
  }

  template <typename Container>
  void extend(const Container& container) {
    for (const auto& param : container) {
      append(param);
    }
  }

  Iterator begin() {
    return parameters_.begin();
  }

  ConstIterator begin() const {
    return parameters_.begin();
  }

  Iterator end() {
    return parameters_.end();
  }

  ConstIterator end() const {
    return parameters_.end();
  }

  at::Tensor& at(size_t idx) {
    TORCH_CHECK(idx < size(), "Index out of range");
    return parameters_[c10::to_string(idx)];
  }

  const at::Tensor& at(size_t idx) const {
    TORCH_CHECK(idx < size(), "Index out of range");
    return parameters_[c10::to_string(idx)];
  }

  at::Tensor& operator[](size_t idx) {
    return at(idx);
  }

  const at::Tensor& operator[](size_t idx) const {
    return at(idx);
  }

  size_t size() const noexcept {
    return parameters_.size();
  }
  bool is_empty() const noexcept {
    return parameters_.is_empty();
  }

  template <typename Container>
  Container& operator+=(const Container& other) {
    extend(other);
    return *this;
  }

 private:
  template <typename Head, typename... Tail>
  void push_back_var(Head&& head, Tail&&... tail) {
    append(std::forward<Head>(head));
    // Recursively calls this method, until the parameter pack only thas this
    // entry left. Then calls `push_back()` a final time (above).
    push_back_var(std::forward<Tail>(tail)...);
  }

  void push_back_var() {}
};
TORCH_MODULE(ParameterList);
} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources