Shortcuts

Class ModuleDictImpl

Inheritance Relationships

Base Type

Class Documentation

class ModuleDictImpl : public torch::nn::Cloneable<ModuleDictImpl>

An OrderedDict of Modules that registers its elements by their keys.

torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
  {"linear", Linear(10, 3).ptr()},
  {"conv", Conv2d(1, 2, 3).ptr()},
  {"dropout", Dropout(0.5).ptr()},
};
torch::nn::ModuleDict dict1(ordereddict);

for (const auto &module : *dict1) {
  module->pretty_print(std::cout);
}

std::vector<std::pair<std::string, std::shared_ptr<Module>>> list = {
  {"linear", Linear(10, 3).ptr()},
  {"conv", Conv2d(1, 2, 3).ptr()},
  {"dropout", Dropout(0.5).ptr()},
};
torch::nn::ModuleDict dict2(list);

for (const auto &module : *dict2) {
  module->pretty_print(std::cout);
}

Why should you use ModuleDict instead of a simple map or OrderedDict? The value a ModuleDict provides over manually calling an ordered map of modules is that it allows treating the whole container as a single module, such that performing a transformation on the ModuleDict applies to each of the modules it stores (which are each a registered submodule of the ModuleDict). For example, calling .to(torch::kCUDA) on a ModuleDict will move each module in the map to CUDA memory. For example:

torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
  {"linear", Linear(10, 3).ptr()},
  {"conv", Conv2d(1, 2, 3).ptr()},
  {"dropout", Dropout(0.5).ptr()},
};
torch::nn::ModuleDict dict(ordereddict);

// Convert all modules to CUDA.
dict->to(torch::kCUDA);

Finally, ModuleDict provides a lightweight container API, such as allowing iteration over submodules, positional access, adding new modules from a vector of key-module pairs or an OrderedDict or another ModuleDict after construction via update.

Public Types

using Iterator = torch::OrderedDict<std::string, std::shared_ptr<Module>>::Iterator
using ConstIterator = torch::OrderedDict<std::string, std::shared_ptr<Module>>::ConstIterator

Public Functions

ModuleDictImpl() = default
inline explicit ModuleDictImpl(const std::vector<std::pair<std::string, std::shared_ptr<Module>>> &modules)

Constructs the ModuleDict from a list of string-Module pairs.

inline explicit ModuleDictImpl(const torch::OrderedDict<std::string, std::shared_ptr<Module>> &modules)

Constructs the ModuleDict from an OrderedDict.

inline std::vector<std::pair<std::string, std::shared_ptr<Module>>> items() const

Return the items in the ModuleDict.

inline std::vector<std::string> keys() const

Return the keys in the ModuleDict.

inline std::vector<std::shared_ptr<Module>> values() const

Return the values in the ModuleDict.

inline Iterator begin()

Return an iterator to the start of ModuleDict.

inline ConstIterator begin() const

Return a const iterator to the start of ModuleDict.

inline Iterator end()

Return an iterator to the end of ModuleDict.

inline ConstIterator end() const

Return a const iterator to the end of ModuleDict.

inline size_t size() const noexcept

Return the number of items currently stored in the ModuleDict.

inline bool empty() const noexcept

Return true if the ModuleDict is empty, otherwise return false.

inline bool contains(const std::string &key) const noexcept

Check if the centain parameter with the key in the ModuleDict.

inline void clear()

Remove all items from the ModuleDict.

inline virtual std::shared_ptr<Module> clone(const std::optional<Device> &device = std::nullopt) const override

Special cloning function for ModuleDict because it does not use reset().

inline virtual void reset() override

reset() is empty for ModuleDict, since it does not have parameters of its own.

inline virtual void pretty_print(std::ostream &stream) const override

Pretty prints the ModuleDict into the given stream.

inline std::shared_ptr<Module> operator[](const std::string &key) const

Attempts to returns the Module associated with the given key.

Throws an exception if no such key is stored in the ModuleDict. Check contains(key) before for a non-throwing way of access.

template<typename T>
inline T &at(const std::string &key)

Attempts to return the module at the given key as the requested type.

Throws an exception if no such key is stored in the ModuleDict. Check contains(key) before for a non-throwing way of access.

template<typename T>
inline const T &at(const std::string &key) const

Attempts to return the module at the given key as the requested type.

Throws an exception if no such key is stored in the ModuleDict. Check contains(key) before for a non-throwing way of access.

inline std::shared_ptr<Module> pop(const std::string &key)

Removes and returns the Module associated with the given key.

Throws an exception if no such key is stored in the ModuleDict. Check contains(key) before for a non-throwing way of access.

inline void update(const std::vector<std::pair<std::string, std::shared_ptr<Module>>> &modules)

Updated the ModuleDict with a vector of key-module pairs.

template<typename Container>
inline void update(const Container &container)

Updated the ModuleDict with key-value pairs from OrderedDict or ModuleDict.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources