Shortcuts

Class Module

Inheritance Relationships

Base Type

  • public std::enable_shared_from_this< Module >

Derived Types

Class Documentation

class torch::nn::Module : public std::enable_shared_from_this<Module>

The base class for all modules in PyTorch.

Note

The design and implementation of this class is largely based on the Python API. You may want to consult the python documentation for torch.nn.Module for further clarification on certain methods or behavior.

A Module is an abstraction over the implementation of some function or algorithm, possibly associated with some persistent data. A Module may contain further Modules (“submodules”), each with their own implementation, persistent data and further submodules. Modules can thus be said to form a recursive tree structure. A Module is registered as a submodule to another Module by calling register_module(), typically from within a parent module’s constructor.

A distinction is made between three kinds of persistent data that may be associated with a Module:

  1. Parameters: tensors that record gradients, typically weights updated during the backward step (e.g. the weight of a Linear module),

  2. Buffers: tensors that do not record gradients, typically updated during the forward step, such as running statistics (e.g. mean and variance in the BatchNorm module),

  3. Any additional state, not necessarily tensors, required for the implementation or configuration of a Module.

The first two kinds of state are special in that they may be registered with the Module system to allow convenient access and batch configuration. For example, registered parameters in any Module may be iterated over via the parameters() accessor. Further, changing the data type of a Module’s registered parameters can be done conveniently via Module::to(), e.g. module->to(torch::kCUDA) to move all parameters to GPU memory. Lastly, registered parameters and buffers are handled specially during a clone() operation, which performs a deepcopy of a cloneable Module hierarchy.

Parameters are registered with a Module via register_parameter. Buffers are registered separately via register_buffer. These methods are part of the public API of Module and are typically invoked from within a concrete Modules constructor.

Subclassed by torch::nn::Cloneable< AdaptiveAvgPool1dImpl >, torch::nn::Cloneable< AdaptiveAvgPool2dImpl >, torch::nn::Cloneable< AdaptiveAvgPool3dImpl >, torch::nn::Cloneable< AdaptiveLogSoftmaxWithLossImpl >, torch::nn::Cloneable< AdaptiveMaxPool1dImpl >, torch::nn::Cloneable< AdaptiveMaxPool2dImpl >, torch::nn::Cloneable< AdaptiveMaxPool3dImpl >, torch::nn::Cloneable< AlphaDropoutImpl >, torch::nn::Cloneable< AvgPool1dImpl >, torch::nn::Cloneable< AvgPool2dImpl >, torch::nn::Cloneable< AvgPool3dImpl >, torch::nn::Cloneable< BatchNorm1dImpl >, torch::nn::Cloneable< BatchNorm2dImpl >, torch::nn::Cloneable< BatchNorm3dImpl >, torch::nn::Cloneable< BCELossImpl >, torch::nn::Cloneable< BCEWithLogitsLossImpl >, torch::nn::Cloneable< BilinearImpl >, torch::nn::Cloneable< CELUImpl >, torch::nn::Cloneable< ConstantPad1dImpl >, torch::nn::Cloneable< ConstantPad2dImpl >, torch::nn::Cloneable< ConstantPad3dImpl >, torch::nn::Cloneable< Conv1dImpl >, torch::nn::Cloneable< Conv2dImpl >, torch::nn::Cloneable< Conv3dImpl >, torch::nn::Cloneable< ConvTranspose1dImpl >, torch::nn::Cloneable< ConvTranspose2dImpl >, torch::nn::Cloneable< ConvTranspose3dImpl >, torch::nn::Cloneable< CosineEmbeddingLossImpl >, torch::nn::Cloneable< CosineSimilarityImpl >, torch::nn::Cloneable< CrossEntropyLossImpl >, torch::nn::Cloneable< CrossMapLRN2dImpl >, torch::nn::Cloneable< CTCLossImpl >, torch::nn::Cloneable< Dropout2dImpl >, torch::nn::Cloneable< Dropout3dImpl >, torch::nn::Cloneable< DropoutImpl >, torch::nn::Cloneable< ELUImpl >, torch::nn::Cloneable< EmbeddingBagImpl >, torch::nn::Cloneable< EmbeddingImpl >, torch::nn::Cloneable< FeatureAlphaDropoutImpl >, torch::nn::Cloneable< FlattenImpl >, torch::nn::Cloneable< FoldImpl >, torch::nn::Cloneable< FractionalMaxPool2dImpl >, torch::nn::Cloneable< FractionalMaxPool3dImpl >, torch::nn::Cloneable< FunctionalImpl >, torch::nn::Cloneable< GELUImpl >, torch::nn::Cloneable< GLUImpl >, torch::nn::Cloneable< GroupNormImpl >, torch::nn::Cloneable< GRUCellImpl >, torch::nn::Cloneable< GRUImpl >, torch::nn::Cloneable< HardshrinkImpl >, torch::nn::Cloneable< HardtanhImpl >, torch::nn::Cloneable< HingeEmbeddingLossImpl >, torch::nn::Cloneable< HuberLossImpl >, torch::nn::Cloneable< IdentityImpl >, torch::nn::Cloneable< InstanceNorm1dImpl >, torch::nn::Cloneable< InstanceNorm2dImpl >, torch::nn::Cloneable< InstanceNorm3dImpl >, torch::nn::Cloneable< KLDivLossImpl >, torch::nn::Cloneable< L1LossImpl >, torch::nn::Cloneable< LayerNormImpl >, torch::nn::Cloneable< LeakyReLUImpl >, torch::nn::Cloneable< LinearImpl >, torch::nn::Cloneable< LocalResponseNormImpl >, torch::nn::Cloneable< LogSigmoidImpl >, torch::nn::Cloneable< LogSoftmaxImpl >, torch::nn::Cloneable< LPPool1dImpl >, torch::nn::Cloneable< LPPool2dImpl >, torch::nn::Cloneable< LSTMCellImpl >, torch::nn::Cloneable< LSTMImpl >, torch::nn::Cloneable< MarginRankingLossImpl >, torch::nn::Cloneable< MaxPool1dImpl >, torch::nn::Cloneable< MaxPool2dImpl >, torch::nn::Cloneable< MaxPool3dImpl >, torch::nn::Cloneable< MaxUnpool1dImpl >, torch::nn::Cloneable< MaxUnpool2dImpl >, torch::nn::Cloneable< MaxUnpool3dImpl >, torch::nn::Cloneable< MishImpl >, torch::nn::Cloneable< ModuleDictImpl >, torch::nn::Cloneable< ModuleListImpl >, torch::nn::Cloneable< MSELossImpl >, torch::nn::Cloneable< MultiheadAttentionImpl >, torch::nn::Cloneable< MultiLabelMarginLossImpl >, torch::nn::Cloneable< MultiLabelSoftMarginLossImpl >, torch::nn::Cloneable< MultiMarginLossImpl >, torch::nn::Cloneable< NLLLossImpl >, torch::nn::Cloneable< PairwiseDistanceImpl >, torch::nn::Cloneable< ParameterDictImpl >, torch::nn::Cloneable< ParameterListImpl >, torch::nn::Cloneable< PixelShuffleImpl >, torch::nn::Cloneable< PixelUnshuffleImpl >, torch::nn::Cloneable< PoissonNLLLossImpl >, torch::nn::Cloneable< PReLUImpl >, torch::nn::Cloneable< ReflectionPad1dImpl >, torch::nn::Cloneable< ReflectionPad2dImpl >, torch::nn::Cloneable< ReflectionPad3dImpl >, torch::nn::Cloneable< ReLU6Impl >, torch::nn::Cloneable< ReLUImpl >, torch::nn::Cloneable< ReplicationPad1dImpl >, torch::nn::Cloneable< ReplicationPad2dImpl >, torch::nn::Cloneable< ReplicationPad3dImpl >, torch::nn::Cloneable< RNNCellImpl >, torch::nn::Cloneable< RNNImpl >, torch::nn::Cloneable< RReLUImpl >, torch::nn::Cloneable< SELUImpl >, torch::nn::Cloneable< SequentialImpl >, torch::nn::Cloneable< SigmoidImpl >, torch::nn::Cloneable< SiLUImpl >, torch::nn::Cloneable< SmoothL1LossImpl >, torch::nn::Cloneable< SoftMarginLossImpl >, torch::nn::Cloneable< Softmax2dImpl >, torch::nn::Cloneable< SoftmaxImpl >, torch::nn::Cloneable< SoftminImpl >, torch::nn::Cloneable< SoftplusImpl >, torch::nn::Cloneable< SoftshrinkImpl >, torch::nn::Cloneable< SoftsignImpl >, torch::nn::Cloneable< TanhImpl >, torch::nn::Cloneable< TanhshrinkImpl >, torch::nn::Cloneable< ThresholdImpl >, torch::nn::Cloneable< TransformerDecoderImpl >, torch::nn::Cloneable< TransformerDecoderLayerImpl >, torch::nn::Cloneable< TransformerEncoderImpl >, torch::nn::Cloneable< TransformerEncoderLayerImpl >, torch::nn::Cloneable< TransformerImpl >, torch::nn::Cloneable< TripletMarginLossImpl >, torch::nn::Cloneable< TripletMarginWithDistanceLossImpl >, torch::nn::Cloneable< UnflattenImpl >, torch::nn::Cloneable< UnfoldImpl >, torch::nn::Cloneable< UpsampleImpl >, torch::nn::Cloneable< ZeroPad2dImpl >, torch::nn::Cloneable< Derived >

Public Types

using ModuleApplyFunction = std::function<void(Module&)>
using ConstModuleApplyFunction = std::function<void(const Module&)>
using NamedModuleApplyFunction = std::function<void(const std::string&, Module&)>
using ConstNamedModuleApplyFunction = std::function<void(const std::string&, const Module&)>
using ModulePointerApplyFunction = std::function<void(const std::shared_ptr<Module>&)>
using NamedModulePointerApplyFunction = std::function<void(const std::string&, const std::shared_ptr<Module>&)>

Public Functions

Module(std::string name)

Tells the base Module about the name of the submodule.

Module()

Constructs the module without immediate knowledge of the submodule’s name.

The name of the submodule is inferred via RTTI (if possible) the first time .name() is invoked.

~Module() = default
const std::string &name() const noexcept

Returns the name of the Module.

A Module has an associated name, which is a string representation of the kind of concrete Module it represents, such as "Linear" for the Linear module. Under most circumstances, this name is automatically inferred via runtime type information (RTTI). In the unusual circumstance that you have this feature disabled, you may want to manually name your Modules by passing the string name to the Module base class’ constructor.

std::shared_ptr<Module> clone(const optional<Device> &device = nullopt) const

Performs a recursive deep copy of the module and all its registered parameters, buffers and submodules.

Optionally, this method sets the current device to the one supplied before cloning. If no device is given, each parameter and buffer will be moved to the device of its source.

Attention

Attempting to call the clone() method inherited from the base Module class (the one documented here) will fail. To inherit an actual implementation of clone(), you must subclass Cloneable. Cloneable is templatized on the concrete module type, and can thus properly copy a Module. This method is provided on the base class’ API solely for an easier-to-use polymorphic interface.

void apply(const ModuleApplyFunction &function)

Applies the function to the Module and recursively to every submodule.

The function must accept a Module&.

void apply(const ConstModuleApplyFunction &function) const

Applies the function to the Module and recursively to every submodule.

The function must accept a const Module&.

void apply(const NamedModuleApplyFunction &function, const std::string &name_prefix = std::string())

Applies the function to the Module and recursively to every submodule.

The function must accept a const std::string& for the key of the module, and a Module&. The key of the module itself is the empty string. If name_prefix is given, it is prepended to every key as <name_prefix>.<key> (and just name_prefix for the module itself).

void apply(const ConstNamedModuleApplyFunction &function, const std::string &name_prefix = std::string()) const

Applies the function to the Module and recursively to every submodule.

The function must accept a const std::string& for the key of the module, and a const Module&. The key of the module itself is the empty string. If name_prefix is given, it is prepended to every key as <name_prefix>.<key> (and just name_prefix for the module itself).

void apply(const ModulePointerApplyFunction &function) const

Applies the function to the Module and recursively to every submodule.

The function must accept a const std::shared_ptr<Module>&.

void apply(const NamedModulePointerApplyFunction &function, const std::string &name_prefix = std::string()) const

Applies the function to the Module and recursively to every submodule.

The function must accept a const std::string& for the key of the module, and a const std::shared_ptr<Module>&. The key of the module itself is the empty string. If name_prefix is given, it is prepended to every key as <name_prefix>.<key> (and just name_prefix for the module itself).

std::vector<Tensor> parameters(bool recurse = true) const

Returns the parameters of this Module and if recurse is true, also recursively of every submodule.

OrderedDict<std::string, Tensor> named_parameters(bool recurse = true) const

Returns an OrderedDict with the parameters of this Module along with their keys, and if recurse is true also recursively of every submodule.

std::vector<Tensor> buffers(bool recurse = true) const

Returns the buffers of this Module and if recurse is true, also recursively of every submodule.

OrderedDict<std::string, Tensor> named_buffers(bool recurse = true) const

Returns an OrderedDict with the buffers of this Module along with their keys, and if recurse is true also recursively of every submodule.

std::vector<std::shared_ptr<Module>> modules(bool include_self = true) const

Returns the submodules of this Module (the entire submodule hierarchy) and if include_self is true, also inserts a shared_ptr to this module in the first position.

Warning

Only pass include_self as true if this Module is stored in a shared_ptr! Otherwise an exception will be thrown. You may still call this method with include_self set to false if your Module is not stored in a shared_ptr.

OrderedDict<std::string, std::shared_ptr<Module>> named_modules(const std::string &name_prefix = std::string(), bool include_self = true) const

Returns an OrderedDict of the submodules of this Module (the entire submodule hierarchy) and their keys, and if include_self is true, also inserts a shared_ptr to this module in the first position.

If name_prefix is given, it is prepended to every key as <name_prefix>.<key> (and just name_prefix for the module itself).

Warning

Only pass include_self as true if this Module is stored in a shared_ptr! Otherwise an exception will be thrown. You may still call this method with include_self set to false if your Module is not stored in a shared_ptr.

std::vector<std::shared_ptr<Module>> children() const

Returns the direct submodules of this Module.

OrderedDict<std::string, std::shared_ptr<Module>> named_children() const

Returns an OrderedDict of the direct submodules of this Module and their keys.

void train(bool on = true)

Enables “training” mode.

void eval()

Calls train(false) to enable “eval” mode.

Do not override this method, override train() instead.

bool is_training() const noexcept

True if the module is in training mode.

Every Module has a boolean associated with it that determines whether the Module is currently in training mode (set via .train()) or in evaluation (inference) mode (set via .eval()). This property is exposed via is_training(), and may be used by the implementation of a concrete module to modify its runtime behavior. See the BatchNorm or Dropout modules for examples of Modules that use different code paths depending on this property.

void to(torch::Device device, torch::Dtype dtype, bool non_blocking = false)

Recursively casts all parameters to the given dtype and device.

If non_blocking is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.

void to(torch::Dtype dtype, bool non_blocking = false)

Recursively casts all parameters to the given dtype.

If non_blocking is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.

void to(torch::Device device, bool non_blocking = false)

Recursively moves all parameters to the given device.

If non_blocking is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.

void zero_grad()

Recursively zeros out the grad value of each registered parameter.

template<typename ModuleType>
ModuleType::ContainedType *as() noexcept

Attempts to cast this Module to the given ModuleType.

This method is useful when calling apply().

void initialize_weights(nn::Module& module) {
  torch::NoGradGuard no_grad;
  if (auto* linear = module.as<nn::Linear>()) {
    linear->weight.normal_(0.0, 0.02);
  }
}

MyModule module;
module->apply(initialize_weights);

template<typename ModuleType>
const ModuleType::ContainedType *as() const noexcept

Attempts to cast this Module to the given ModuleType.

This method is useful when calling apply().

template<typename ModuleType, typename = torch::detail::disable_if_module_holder_t<ModuleType>>
ModuleType *as() noexcept

Attempts to cast this Module to the given ModuleType.

This method is useful when calling apply().

void initialize_weights(nn::Module& module) {
  torch::NoGradGuard no_grad;
  if (auto* linear = module.as<nn::Linear>()) {
    linear->weight.normal_(0.0, 0.02);
  }
}

MyModule module;
module.apply(initialize_weights);

template<typename ModuleType, typename = torch::detail::disable_if_module_holder_t<ModuleType>>
const ModuleType *as() const noexcept

Attempts to cast this Module to the given ModuleType.

This method is useful when calling apply().

void initialize_weights(nn::Module& module) {
  torch::NoGradGuard no_grad;
  if (auto* linear = module.as<nn::Linear>()) {
    linear->weight.normal_(0.0, 0.02);
  }
}

MyModule module;
module.apply(initialize_weights);

void save(serialize::OutputArchive &archive) const

Serializes the Module into the given OutputArchive.

If the Module contains unserializable submodules (e.g. nn::Functional), those submodules are skipped when serializing.

void load(serialize::InputArchive &archive)

Deserializes the Module from the given InputArchive.

If the Module contains unserializable submodules (e.g. nn::Functional), we don’t check the existence of those submodules in the InputArchive when deserializing.

void pretty_print(std::ostream &stream) const

Streams a pretty representation of the Module into the given stream.

By default, this representation will be the name of the module (taken from name()), followed by a recursive pretty print of all of the Module’s submodules.

Override this method to change the pretty print. The input stream should be returned from the method, to allow easy chaining.

bool is_serializable() const

Returns whether the Module is serializable.

Tensor &register_parameter(std::string name, Tensor tensor, bool requires_grad = true)

Registers a parameter with this Module.

A parameter should be any gradient-recording tensor used in the implementation of your Module. Registering it makes it available to methods such as parameters(), clone() or to().

Note that registering an undefined Tensor (e.g. module.register_parameter("param", Tensor())) is allowed, and is equivalent to module.register_parameter("param", None) in Python API.

MyModule::MyModule() {
  weight_ = register_parameter("weight", torch::randn({A, B}));
}

Tensor &register_buffer(std::string name, Tensor tensor)

Registers a buffer with this Module.

A buffer is intended to be state in your module that does not record gradients, such as running statistics. Registering it makes it available to methods such as buffers(), clone() or `to().

MyModule::MyModule() {
  mean_ = register_buffer("mean", torch::empty({num_features_}));
}

template<typename ModuleType>
std::shared_ptr<ModuleType> register_module(std::string name, std::shared_ptr<ModuleType> module)

Registers a submodule with this Module.

Registering a module makes it available to methods such as modules(), clone() or to().

MyModule::MyModule() {
  submodule_ = register_module("linear", torch::nn::Linear(3, 4));
}

template<typename ModuleType>
std::shared_ptr<ModuleType> register_module(std::string name, ModuleHolder<ModuleType> module_holder)

Registers a submodule with this Module.

This method deals with ModuleHolders.

Registering a module makes it available to methods such as modules(), clone() or to().

MyModule::MyModule() {
  submodule_ = register_module("linear", torch::nn::Linear(3, 4));
}

template<typename ModuleType>
std::shared_ptr<ModuleType> replace_module(const std::string &name, std::shared_ptr<ModuleType> module)

Replaces a registered submodule with this Module.

This takes care of the registration, if you used submodule members, you should module->submodule_ = module->replace_module(“linear”, torch::nn::Linear(3, 4)); It only works when a module of the name is already registered.

This is useful for replacing a module after initialization, e.g. for finetuning.

template<typename ModuleType>
std::shared_ptr<ModuleType> replace_module(const std::string &name, ModuleHolder<ModuleType> module_holder)

Replaces a registered submodule with this Module.

This method deals with ModuleHolders.

This takes care of the registration, if you used submodule members, you should module->submodule_ = module->replace_module(“linear”, linear_holder); It only works when a module of the name is already registered.

This is useful for replacing a module after initialization, e.g. for finetuning.

void unregister_module(const std::string &name)

Unregisters a submodule from this Module.

If there is no such module with name an exception is thrown.

Protected Functions

bool _forward_has_default_args()

The following three functions allow a module with default arguments in its forward method to be used in a Sequential module.

You should NEVER override these functions manually. Instead, you should use the FORWARD_HAS_DEFAULT_ARGS macro.

unsigned int _forward_num_required_args()
std::vector<AnyValue> _forward_populate_default_args(std::vector<AnyValue> &&arguments)

Protected Attributes

OrderedDict<std::string, Tensor> parameters_

The registered parameters of this Module.

Inorder to access parameters_ in ParameterDict and ParameterList

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources