Program Listing for File transformerlayer.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/modules/transformerlayer.h
)
#pragma once
#include <torch/nn/cloneable.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/activation.h>
#include <torch/nn/modules/common.h>
#include <torch/nn/modules/dropout.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/modules/normalization.h>
#include <torch/nn/options/transformerlayer.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>
#include <ostream>
namespace torch::nn {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerEncoderLayer
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API TransformerEncoderLayerImpl
: public Cloneable<TransformerEncoderLayerImpl> {
public:
TransformerEncoderLayerImpl(int64_t d_model, int64_t nhead)
: TransformerEncoderLayerImpl(
TransformerEncoderLayerOptions(d_model, nhead)) {}
explicit TransformerEncoderLayerImpl(TransformerEncoderLayerOptions options_);
Tensor forward(
const Tensor& src,
const Tensor& src_mask = {},
const Tensor& src_key_padding_mask = {});
void reset() override;
void reset_parameters();
protected:
FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}, {2, AnyValue(Tensor())})
public:
TransformerEncoderLayerOptions options;
MultiheadAttention self_attn = nullptr;
Linear linear1 = nullptr;
Dropout dropout = nullptr;
Linear linear2 = nullptr;
LayerNorm norm1 = nullptr;
LayerNorm norm2 = nullptr;
Dropout dropout1 = nullptr;
Dropout dropout2 = nullptr;
};
TORCH_MODULE(TransformerEncoderLayer);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerDecoderLayer
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API TransformerDecoderLayerImpl
: public Cloneable<TransformerDecoderLayerImpl> {
public:
TransformerDecoderLayerImpl(int64_t d_model, int64_t nhead)
: TransformerDecoderLayerImpl(
TransformerDecoderLayerOptions(d_model, nhead)) {}
explicit TransformerDecoderLayerImpl(TransformerDecoderLayerOptions options_);
void reset() override;
void reset_parameters();
Tensor forward(
Tensor tgt,
const Tensor& memory,
const Tensor& tgt_mask = {},
const Tensor& memory_mask = {},
const Tensor& tgt_key_padding_mask = {},
const Tensor& memory_key_padding_mask = {});
TransformerDecoderLayerOptions options;
MultiheadAttention self_attn{nullptr};
Dropout dropout1{nullptr};
LayerNorm norm1{nullptr};
MultiheadAttention multihead_attn{nullptr};
Dropout dropout2{nullptr};
LayerNorm norm2{nullptr};
Linear linear1{nullptr};
Dropout dropout{nullptr};
Linear linear2{nullptr};
Dropout dropout3{nullptr};
LayerNorm norm3{nullptr};
protected:
FORWARD_HAS_DEFAULT_ARGS(
{2, AnyValue(Tensor())},
{3, AnyValue(Tensor())},
{4, AnyValue(Tensor())},
{5, AnyValue(Tensor())})
Tensor activation(const Tensor& input);
};
TORCH_MODULE(TransformerDecoderLayer);
} // namespace torch::nn