Shortcuts

Program Listing for File module.h

Return to documentation for file (torch/csrc/api/include/torch/nn/module.h)

#pragma once

#include <torch/nn/modules/container/any_module_holder.h>
#include <torch/nn/modules/container/any_value.h>
#include <torch/nn/pimpl.h>
#include <torch/ordered_dict.h>
#include <torch/serialize/archive.h>
#include <torch/types.h>

#include <ATen/ATen.h>

#include <functional>
#include <iosfwd>
#include <map>
#include <memory>
#include <string>
#include <type_traits>

namespace torch {
namespace nn {

class TORCH_API Module : public std::enable_shared_from_this<Module> {
 public:
  using ModuleApplyFunction = std::function<void(Module&)>;
  using ConstModuleApplyFunction = std::function<void(const Module&)>;
  using NamedModuleApplyFunction =
      std::function<void(const std::string&, Module&)>;
  using ConstNamedModuleApplyFunction =
      std::function<void(const std::string&, const Module&)>;
  using ModulePointerApplyFunction =
      std::function<void(const std::shared_ptr<Module>&)>;
  using NamedModulePointerApplyFunction =
      std::function<void(const std::string&, const std::shared_ptr<Module>&)>;

  explicit Module(std::string name);

  Module();
  Module(const Module&) = default;
  Module& operator=(const Module&) = default;
  Module(Module&&) noexcept = default;
  Module& operator=(Module&&) noexcept = default;

  virtual ~Module() = default;

  const std::string& name() const noexcept;

  virtual std::shared_ptr<Module> clone(
      const std::optional<Device>& device = nullopt) const;

  void apply(const ModuleApplyFunction& function);

  void apply(const ConstModuleApplyFunction& function) const;

  void apply(
      const NamedModuleApplyFunction& function,
      const std::string& name_prefix = std::string());

  void apply(
      const ConstNamedModuleApplyFunction& function,
      const std::string& name_prefix = std::string()) const;

  void apply(const ModulePointerApplyFunction& function) const;

  void apply(
      const NamedModulePointerApplyFunction& function,
      const std::string& name_prefix = std::string()) const;

  std::vector<Tensor> parameters(bool recurse = true) const;

  OrderedDict<std::string, Tensor> named_parameters(bool recurse = true) const;

  std::vector<Tensor> buffers(bool recurse = true) const;

  OrderedDict<std::string, Tensor> named_buffers(bool recurse = true) const;

  std::vector<std::shared_ptr<Module>> modules(bool include_self = true) const;

  OrderedDict<std::string, std::shared_ptr<Module>> named_modules(
      const std::string& name_prefix = std::string(),
      bool include_self = true) const;

  std::vector<std::shared_ptr<Module>> children() const;

  OrderedDict<std::string, std::shared_ptr<Module>> named_children() const;

  virtual void train(bool on = true);

  void eval();

  virtual bool is_training() const noexcept;

  virtual void to(
      torch::Device device,
      torch::Dtype dtype,
      bool non_blocking = false);

  virtual void to(torch::Dtype dtype, bool non_blocking = false);

  virtual void to(torch::Device device, bool non_blocking = false);

  virtual void zero_grad(bool set_to_none = true);

  template <typename ModuleType>
  typename ModuleType::ContainedType* as() noexcept;

  template <typename ModuleType>
  const typename ModuleType::ContainedType* as() const noexcept;

  template <
      typename ModuleType,
      typename = torch::detail::disable_if_module_holder_t<ModuleType>>
  ModuleType* as() noexcept;

  template <
      typename ModuleType,
      typename = torch::detail::disable_if_module_holder_t<ModuleType>>
  const ModuleType* as() const noexcept;

  virtual void save(serialize::OutputArchive& archive) const;

  virtual void load(serialize::InputArchive& archive);

  virtual void pretty_print(std::ostream& stream) const;

  virtual bool is_serializable() const;

  Tensor& register_parameter(
      std::string name,
      Tensor tensor,
      bool requires_grad = true);

  Tensor& register_buffer(std::string name, Tensor tensor);

  template <typename ModuleType>
  std::shared_ptr<ModuleType> register_module(
      std::string name,
      std::shared_ptr<ModuleType> module);

  template <typename ModuleType>
  std::shared_ptr<ModuleType> register_module(
      std::string name,
      ModuleHolder<ModuleType> module_holder);

  //  assign the submodule as well, i.e. use as
  template <typename ModuleType>
  std::shared_ptr<ModuleType> replace_module(
      const std::string& name,
      std::shared_ptr<ModuleType> module);

  //  assign the submodule as well, i.e. use as
  template <typename ModuleType>
  std::shared_ptr<ModuleType> replace_module(
      const std::string& name,
      ModuleHolder<ModuleType> module_holder);

  void unregister_module(const std::string& name);

 protected:
  virtual bool _forward_has_default_args() {
    return false;
  }

  virtual unsigned int _forward_num_required_args() {
    TORCH_CHECK(
        false,
        "torch::nn::Module subclass that has default arguments in `forward` method ",
        "must override `_forward_num_required_args` method. Please use ",
        "`FORWARD_HAS_DEFAULT_ARGS` macro to do so.");
  }

  virtual std::vector<AnyValue> _forward_populate_default_args(
      std::vector<AnyValue>&& arguments) {
    TORCH_CHECK(
        false,
        "torch::nn::Module subclass that has default arguments in `forward` method ",
        "must override `_forward_populate_default_args` method. Please use ",
        "`FORWARD_HAS_DEFAULT_ARGS` macro to do so.");
  }

  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  OrderedDict<std::string, Tensor> parameters_;

 private:
  // Friend classes.

  template <typename Derived>
  friend class Cloneable;

  template <typename ModuleType, typename... ArgumentTypes>
  friend struct AnyModuleHolder;

  TORCH_API friend std::ostream& operator<<(
      std::ostream& stream,
      const nn::Module& module);

  // data parallel using this method to configure gradient edges during the
  // replicate step.
  template <typename ModuleType>
  friend void replicate_grad_edges(
      const std::shared_ptr<Module>& module,
      const std::vector<std::shared_ptr<ModuleType>>& replicas,
      const std::vector<Device>& devices);

  // Private methods.

  virtual void clone_(Module& other, const std::optional<Device>& device);

  template <typename... Ts>
  void to_impl(Ts&&... ts);

  void pretty_print_recursive(
      std::ostream& stream,
      const std::string& indentation) const;

  void apply_to_submodules(
      const NamedModulePointerApplyFunction& function,
      const std::string& name_prefix = std::string()) const;

  std::shared_ptr<Module> shared_from_this_checked() const;

  OrderedDict<std::string, Tensor> buffers_;

  OrderedDict<std::string, std::shared_ptr<Module>> children_;

  mutable std::optional<std::string> name_;

  bool is_training_{true};
};

TORCH_API serialize::OutputArchive& operator<<(
    serialize::OutputArchive& archive,
    const std::shared_ptr<nn::Module>& module);

TORCH_API serialize::InputArchive& operator>>(
    serialize::InputArchive& archive,
    const std::shared_ptr<nn::Module>& module);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ nn::Module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template <typename ModuleType>
typename ModuleType::ContainedType* Module::as() noexcept {
  // Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for
  // `Linear`, since `LinearImpl` inherits `nn::Module`.
  return as<typename ModuleType::ContainedType>();
}

template <typename ModuleType>
const typename ModuleType::ContainedType* Module::as() const noexcept {
  // Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for
  // `Linear`, since `LinearImpl` inherits `nn::Module`.
  return as<typename ModuleType::ContainedType>();
}

template <typename ModuleType, typename>
ModuleType* Module::as() noexcept {
  return dynamic_cast<ModuleType*>(this);
}

template <typename ModuleType, typename>
const ModuleType* Module::as() const noexcept {
  return dynamic_cast<const ModuleType*>(this);
}

template <typename ModuleType>
std::shared_ptr<ModuleType> Module::register_module(
    std::string name,
    std::shared_ptr<ModuleType> module) {
  TORCH_CHECK(!name.empty(), "Submodule name must not be empty");
  TORCH_CHECK(
      name.find('.') == std::string::npos,
      "Submodule name must not contain a dot (got '",
      name,
      "')");
  auto& base_module = children_.insert(std::move(name), std::move(module));
  return std::dynamic_pointer_cast<ModuleType>(base_module);
}

template <typename ModuleType>
std::shared_ptr<ModuleType> Module::register_module(
    std::string name,
    ModuleHolder<ModuleType> module_holder) {
  return register_module(std::move(name), module_holder.ptr());
}

template <typename ModuleType>
std::shared_ptr<ModuleType> Module::replace_module(
    const std::string& name,
    std::shared_ptr<ModuleType> module) {
  auto& base_module = (children_[name] = std::move(module));
  return std::dynamic_pointer_cast<ModuleType>(base_module);
}

template <typename ModuleType>
std::shared_ptr<ModuleType> Module::replace_module(
    const std::string& name,
    ModuleHolder<ModuleType> module_holder) {
  return replace_module(name, module_holder.ptr());
}

template <typename... Ts>
void Module::to_impl(Ts&&... ts) {
  // First call `to()` on every child module.
  for (auto& child : children_) {
    child.value()->to(ts...);
  }
  // Then move every parameter to the new dtype/device.
  for (auto& parameter : named_parameters(/*recurse=*/false)) {
    parameter->set_data(autograd::Variable(*parameter).to(ts...));
  }
  // Then move every buffer to the new dtype/device.
  for (auto& buffer : named_buffers(/*recurse=*/false)) {
    buffer->set_data(autograd::Variable(*buffer).to(ts...));
  }
}

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources