Shortcuts

Class ModuleDictImpl

Inheritance Relationships

Base Type

Class Documentation

class torch::nn::ModuleDictImpl : public torch::nn::Cloneable<ModuleDictImpl>

An OrderedDict of Modules that registers its elements by their keys.

torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
  {"linear", Linear(10, 3).ptr()},
  {"conv", Conv2d(1, 2, 3).ptr()},
  {"dropout", Dropout(0.5).ptr()},
};
torch::nn::ModuleDict dict1(ordereddict);

for (const auto &module : *dict1) {
  module->pretty_print(std::cout);
}

std::vector<std::pair<std::string, std::shared_ptr<Module>>> list = {
  {"linear", Linear(10, 3).ptr()},
  {"conv", Conv2d(1, 2, 3).ptr()},
  {"dropout", Dropout(0.5).ptr()},
};
torch::nn::ModuleDict dict2(list);

for (const auto &module : *dict2) {
  module->pretty_print(std::cout);
}

Why should you use ModuleDict instead of a simple map or OrderedDict? The value a ModuleDict provides over manually calling an ordered map of modules is that it allows treating the whole container as a single module, such that performing a transformation on the ModuleDict applies to each of the modules it stores (which are each a registered submodule of the ModuleDict). For example, calling .to(torch::kCUDA) on a ModuleDict will move each module in the map to CUDA memory. For example:

torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
  {"linear", Linear(10, 3).ptr()},
  {"conv", Conv2d(1, 2, 3).ptr()},
  {"dropout", Dropout(0.5).ptr()},
};
torch::nn::ModuleDict dict(ordereddict);

// Convert all modules to CUDA.
dict->to(torch::kCUDA);

Finally, ModuleDict provides a lightweight container API, such as allowing iteration over submodules, positional access, adding new modules from a vector of key-module pairs or an OrderedDict or another ModuleDict after construction via update.

Public Types

using Iterator = torch::OrderedDict<std::string, std::shared_ptr<Module>>::Iterator
using ConstIterator = torch::OrderedDict<std::string, std::shared_ptr<Module>>::ConstIterator

Public Functions

ModuleDictImpl() = default
ModuleDictImpl(const std::vector<std::pair<std::string, std::shared_ptr<Module>>> &modules)

Constructs the ModuleDict from a list of string-Module pairs.

ModuleDictImpl(const torch::OrderedDict<std::string, std::shared_ptr<Module>> &modules)

Constructs the ModuleDict from an OrderedDict.

std::vector<std::pair<std::string, std::shared_ptr<Module>>> items() const

Return the items in the ModuleDict.

std::vector<std::string> keys() const

Return the keys in the ModuleDict.

std::vector<std::shared_ptr<Module>> values() const

Return the values in the ModuleDict.

Iterator begin()

Return an iterator to the start of ModuleDict.

ConstIterator begin() const

Return a const iterator to the start of ModuleDict.

Iterator end()

Return an iterator to the end of ModuleDict.

ConstIterator end() const

Return a const iterator to the end of ModuleDict.

size_t size() const noexcept

Return the number of items currently stored in the ModuleDict.

bool empty() const noexcept

Return true if the ModuleDict is empty, otherwise return false.

bool contains(const std::string &key) const noexcept

Check if the centain parameter with the key in the ModuleDict.

void clear()

Remove all items from the ModuleDict.

std::shared_ptr<Module> clone(const optional<Device> &device = nullopt) const override

Special cloning function for ModuleDict because it does not use reset().

void reset() override

reset() is empty for ModuleDict, since it does not have parameters of its own.

void pretty_print(std::ostream &stream) const override

Pretty prints the ModuleDict into the given stream.

std::shared_ptr<Module> operator[](const std::string &key) const

Attempts to returns the Module associated with the given key.

Throws an exception if no such key is stored in the ModuleDict. Check contains(key) before for a non-throwing way of access.

template<typename T>
T &at(const std::string &key)

Attempts to return the module at the given key as the requested type.

Throws an exception if no such key is stored in the ModuleDict. Check contains(key) before for a non-throwing way of access.

template<typename T>
const T &at(const std::string &key) const

Attempts to return the module at the given key as the requested type.

Throws an exception if no such key is stored in the ModuleDict. Check contains(key) before for a non-throwing way of access.

std::shared_ptr<Module> pop(const std::string &key)

Removes and returns the Module associated with the given key.

Throws an exception if no such key is stored in the ModuleDict. Check contains(key) before for a non-throwing way of access.

void update(const std::vector<std::pair<std::string, std::shared_ptr<Module>>> &modules)

Updated the ModuleDict with a vector of key-module pairs.

template<typename Container>
void update(const Container &container)

Updated the ModuleDict with key-value pairs from OrderedDict or ModuleDict.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources