Shortcuts

Program Listing for File parameterdict.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/container/parameterdict.h)

#pragma once

#include <torch/nn/cloneable.h>
#include <torch/nn/pimpl.h>
#include <torch/ordered_dict.h>
#include <utility>
#include <vector>

namespace torch {
namespace nn {

class ParameterDictImpl : public Cloneable<ParameterDictImpl> {
 public:
  using Iterator = OrderedDict<std::string, Tensor>::Iterator;
  using ConstIterator = OrderedDict<std::string, Tensor>::ConstIterator;

  ParameterDictImpl() = default;

  explicit ParameterDictImpl(
      const torch::OrderedDict<std::string, torch::Tensor>& params) {
    parameters_ = params;
  }

  void reset() override {}

  void pretty_print(std::ostream& stream) const override {
    stream << "torch::nn::ParameterDict(" << std::endl;
    for (const auto& pair : parameters_) {
      stream << "(" << pair.key() << ")"
             << ": Parameter containing: [" << pair.value().scalar_type()
             << " of size " << pair.value().sizes() << "]";
      ;
      stream << std::endl;
    }
    stream << ")";
  }

  Tensor& insert(std::string key, Tensor param) {
    bool requires_grad = param.requires_grad();
    return register_parameter(std::move(key), std::move(param), requires_grad);
  }

  Tensor pop(const std::string& key) {
    torch::Tensor v = parameters_[key];
    parameters_.erase(key);
    return v;
  }

  ::std::vector<std::string> keys() const {
    return parameters_.keys();
  }

  ::std::vector<torch::Tensor> values() const {
    return parameters_.values();
  }

  Iterator begin() {
    return parameters_.begin();
  }

  ConstIterator begin() const {
    return parameters_.begin();
  }

  Iterator end() {
    return parameters_.end();
  }

  ConstIterator end() const {
    return parameters_.end();
  }

  size_t size() const noexcept {
    return parameters_.size();
  }

  bool empty() const noexcept {
    return parameters_.is_empty();
  }

  template <typename Container>
  void update(const Container& container) {
    for (auto& item : container) {
      parameters_[item.key()] = item.value();
    }
  }

  void clear() {
    parameters_.clear();
  }

  bool contains(const std::string& key) const noexcept {
    return parameters_.contains(key);
  }

  const Tensor& get(const std::string& key) const {
    return parameters_[key];
  }

  Tensor& get(const std::string& key) {
    return parameters_[key];
  }

  Tensor& operator[](const std::string& key) {
    return parameters_[key];
  }

  const Tensor& operator[](const std::string& key) const {
    return parameters_[key];
  }
};

TORCH_MODULE(ParameterDict);

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources