Shortcuts

Class TransformerEncoderLayerImpl

Inheritance Relationships

Base Type

Class Documentation

class TransformerEncoderLayerImpl : public torch::nn::Cloneable<TransformerEncoderLayerImpl>

TransformerEncoderLayer module.

See https://pytorch.org/docs/main/generated/torch.nn.TransformerEncoderLayer.html to learn abouut the exact behavior of this encoder layer model

See the documentation for torch::nn::TransformerEncoderLayer class to learn what constructor arguments are supported for this encoder layer model

Example:

TransformerEncoderLayer encoderLayer(TransformerEncoderLayerOptions(512,
8).dropout(0.1));

Public Functions

inline TransformerEncoderLayerImpl(int64_t d_model, int64_t nhead)
explicit TransformerEncoderLayerImpl(TransformerEncoderLayerOptions options_)
Tensor forward(const Tensor &src, const Tensor &src_mask = {}, const Tensor &src_key_padding_mask = {})
virtual void reset() override

reset() must perform initialization of all members with reference semantics, most importantly parameters, buffers and submodules.

void reset_parameters()

Public Members

TransformerEncoderLayerOptions options

options with which this TransformerEncoderLayer was constructed

MultiheadAttention self_attn = nullptr

self attention

Linear linear1 = nullptr

feedforward first linear layer

Dropout dropout = nullptr

feedforward dropout layer

Linear linear2 = nullptr

feedforward second linear layer

LayerNorm norm1 = nullptr

pre feedforward, normalization layer

LayerNorm norm2 = nullptr

post feedfastward, normalization layer

Dropout dropout1 = nullptr

pre feedfastward, dropout layer

Dropout dropout2 = nullptr

post feedfastward, dropout layer

Protected Functions

inline virtual bool _forward_has_default_args() override

The following three functions allow a module with default arguments in its forward method to be used in a Sequential module.

You should NEVER override these functions manually. Instead, you should use the FORWARD_HAS_DEFAULT_ARGS macro.

inline virtual unsigned int _forward_num_required_args() override
inline std::vector<torch::nn::AnyValue> _forward_populate_default_args(std::vector<torch::nn::AnyValue> &&arguments) override

Friends

friend struct torch::nn::AnyModuleHolder

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources