Program Listing for File module.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/module.h
)
#pragma once
#include <torch/nn/modules/container/any_module_holder.h>
#include <torch/nn/modules/container/any_value.h>
#include <torch/nn/pimpl.h>
#include <torch/ordered_dict.h>
#include <torch/serialize/archive.h>
#include <torch/types.h>
#include <ATen/ATen.h>
#include <functional>
#include <iosfwd>
#include <map>
#include <memory>
#include <string>
#include <type_traits>
namespace torch::nn {
class TORCH_API Module : public std::enable_shared_from_this<Module> {
public:
using ModuleApplyFunction = std::function<void(Module&)>;
using ConstModuleApplyFunction = std::function<void(const Module&)>;
using NamedModuleApplyFunction =
std::function<void(const std::string&, Module&)>;
using ConstNamedModuleApplyFunction =
std::function<void(const std::string&, const Module&)>;
using ModulePointerApplyFunction =
std::function<void(const std::shared_ptr<Module>&)>;
using NamedModulePointerApplyFunction =
std::function<void(const std::string&, const std::shared_ptr<Module>&)>;
explicit Module(std::string name);
Module();
Module(const Module&) = default;
Module& operator=(const Module&) = default;
Module(Module&&) noexcept = default;
Module& operator=(Module&&) noexcept = default;
virtual ~Module() = default;
const std::string& name() const noexcept;
virtual std::shared_ptr<Module> clone(
const std::optional<Device>& device = std::nullopt) const;
void apply(const ModuleApplyFunction& function);
void apply(const ConstModuleApplyFunction& function) const;
void apply(
const NamedModuleApplyFunction& function,
const std::string& name_prefix = std::string());
void apply(
const ConstNamedModuleApplyFunction& function,
const std::string& name_prefix = std::string()) const;
void apply(const ModulePointerApplyFunction& function) const;
void apply(
const NamedModulePointerApplyFunction& function,
const std::string& name_prefix = std::string()) const;
std::vector<Tensor> parameters(bool recurse = true) const;
OrderedDict<std::string, Tensor> named_parameters(bool recurse = true) const;
std::vector<Tensor> buffers(bool recurse = true) const;
OrderedDict<std::string, Tensor> named_buffers(bool recurse = true) const;
std::vector<std::shared_ptr<Module>> modules(bool include_self = true) const;
OrderedDict<std::string, std::shared_ptr<Module>> named_modules(
const std::string& name_prefix = std::string(),
bool include_self = true) const;
std::vector<std::shared_ptr<Module>> children() const;
OrderedDict<std::string, std::shared_ptr<Module>> named_children() const;
virtual void train(bool on = true);
void eval();
virtual bool is_training() const noexcept;
virtual void to(
torch::Device device,
torch::Dtype dtype,
bool non_blocking = false);
virtual void to(torch::Dtype dtype, bool non_blocking = false);
virtual void to(torch::Device device, bool non_blocking = false);
virtual void zero_grad(bool set_to_none = true);
template <typename ModuleType>
typename ModuleType::ContainedType* as() noexcept;
template <typename ModuleType>
const typename ModuleType::ContainedType* as() const noexcept;
template <
typename ModuleType,
typename = torch::detail::disable_if_module_holder_t<ModuleType>>
ModuleType* as() noexcept;
template <
typename ModuleType,
typename = torch::detail::disable_if_module_holder_t<ModuleType>>
const ModuleType* as() const noexcept;
virtual void save(serialize::OutputArchive& archive) const;
virtual void load(serialize::InputArchive& archive);
virtual void pretty_print(std::ostream& stream) const;
virtual bool is_serializable() const;
Tensor& register_parameter(
std::string name,
Tensor tensor,
bool requires_grad = true);
Tensor& register_buffer(std::string name, Tensor tensor);
template <typename ModuleType>
std::shared_ptr<ModuleType> register_module(
std::string name,
std::shared_ptr<ModuleType> module);
template <typename ModuleType>
std::shared_ptr<ModuleType> register_module(
std::string name,
ModuleHolder<ModuleType> module_holder);
// assign the submodule as well, i.e. use as
template <typename ModuleType>
std::shared_ptr<ModuleType> replace_module(
const std::string& name,
std::shared_ptr<ModuleType> module);
// assign the submodule as well, i.e. use as
template <typename ModuleType>
std::shared_ptr<ModuleType> replace_module(
const std::string& name,
ModuleHolder<ModuleType> module_holder);
void unregister_module(const std::string& name);
protected:
virtual bool _forward_has_default_args() {
return false;
}
virtual unsigned int _forward_num_required_args() {
TORCH_CHECK(
false,
"torch::nn::Module subclass that has default arguments in `forward` method ",
"must override `_forward_num_required_args` method. Please use ",
"`FORWARD_HAS_DEFAULT_ARGS` macro to do so.");
}
virtual std::vector<AnyValue> _forward_populate_default_args(
std::vector<AnyValue>&& arguments) {
TORCH_CHECK(
false,
"torch::nn::Module subclass that has default arguments in `forward` method ",
"must override `_forward_populate_default_args` method. Please use ",
"`FORWARD_HAS_DEFAULT_ARGS` macro to do so.");
}
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
OrderedDict<std::string, Tensor> parameters_;
private:
// Friend classes.
template <typename Derived>
friend class Cloneable;
template <typename ModuleType, typename... ArgumentTypes>
friend struct AnyModuleHolder;
TORCH_API friend std::ostream& operator<<(
std::ostream& stream,
const nn::Module& module);
// data parallel using this method to configure gradient edges during the
// replicate step.
template <typename ModuleType>
friend void replicate_grad_edges(
const std::shared_ptr<Module>& module,
const std::vector<std::shared_ptr<ModuleType>>& replicas,
const std::vector<Device>& devices);
// Private methods.
virtual void clone_(Module& other, const std::optional<Device>& device);
template <typename... Ts>
void to_impl(Ts&&... ts);
void pretty_print_recursive(
std::ostream& stream,
const std::string& indentation) const;
void apply_to_submodules(
const NamedModulePointerApplyFunction& function,
const std::string& name_prefix = std::string()) const;
std::shared_ptr<Module> shared_from_this_checked() const;
OrderedDict<std::string, Tensor> buffers_;
OrderedDict<std::string, std::shared_ptr<Module>> children_;
mutable std::optional<std::string> name_;
bool is_training_{true};
};
TORCH_API serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const std::shared_ptr<nn::Module>& module);
TORCH_API serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
const std::shared_ptr<nn::Module>& module);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ nn::Module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <typename ModuleType>
typename ModuleType::ContainedType* Module::as() noexcept {
// Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for
// `Linear`, since `LinearImpl` inherits `nn::Module`.
return as<typename ModuleType::ContainedType>();
}
template <typename ModuleType>
const typename ModuleType::ContainedType* Module::as() const noexcept {
// Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for
// `Linear`, since `LinearImpl` inherits `nn::Module`.
return as<typename ModuleType::ContainedType>();
}
template <typename ModuleType, typename>
ModuleType* Module::as() noexcept {
return dynamic_cast<ModuleType*>(this);
}
template <typename ModuleType, typename>
const ModuleType* Module::as() const noexcept {
return dynamic_cast<const ModuleType*>(this);
}
template <typename ModuleType>
std::shared_ptr<ModuleType> Module::register_module(
std::string name,
std::shared_ptr<ModuleType> module) {
TORCH_CHECK(!name.empty(), "Submodule name must not be empty");
TORCH_CHECK(
name.find('.') == std::string::npos,
"Submodule name must not contain a dot (got '",
name,
"')");
auto& base_module = children_.insert(std::move(name), std::move(module));
return std::dynamic_pointer_cast<ModuleType>(base_module);
}
template <typename ModuleType>
std::shared_ptr<ModuleType> Module::register_module(
std::string name,
ModuleHolder<ModuleType> module_holder) {
return register_module(std::move(name), module_holder.ptr());
}
template <typename ModuleType>
std::shared_ptr<ModuleType> Module::replace_module(
const std::string& name,
std::shared_ptr<ModuleType> module) {
auto& base_module = (children_[name] = std::move(module));
return std::dynamic_pointer_cast<ModuleType>(base_module);
}
template <typename ModuleType>
std::shared_ptr<ModuleType> Module::replace_module(
const std::string& name,
ModuleHolder<ModuleType> module_holder) {
return replace_module(name, module_holder.ptr());
}
template <typename... Ts>
void Module::to_impl(Ts&&... ts) {
// First call `to()` on every child module.
for (auto& child : children_) {
child.value()->to(ts...);
}
// Then move every parameter to the new dtype/device.
for (auto& parameter : named_parameters(/*recurse=*/false)) {
parameter->set_data(autograd::Variable(*parameter).to(ts...));
}
// Then move every buffer to the new dtype/device.
for (auto& buffer : named_buffers(/*recurse=*/false)) {
buffer->set_data(autograd::Variable(*buffer).to(ts...));
}
}
} // namespace torch::nn