Program Listing for File parameterdict.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/modules/container/parameterdict.h
)
#pragma once
#include <torch/nn/cloneable.h>
#include <torch/nn/pimpl.h>
#include <torch/ordered_dict.h>
#include <utility>
#include <vector>
namespace torch {
namespace nn {
class ParameterDictImpl : public Cloneable<ParameterDictImpl> {
public:
using Iterator = OrderedDict<std::string, Tensor>::Iterator;
using ConstIterator = OrderedDict<std::string, Tensor>::ConstIterator;
ParameterDictImpl() = default;
explicit ParameterDictImpl(
const torch::OrderedDict<std::string, torch::Tensor>& params) {
parameters_ = params;
}
void reset() override {}
void pretty_print(std::ostream& stream) const override {
stream << "torch::nn::ParameterDict(" << std::endl;
for (const auto& pair : parameters_) {
stream << "(" << pair.key() << ")"
<< ": Parameter containing: [" << pair.value().scalar_type()
<< " of size " << pair.value().sizes() << "]";
;
stream << std::endl;
}
stream << ")";
}
Tensor& insert(std::string key, Tensor param) {
bool requires_grad = param.requires_grad();
return register_parameter(std::move(key), std::move(param), requires_grad);
}
Tensor pop(const std::string& key) {
torch::Tensor v = parameters_[key];
parameters_.erase(key);
return v;
}
::std::vector<std::string> keys() const {
return parameters_.keys();
}
::std::vector<torch::Tensor> values() const {
return parameters_.values();
}
Iterator begin() {
return parameters_.begin();
}
ConstIterator begin() const {
return parameters_.begin();
}
Iterator end() {
return parameters_.end();
}
ConstIterator end() const {
return parameters_.end();
}
size_t size() const noexcept {
return parameters_.size();
}
bool empty() const noexcept {
return parameters_.is_empty();
}
template <typename Container>
void update(const Container& container) {
for (auto& item : container) {
parameters_[item.key()] = item.value();
}
}
void clear() {
parameters_.clear();
}
bool contains(const std::string& key) const noexcept {
return parameters_.contains(key);
}
const Tensor& get(const std::string& key) const {
return parameters_[key];
}
Tensor& get(const std::string& key) {
return parameters_[key];
}
Tensor& operator[](const std::string& key) {
return parameters_[key];
}
const Tensor& operator[](const std::string& key) const {
return parameters_[key];
}
};
TORCH_MODULE(ParameterDict);
} // namespace nn
} // namespace torch