Class Module¶
Defined in File module.h
Page Contents
Inheritance Relationships¶
Base Type¶
public std::enable_shared_from_this< Module >
Derived Types¶
public torch::nn::Cloneable< AdaptiveAvgPool1dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< AdaptiveAvgPool2dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< AdaptiveAvgPool3dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< AdaptiveLogSoftmaxWithLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< AdaptiveMaxPool1dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< AdaptiveMaxPool2dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< AdaptiveMaxPool3dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< AlphaDropoutImpl >
(Template Class Cloneable)public torch::nn::Cloneable< AvgPool1dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< AvgPool2dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< AvgPool3dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< BatchNorm1dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< BatchNorm2dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< BatchNorm3dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< BCELossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< BCEWithLogitsLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< BilinearImpl >
(Template Class Cloneable)public torch::nn::Cloneable< CELUImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ConstantPad1dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ConstantPad2dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ConstantPad3dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< Conv1dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< Conv2dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< Conv3dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ConvTranspose1dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ConvTranspose2dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ConvTranspose3dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< CosineEmbeddingLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< CosineSimilarityImpl >
(Template Class Cloneable)public torch::nn::Cloneable< CrossEntropyLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< CrossMapLRN2dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< CTCLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< Dropout2dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< Dropout3dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< DropoutImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ELUImpl >
(Template Class Cloneable)public torch::nn::Cloneable< EmbeddingBagImpl >
(Template Class Cloneable)public torch::nn::Cloneable< EmbeddingImpl >
(Template Class Cloneable)public torch::nn::Cloneable< FeatureAlphaDropoutImpl >
(Template Class Cloneable)public torch::nn::Cloneable< FlattenImpl >
(Template Class Cloneable)public torch::nn::Cloneable< FoldImpl >
(Template Class Cloneable)public torch::nn::Cloneable< FractionalMaxPool2dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< FractionalMaxPool3dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< FunctionalImpl >
(Template Class Cloneable)public torch::nn::Cloneable< GELUImpl >
(Template Class Cloneable)public torch::nn::Cloneable< GLUImpl >
(Template Class Cloneable)public torch::nn::Cloneable< GroupNormImpl >
(Template Class Cloneable)public torch::nn::Cloneable< GRUCellImpl >
(Template Class Cloneable)public torch::nn::Cloneable< GRUImpl >
(Template Class Cloneable)public torch::nn::Cloneable< HardshrinkImpl >
(Template Class Cloneable)public torch::nn::Cloneable< HardtanhImpl >
(Template Class Cloneable)public torch::nn::Cloneable< HingeEmbeddingLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< HuberLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< IdentityImpl >
(Template Class Cloneable)public torch::nn::Cloneable< InstanceNorm1dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< InstanceNorm2dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< InstanceNorm3dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< KLDivLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< L1LossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< LayerNormImpl >
(Template Class Cloneable)public torch::nn::Cloneable< LeakyReLUImpl >
(Template Class Cloneable)public torch::nn::Cloneable< LinearImpl >
(Template Class Cloneable)public torch::nn::Cloneable< LocalResponseNormImpl >
(Template Class Cloneable)public torch::nn::Cloneable< LogSigmoidImpl >
(Template Class Cloneable)public torch::nn::Cloneable< LogSoftmaxImpl >
(Template Class Cloneable)public torch::nn::Cloneable< LPPool1dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< LPPool2dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< LSTMCellImpl >
(Template Class Cloneable)public torch::nn::Cloneable< LSTMImpl >
(Template Class Cloneable)public torch::nn::Cloneable< MarginRankingLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< MaxPool1dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< MaxPool2dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< MaxPool3dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< MaxUnpool1dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< MaxUnpool2dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< MaxUnpool3dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< MishImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ModuleDictImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ModuleListImpl >
(Template Class Cloneable)public torch::nn::Cloneable< MSELossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< MultiheadAttentionImpl >
(Template Class Cloneable)public torch::nn::Cloneable< MultiLabelMarginLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< MultiLabelSoftMarginLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< MultiMarginLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< NLLLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< PairwiseDistanceImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ParameterDictImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ParameterListImpl >
(Template Class Cloneable)public torch::nn::Cloneable< PixelShuffleImpl >
(Template Class Cloneable)public torch::nn::Cloneable< PixelUnshuffleImpl >
(Template Class Cloneable)public torch::nn::Cloneable< PoissonNLLLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< PReLUImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ReflectionPad1dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ReflectionPad2dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ReflectionPad3dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ReLU6Impl >
(Template Class Cloneable)public torch::nn::Cloneable< ReLUImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ReplicationPad1dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ReplicationPad2dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ReplicationPad3dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< RNNCellImpl >
(Template Class Cloneable)public torch::nn::Cloneable< RNNImpl >
(Template Class Cloneable)public torch::nn::Cloneable< RReLUImpl >
(Template Class Cloneable)public torch::nn::Cloneable< SELUImpl >
(Template Class Cloneable)public torch::nn::Cloneable< SequentialImpl >
(Template Class Cloneable)public torch::nn::Cloneable< SigmoidImpl >
(Template Class Cloneable)public torch::nn::Cloneable< SiLUImpl >
(Template Class Cloneable)public torch::nn::Cloneable< SmoothL1LossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< SoftMarginLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< Softmax2dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< SoftmaxImpl >
(Template Class Cloneable)public torch::nn::Cloneable< SoftminImpl >
(Template Class Cloneable)public torch::nn::Cloneable< SoftplusImpl >
(Template Class Cloneable)public torch::nn::Cloneable< SoftshrinkImpl >
(Template Class Cloneable)public torch::nn::Cloneable< SoftsignImpl >
(Template Class Cloneable)public torch::nn::Cloneable< TanhImpl >
(Template Class Cloneable)public torch::nn::Cloneable< TanhshrinkImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ThresholdImpl >
(Template Class Cloneable)public torch::nn::Cloneable< TransformerDecoderImpl >
(Template Class Cloneable)public torch::nn::Cloneable< TransformerDecoderLayerImpl >
(Template Class Cloneable)public torch::nn::Cloneable< TransformerEncoderImpl >
(Template Class Cloneable)public torch::nn::Cloneable< TransformerEncoderLayerImpl >
(Template Class Cloneable)public torch::nn::Cloneable< TransformerImpl >
(Template Class Cloneable)public torch::nn::Cloneable< TripletMarginLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< TripletMarginWithDistanceLossImpl >
(Template Class Cloneable)public torch::nn::Cloneable< UnflattenImpl >
(Template Class Cloneable)public torch::nn::Cloneable< UnfoldImpl >
(Template Class Cloneable)public torch::nn::Cloneable< UpsampleImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ZeroPad1dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ZeroPad2dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< ZeroPad3dImpl >
(Template Class Cloneable)public torch::nn::Cloneable< Derived >
(Template Class Cloneable)
Class Documentation¶
-
class
torch::nn
::
Module
: public std::enable_shared_from_this<Module>¶ The base class for all modules in PyTorch.
Note
The design and implementation of this class is largely based on the Python API. You may want to consult the python documentation for
torch.nn.Module
for further clarification on certain methods or behavior.A
Module
is an abstraction over the implementation of some function or algorithm, possibly associated with some persistent data. AModule
may contain furtherModule
s (“submodules”), each with their own implementation, persistent data and further submodules.Module
s can thus be said to form a recursive tree structure. AModule
is registered as a submodule to anotherModule
by callingregister_module()
, typically from within a parent module’s constructor.A distinction is made between three kinds of persistent data that may be associated with a
Module
:Parameters: tensors that record gradients, typically weights updated during the backward step (e.g. the
weight
of aLinear
module),Buffers: tensors that do not record gradients, typically updated during the forward step, such as running statistics (e.g.
mean
andvariance
in theBatchNorm
module),Any additional state, not necessarily tensors, required for the implementation or configuration of a
Module
.
The first two kinds of state are special in that they may be registered with the
Module
system to allow convenient access and batch configuration. For example, registered parameters in anyModule
may be iterated over via theparameters()
accessor. Further, changing the data type of aModule
’s registered parameters can be done conveniently viaModule::to()
, e.g.module->to(torch::kCUDA)
to move all parameters to GPU memory. Lastly, registered parameters and buffers are handled specially during aclone()
operation, which performs a deepcopy of a cloneableModule
hierarchy.Parameters are registered with a
Module
viaregister_parameter
. Buffers are registered separately viaregister_buffer
. These methods are part of the public API ofModule
and are typically invoked from within a concreteModule
s constructor.Subclassed by torch::nn::Cloneable< AdaptiveAvgPool1dImpl >, torch::nn::Cloneable< AdaptiveAvgPool2dImpl >, torch::nn::Cloneable< AdaptiveAvgPool3dImpl >, torch::nn::Cloneable< AdaptiveLogSoftmaxWithLossImpl >, torch::nn::Cloneable< AdaptiveMaxPool1dImpl >, torch::nn::Cloneable< AdaptiveMaxPool2dImpl >, torch::nn::Cloneable< AdaptiveMaxPool3dImpl >, torch::nn::Cloneable< AlphaDropoutImpl >, torch::nn::Cloneable< AvgPool1dImpl >, torch::nn::Cloneable< AvgPool2dImpl >, torch::nn::Cloneable< AvgPool3dImpl >, torch::nn::Cloneable< BatchNorm1dImpl >, torch::nn::Cloneable< BatchNorm2dImpl >, torch::nn::Cloneable< BatchNorm3dImpl >, torch::nn::Cloneable< BCELossImpl >, torch::nn::Cloneable< BCEWithLogitsLossImpl >, torch::nn::Cloneable< BilinearImpl >, torch::nn::Cloneable< CELUImpl >, torch::nn::Cloneable< ConstantPad1dImpl >, torch::nn::Cloneable< ConstantPad2dImpl >, torch::nn::Cloneable< ConstantPad3dImpl >, torch::nn::Cloneable< Conv1dImpl >, torch::nn::Cloneable< Conv2dImpl >, torch::nn::Cloneable< Conv3dImpl >, torch::nn::Cloneable< ConvTranspose1dImpl >, torch::nn::Cloneable< ConvTranspose2dImpl >, torch::nn::Cloneable< ConvTranspose3dImpl >, torch::nn::Cloneable< CosineEmbeddingLossImpl >, torch::nn::Cloneable< CosineSimilarityImpl >, torch::nn::Cloneable< CrossEntropyLossImpl >, torch::nn::Cloneable< CrossMapLRN2dImpl >, torch::nn::Cloneable< CTCLossImpl >, torch::nn::Cloneable< Dropout2dImpl >, torch::nn::Cloneable< Dropout3dImpl >, torch::nn::Cloneable< DropoutImpl >, torch::nn::Cloneable< ELUImpl >, torch::nn::Cloneable< EmbeddingBagImpl >, torch::nn::Cloneable< EmbeddingImpl >, torch::nn::Cloneable< FeatureAlphaDropoutImpl >, torch::nn::Cloneable< FlattenImpl >, torch::nn::Cloneable< FoldImpl >, torch::nn::Cloneable< FractionalMaxPool2dImpl >, torch::nn::Cloneable< FractionalMaxPool3dImpl >, torch::nn::Cloneable< FunctionalImpl >, torch::nn::Cloneable< GELUImpl >, torch::nn::Cloneable< GLUImpl >, torch::nn::Cloneable< GroupNormImpl >, torch::nn::Cloneable< GRUCellImpl >, torch::nn::Cloneable< GRUImpl >, torch::nn::Cloneable< HardshrinkImpl >, torch::nn::Cloneable< HardtanhImpl >, torch::nn::Cloneable< HingeEmbeddingLossImpl >, torch::nn::Cloneable< HuberLossImpl >, torch::nn::Cloneable< IdentityImpl >, torch::nn::Cloneable< InstanceNorm1dImpl >, torch::nn::Cloneable< InstanceNorm2dImpl >, torch::nn::Cloneable< InstanceNorm3dImpl >, torch::nn::Cloneable< KLDivLossImpl >, torch::nn::Cloneable< L1LossImpl >, torch::nn::Cloneable< LayerNormImpl >, torch::nn::Cloneable< LeakyReLUImpl >, torch::nn::Cloneable< LinearImpl >, torch::nn::Cloneable< LocalResponseNormImpl >, torch::nn::Cloneable< LogSigmoidImpl >, torch::nn::Cloneable< LogSoftmaxImpl >, torch::nn::Cloneable< LPPool1dImpl >, torch::nn::Cloneable< LPPool2dImpl >, torch::nn::Cloneable< LSTMCellImpl >, torch::nn::Cloneable< LSTMImpl >, torch::nn::Cloneable< MarginRankingLossImpl >, torch::nn::Cloneable< MaxPool1dImpl >, torch::nn::Cloneable< MaxPool2dImpl >, torch::nn::Cloneable< MaxPool3dImpl >, torch::nn::Cloneable< MaxUnpool1dImpl >, torch::nn::Cloneable< MaxUnpool2dImpl >, torch::nn::Cloneable< MaxUnpool3dImpl >, torch::nn::Cloneable< MishImpl >, torch::nn::Cloneable< ModuleDictImpl >, torch::nn::Cloneable< ModuleListImpl >, torch::nn::Cloneable< MSELossImpl >, torch::nn::Cloneable< MultiheadAttentionImpl >, torch::nn::Cloneable< MultiLabelMarginLossImpl >, torch::nn::Cloneable< MultiLabelSoftMarginLossImpl >, torch::nn::Cloneable< MultiMarginLossImpl >, torch::nn::Cloneable< NLLLossImpl >, torch::nn::Cloneable< PairwiseDistanceImpl >, torch::nn::Cloneable< ParameterDictImpl >, torch::nn::Cloneable< ParameterListImpl >, torch::nn::Cloneable< PixelShuffleImpl >, torch::nn::Cloneable< PixelUnshuffleImpl >, torch::nn::Cloneable< PoissonNLLLossImpl >, torch::nn::Cloneable< PReLUImpl >, torch::nn::Cloneable< ReflectionPad1dImpl >, torch::nn::Cloneable< ReflectionPad2dImpl >, torch::nn::Cloneable< ReflectionPad3dImpl >, torch::nn::Cloneable< ReLU6Impl >, torch::nn::Cloneable< ReLUImpl >, torch::nn::Cloneable< ReplicationPad1dImpl >, torch::nn::Cloneable< ReplicationPad2dImpl >, torch::nn::Cloneable< ReplicationPad3dImpl >, torch::nn::Cloneable< RNNCellImpl >, torch::nn::Cloneable< RNNImpl >, torch::nn::Cloneable< RReLUImpl >, torch::nn::Cloneable< SELUImpl >, torch::nn::Cloneable< SequentialImpl >, torch::nn::Cloneable< SigmoidImpl >, torch::nn::Cloneable< SiLUImpl >, torch::nn::Cloneable< SmoothL1LossImpl >, torch::nn::Cloneable< SoftMarginLossImpl >, torch::nn::Cloneable< Softmax2dImpl >, torch::nn::Cloneable< SoftmaxImpl >, torch::nn::Cloneable< SoftminImpl >, torch::nn::Cloneable< SoftplusImpl >, torch::nn::Cloneable< SoftshrinkImpl >, torch::nn::Cloneable< SoftsignImpl >, torch::nn::Cloneable< TanhImpl >, torch::nn::Cloneable< TanhshrinkImpl >, torch::nn::Cloneable< ThresholdImpl >, torch::nn::Cloneable< TransformerDecoderImpl >, torch::nn::Cloneable< TransformerDecoderLayerImpl >, torch::nn::Cloneable< TransformerEncoderImpl >, torch::nn::Cloneable< TransformerEncoderLayerImpl >, torch::nn::Cloneable< TransformerImpl >, torch::nn::Cloneable< TripletMarginLossImpl >, torch::nn::Cloneable< TripletMarginWithDistanceLossImpl >, torch::nn::Cloneable< UnflattenImpl >, torch::nn::Cloneable< UnfoldImpl >, torch::nn::Cloneable< UpsampleImpl >, torch::nn::Cloneable< ZeroPad1dImpl >, torch::nn::Cloneable< ZeroPad2dImpl >, torch::nn::Cloneable< ZeroPad3dImpl >, torch::nn::Cloneable< Derived >
Public Types
Public Functions
-
Module
()¶ Constructs the module without immediate knowledge of the submodule’s name.
The name of the submodule is inferred via RTTI (if possible) the first time
.name()
is invoked.
-
~Module
() = default¶
-
const std::string &
name
() const noexcept¶ Returns the name of the
Module
.A
Module
has an associatedname
, which is a string representation of the kind of concreteModule
it represents, such as"Linear"
for theLinear
module. Under most circumstances, this name is automatically inferred via runtime type information (RTTI). In the unusual circumstance that you have this feature disabled, you may want to manually name yourModule
s by passing the string name to theModule
base class’ constructor.
-
std::shared_ptr<Module>
clone
(const optional<Device> &device = nullopt) const¶ Performs a recursive deep copy of the module and all its registered parameters, buffers and submodules.
Optionally, this method sets the current device to the one supplied before cloning. If no device is given, each parameter and buffer will be moved to the device of its source.
Attention
Attempting to call the clone() method inherited from the base Module class (the one documented here) will fail. To inherit an actual implementation of clone(), you must subclass Cloneable. Cloneable is templatized on the concrete module type, and can thus properly copy a Module. This method is provided on the base class’ API solely for an easier-to-use polymorphic interface.
-
void
apply
(const ModuleApplyFunction &function)¶ Applies the
function
to theModule
and recursively to every submodule.The function must accept a
Module&
.
-
void
apply
(const ConstModuleApplyFunction &function) const¶ Applies the
function
to theModule
and recursively to every submodule.The function must accept a
const Module&
.
-
void
apply
(const NamedModuleApplyFunction &function, const std::string &name_prefix = std::string())¶ Applies the
function
to theModule
and recursively to every submodule.The function must accept a
const std::string&
for the key of the module, and aModule&
. The key of the module itself is the empty string. Ifname_prefix
is given, it is prepended to every key as<name_prefix>.<key>
(and justname_prefix
for the module itself).
-
void
apply
(const ConstNamedModuleApplyFunction &function, const std::string &name_prefix = std::string()) const¶ Applies the
function
to theModule
and recursively to every submodule.The function must accept a
const std::string&
for the key of the module, and aconst Module&
. The key of the module itself is the empty string. Ifname_prefix
is given, it is prepended to every key as<name_prefix>.<key>
(and justname_prefix
for the module itself).
-
void
apply
(const ModulePointerApplyFunction &function) const¶ Applies the
function
to theModule
and recursively to every submodule.The function must accept a
const std::shared_ptr<Module>&
.
-
void
apply
(const NamedModulePointerApplyFunction &function, const std::string &name_prefix = std::string()) const¶ Applies the
function
to theModule
and recursively to every submodule.The function must accept a
const std::string&
for the key of the module, and aconst std::shared_ptr<Module>&
. The key of the module itself is the empty string. Ifname_prefix
is given, it is prepended to every key as<name_prefix>.<key>
(and justname_prefix
for the module itself).
-
std::vector<Tensor>
parameters
(bool recurse = true) const¶ Returns the parameters of this
Module
and ifrecurse
is true, also recursively of every submodule.
-
OrderedDict<std::string, Tensor>
named_parameters
(bool recurse = true) const¶ Returns an
OrderedDict
with the parameters of thisModule
along with their keys, and ifrecurse
is true also recursively of every submodule.
-
std::vector<Tensor>
buffers
(bool recurse = true) const¶ Returns the buffers of this
Module
and ifrecurse
is true, also recursively of every submodule.
-
OrderedDict<std::string, Tensor>
named_buffers
(bool recurse = true) const¶ Returns an
OrderedDict
with the buffers of thisModule
along with their keys, and ifrecurse
is true also recursively of every submodule.
-
std::vector<std::shared_ptr<Module>>
modules
(bool include_self = true) const¶ Returns the submodules of this
Module
(the entire submodule hierarchy) and ifinclude_self
is true, also inserts ashared_ptr
to this module in the first position.Warning
Only pass include_self as true if this Module is stored in a shared_ptr! Otherwise an exception will be thrown. You may still call this method with include_self set to false if your Module is not stored in a shared_ptr.
-
OrderedDict<std::string, std::shared_ptr<Module>>
named_modules
(const std::string &name_prefix = std::string(), bool include_self = true) const¶ Returns an
OrderedDict
of the submodules of thisModule
(the entire submodule hierarchy) and their keys, and ifinclude_self
is true, also inserts ashared_ptr
to this module in the first position.If
name_prefix
is given, it is prepended to every key as<name_prefix>.<key>
(and justname_prefix
for the module itself).Warning
Only pass include_self as true if this Module is stored in a shared_ptr! Otherwise an exception will be thrown. You may still call this method with include_self set to false if your Module is not stored in a shared_ptr.
-
std::vector<std::shared_ptr<Module>>
children
() const¶ Returns the direct submodules of this
Module
.
-
OrderedDict<std::string, std::shared_ptr<Module>>
named_children
() const¶ Returns an
OrderedDict
of the direct submodules of thisModule
and their keys.
-
void
train
(bool on = true)¶ Enables “training” mode.
-
void
eval
()¶ Calls train(false) to enable “eval” mode.
Do not override this method, override
train()
instead.
-
bool
is_training
() const noexcept¶ True if the module is in training mode.
Every
Module
has a boolean associated with it that determines whether theModule
is currently in training mode (set via.train()
) or in evaluation (inference) mode (set via.eval()
). This property is exposed viais_training()
, and may be used by the implementation of a concrete module to modify its runtime behavior. See theBatchNorm
orDropout
modules for examples ofModule
s that use different code paths depending on this property.
-
void
to
(torch::Device device, torch::Dtype dtype, bool non_blocking = false)¶ Recursively casts all parameters to the given
dtype
anddevice
.If
non_blocking
is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
-
void
to
(torch::Dtype dtype, bool non_blocking = false)¶ Recursively casts all parameters to the given dtype.
If
non_blocking
is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
-
void
to
(torch::Device device, bool non_blocking = false)¶ Recursively moves all parameters to the given device.
If
non_blocking
is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
-
void
zero_grad
(bool set_to_none = true)¶ Recursively zeros out the
grad
value of each registered parameter.
-
template<typename
ModuleType
>
ModuleType::ContainedType *as
() noexcept¶ Attempts to cast this
Module
to the givenModuleType
.This method is useful when calling
apply()
.void initialize_weights(nn::Module& module) { torch::NoGradGuard no_grad; if (auto* linear = module.as<nn::Linear>()) { linear->weight.normal_(0.0, 0.02); } } MyModule module; module->apply(initialize_weights);
-
template<typename
ModuleType
>
const ModuleType::ContainedType *as
() const noexcept¶ Attempts to cast this
Module
to the givenModuleType
.This method is useful when calling
apply()
.
-
template<typename
ModuleType
, typename = torch::detail::disable_if_module_holder_t<ModuleType>>
ModuleType *as
() noexcept¶ Attempts to cast this
Module
to the givenModuleType
.This method is useful when calling
apply()
.void initialize_weights(nn::Module& module) { torch::NoGradGuard no_grad; if (auto* linear = module.as<nn::Linear>()) { linear->weight.normal_(0.0, 0.02); } } MyModule module; module.apply(initialize_weights);
-
template<typename
ModuleType
, typename = torch::detail::disable_if_module_holder_t<ModuleType>>
const ModuleType *as
() const noexcept¶ Attempts to cast this
Module
to the givenModuleType
.This method is useful when calling
apply()
.void initialize_weights(nn::Module& module) { torch::NoGradGuard no_grad; if (auto* linear = module.as<nn::Linear>()) { linear->weight.normal_(0.0, 0.02); } } MyModule module; module.apply(initialize_weights);
-
void
save
(serialize::OutputArchive &archive) const¶ Serializes the
Module
into the givenOutputArchive
.If the
Module
contains unserializable submodules (e.g.nn::Functional
), those submodules are skipped when serializing.
-
void
load
(serialize::InputArchive &archive)¶ Deserializes the
Module
from the givenInputArchive
.If the
Module
contains unserializable submodules (e.g.nn::Functional
), we don’t check the existence of those submodules in theInputArchive
when deserializing.
-
void
pretty_print
(std::ostream &stream) const¶ Streams a pretty representation of the
Module
into the givenstream
.By default, this representation will be the name of the module (taken from
name()
), followed by a recursive pretty print of all of theModule
’s submodules.Override this method to change the pretty print. The input
stream
should be returned from the method, to allow easy chaining.
-
Tensor &
register_parameter
(std::string name, Tensor tensor, bool requires_grad = true)¶ Registers a parameter with this
Module
.A parameter should be any gradient-recording tensor used in the implementation of your
Module
. Registering it makes it available to methods such asparameters()
,clone()
orto().
Note that registering an undefined Tensor (e.g.
module.register_parameter("param", Tensor())
) is allowed, and is equivalent tomodule.register_parameter("param", None)
in Python API.MyModule::MyModule() { weight_ = register_parameter("weight", torch::randn({A, B})); }
-
Tensor &
register_buffer
(std::string name, Tensor tensor)¶ Registers a buffer with this
Module
.A buffer is intended to be state in your module that does not record gradients, such as running statistics. Registering it makes it available to methods such as
buffers()
,clone()
or `to().MyModule::MyModule() { mean_ = register_buffer("mean", torch::empty({num_features_})); }
Registers a submodule with this
Module
.Registering a module makes it available to methods such as
modules()
,clone()
orto()
.MyModule::MyModule() { submodule_ = register_module("linear", torch::nn::Linear(3, 4)); }
Registers a submodule with this
Module
.This method deals with
ModuleHolder
s.Registering a module makes it available to methods such as
modules()
,clone()
orto()
.MyModule::MyModule() { submodule_ = register_module("linear", torch::nn::Linear(3, 4)); }
Replaces a registered submodule with this
Module
.This takes care of the registration, if you used submodule members, you should module->submodule_ = module->replace_module(“linear”, torch::nn::Linear(3, 4)); It only works when a module of the name is already registered.
This is useful for replacing a module after initialization, e.g. for finetuning.
Replaces a registered submodule with this
Module
.This method deals with
ModuleHolder
s.This takes care of the registration, if you used submodule members, you should module->submodule_ = module->replace_module(“linear”, linear_holder); It only works when a module of the name is already registered.
This is useful for replacing a module after initialization, e.g. for finetuning.
Protected Functions
-
bool
_forward_has_default_args
()¶ The following three functions allow a module with default arguments in its forward method to be used in a Sequential module.
You should NEVER override these functions manually. Instead, you should use the
FORWARD_HAS_DEFAULT_ARGS
macro.
-
unsigned int
_forward_num_required_args
()¶
Protected Attributes
-
OrderedDict<std::string, Tensor>
parameters_
¶ The registered parameters of this
Module
.Inorder to access parameters_ in ParameterDict and ParameterList