Shortcuts

Program Listing for File ordered_dict.h

Return to documentation for file (torch/csrc/api/include/torch/ordered_dict.h)

#pragma once

#include <cstdint>
#include <initializer_list>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

namespace torch {
template <typename Key, typename Value>
class OrderedDict {
 public:
  class Item;

  // The lifetime of an iterator is bound to the lifetime of the `OrderedDict`.
  // Further, any `insert()` operation may invalidate all iterators
  // pointing into the vector.
  using Iterator = typename std::vector<Item>::iterator;
  using ConstIterator = typename std::vector<Item>::const_iterator;

  explicit OrderedDict(std::string key_description = "Key");

  OrderedDict(const OrderedDict& other);

  OrderedDict& operator=(const OrderedDict& other);

  // NB: Move works by default, because you can move-construct vectors of const
  // values. I tried to make this noexcept (conditional on the move constructors
  // of index_ and items_ being noexcept) but the obvious spelling didn't
  // compile on Windows.
  OrderedDict(OrderedDict&& other) noexcept = default;
  OrderedDict& operator=(OrderedDict&& other) noexcept = default;

  ~OrderedDict() = default;

  /*implicit */ OrderedDict(std::initializer_list<Item> initializer_list);

  const std::string& key_description() const noexcept;

  // Element Access

  Item& front();

  const Item& front() const;

  Item& back();

  const Item& back() const;

  Item& operator[](size_t index);

  const Item& operator[](size_t index) const;

  Value& operator[](const Key& key);

  const Value& operator[](const Key& key) const;

  // Lookup

  Value* find(const Key& key) noexcept;

  const Value* find(const Key& key) const noexcept;

  bool contains(const Key& key) const noexcept;

  // Iterators

  Iterator begin();

  ConstIterator begin() const;

  Iterator end();

  ConstIterator end() const;

  // Capacity

  size_t size() const noexcept;

  bool is_empty() const noexcept;

  void reserve(size_t requested_capacity);

  // Modifiers

  template <typename K, typename V>
  Value& insert(K&& key, V&& value);

  Value& insert(Key key, Value&& value);

  void update(OrderedDict&& other);

  void update(const OrderedDict& other);

  void erase(const Key& key);

  void clear();

  // Observers

  const std::vector<Item>& items() const noexcept;

  ::std::vector<Key> keys() const;

  ::std::vector<Value> values() const;

  ::std::vector<std::pair<Key, Value>> pairs() const;

  template <typename K, typename V>
  friend bool operator==(
      const OrderedDict<K, V>& a,
      const OrderedDict<K, V>& b);

 private:
  ::std::unordered_map<Key, size_t> index_;

  ::std::vector<Item> items_;

  ::std::string key_description_{"Key"};
};

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ OrderedDict::Item ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template <typename Key, typename Value>
class OrderedDict<Key, Value>::Item {
 public:
  Item(Key key, Value value) : pair_(std::move(key), std::move(value)) {}

  Value& operator*() {
    return value();
  }

  const Value& operator*() const {
    return value();
  }

  Value* operator->() {
    return &value();
  }

  const Value* operator->() const {
    return &value();
  }

  const Key& key() const noexcept {
    return pair_.first;
  }

  Value& value() noexcept {
    return pair_.second;
  }

  const Value& value() const noexcept {
    return pair_.second;
  }

  const std::pair<Key, Value>& pair() const noexcept {
    return pair_;
  }

 private:
  ::std::pair<Key, Value> pair_;
};

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ OrderedDict ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template <typename Key, typename Value>
OrderedDict<Key, Value>::OrderedDict(std::string key_description)
    : key_description_(std::move(key_description)) {}

template <typename Key, typename Value>
OrderedDict<Key, Value>::OrderedDict(const OrderedDict& other)
    : index_(other.index_), key_description_(other.key_description_) {
  // Copy we have to do ourselves, because items' keys are const, so we have to
  // re-insert the items.
  for (const auto& item : other.items_) {
    items_.push_back(item);
  }
}

template <typename Key, typename Value>
OrderedDict<Key, Value>& OrderedDict<Key, Value>::operator=(
    const OrderedDict& other) {
  index_ = other.index_;
  items_.clear();
  for (auto& item : other.items_) {
    items_.push_back(item);
  }
  key_description_ = other.key_description_;
  return *this;
}

template <typename Key, typename Value>
OrderedDict<Key, Value>::OrderedDict(
    std::initializer_list<Item> initializer_list)
    : OrderedDict("Key") {
  items_.reserve(initializer_list.size());
  for (auto& item : initializer_list) {
    // Copy the key here and move it into the index.
    items_.emplace_back(item.key(), std::move(item.value()));
    index_.emplace(std::move(item.key()), size() - 1);
  }
}

template <typename Key, typename Value>
typename OrderedDict<Key, Value>::Iterator OrderedDict<Key, Value>::begin() {
  return items_.begin();
}

template <typename Key, typename Value>
typename OrderedDict<Key, Value>::ConstIterator OrderedDict<Key, Value>::begin()
    const {
  return items_.begin();
}

template <typename Key, typename Value>
typename OrderedDict<Key, Value>::Iterator OrderedDict<Key, Value>::end() {
  return items_.end();
}

template <typename Key, typename Value>
typename OrderedDict<Key, Value>::ConstIterator OrderedDict<Key, Value>::end()
    const {
  return items_.end();
}

template <typename Key, typename Value>
typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::front() {
  TORCH_CHECK(!items_.empty(), "Called front() on an empty OrderedDict");
  return items_.front();
}

template <typename Key, typename Value>
const typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::front()
    const {
  TORCH_CHECK(!items_.empty(), "Called front() on an empty OrderedDict");
  return items_.front();
}

template <typename Key, typename Value>
typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::back() {
  TORCH_CHECK(!items_.empty(), "Called back() on an empty OrderedDict");
  return items_.back();
}

template <typename Key, typename Value>
const typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::back()
    const {
  TORCH_CHECK(!items_.empty(), "Called back() on an empty OrderedDict");
  return items_.back();
}

template <typename Key, typename Value>
typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::operator[](
    size_t index) {
  TORCH_CHECK(index < items_.size(), "Index ", index, " is out of bounds");
  return items_[index];
}

template <typename Key, typename Value>
const typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::
operator[](size_t index) const {
  TORCH_CHECK(index < items_.size(), "Index ", index, " is out of bounds");
  return items_[index];
}

template <typename Key, typename Value>
Value& OrderedDict<Key, Value>::operator[](const Key& key) {
  if (auto* value = find(key)) {
    return *value;
  }
  AT_ERROR(key_description_, " '", key, "' is not defined");
}

template <typename Key, typename Value>
const Value& OrderedDict<Key, Value>::operator[](const Key& key) const {
  if (auto* value = find(key)) {
    return *value;
  }
  AT_ERROR(key_description_, " '", key, "' is not defined");
}

template <typename Key, typename Value>
template <typename K, typename V>
Value& OrderedDict<Key, Value>::insert(K&& key, V&& value) {
  TORCH_CHECK(
      index_.count(key) == 0, key_description_, " '", key, "' already defined");
  // Copy `key` here and move it into the index.
  items_.emplace_back(key, std::forward<V>(value));
  index_.emplace(std::forward<K>(key), size() - 1);
  return items_.back().value();
}

template <typename Key, typename Value>
Value& OrderedDict<Key, Value>::insert(Key key, Value&& value) {
  return insert<Key, Value>(std::move(key), std::move(value));
}

template <typename Key, typename Value>
void OrderedDict<Key, Value>::update(OrderedDict&& other) {
  reserve(size() + other.size());
  for (auto& item : other) {
    // We want to call `insert()` to prevent duplicate keys.
    insert(std::move(item.key()), std::move(item.value()));
  }
}

template <typename Key, typename Value>
void OrderedDict<Key, Value>::update(const OrderedDict& other) {
  reserve(size() + other.size());
  for (auto& item : other) {
    // We want to call `insert()` to prevent duplicate keys.
    insert(item.key(), item.value());
  }
}

template <typename Key, typename Value>
Value* OrderedDict<Key, Value>::find(const Key& key) noexcept {
  auto iterator = index_.find(key);
  if (iterator == index_.end()) {
    return nullptr;
  }
  return &items_[iterator->second].value();
}

template <typename Key, typename Value>
const Value* OrderedDict<Key, Value>::find(const Key& key) const noexcept {
  auto iterator = index_.find(key);
  if (iterator == index_.end()) {
    return nullptr;
  }
  return &items_[iterator->second].value();
}

template <typename Key, typename Value>
void OrderedDict<Key, Value>::erase(const Key& key) {
  auto it = index_.find(key);
  TORCH_CHECK(it != index_.end(), "Key '", key, "' doesn't exist");

  auto index = it->second;
  index_.erase(it);
  items_.erase(items_.begin() + index);

  for (auto& pair : index_)
    if (pair.second > index)
      --pair.second;
}

template <typename Key, typename Value>
bool OrderedDict<Key, Value>::contains(const Key& key) const noexcept {
  return find(key) != nullptr;
}

template <typename Key, typename Value>
void OrderedDict<Key, Value>::clear() {
  index_.clear();
  items_.clear();
}

template <typename Key, typename Value>
size_t OrderedDict<Key, Value>::size() const noexcept {
  return items_.size();
}

template <typename Key, typename Value>
bool OrderedDict<Key, Value>::is_empty() const noexcept {
  return items_.empty();
}

template <typename Key, typename Value>
const std::string& OrderedDict<Key, Value>::key_description() const noexcept {
  return key_description_;
}

template <typename Key, typename Value>
const std::vector<typename OrderedDict<Key, Value>::Item>& OrderedDict<
    Key,
    Value>::items() const noexcept {
  return items_;
}

template <typename Key, typename Value>
::std::vector<Key> OrderedDict<Key, Value>::keys() const {
  std::vector<Key> keys;
  keys.reserve(size());
  for (const auto& item : items_) {
    keys.push_back(item.key());
  }
  return keys;
}

template <typename Key, typename Value>
::std::vector<Value> OrderedDict<Key, Value>::values() const {
  std::vector<Value> values;
  values.reserve(size());
  for (const auto& item : items_) {
    values.push_back(item.value());
  }
  return values;
}

template <typename Key, typename Value>
::std::vector<std::pair<Key, Value>> OrderedDict<Key, Value>::pairs() const {
  std::vector<std::pair<Key, Value>> values;
  values.reserve(size());
  for (const auto& item : items_) {
    values.push_back(item.pair());
  }
  return values;
}

template <typename Key, typename Value>
void OrderedDict<Key, Value>::reserve(size_t requested_capacity) {
  index_.reserve(requested_capacity);
  items_.reserve(requested_capacity);
}

template <typename K, typename V>
bool operator==(
    const torch::OrderedDict<K, V>& a,
    const torch::OrderedDict<K, V>& b) {
  using Item = typename torch::OrderedDict<K, V>::Item;
  if (a.index_ != b.index_)
    return false;
  if (a.items_.size() != b.items_.size())
    return false;
  // NOTE: There's no point in comparing keys for items_, as we already know
  // that index is equal.
  return std::equal(
      a.items_.begin(),
      a.items_.end(),
      b.items_.begin(),
      [](const Item& a, const Item& b) { return a.value() == b.value(); });
}

} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources