Shortcuts

Program Listing for File transformerlayer.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/transformerlayer.h)

#pragma once

#include <torch/nn/cloneable.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/activation.h>
#include <torch/nn/modules/common.h>
#include <torch/nn/modules/dropout.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/modules/normalization.h>
#include <torch/nn/options/transformerlayer.h>
#include <torch/nn/pimpl.h>

#include <torch/types.h>

#include <ostream>

namespace torch {
namespace nn {

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerEncoderLayer
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API TransformerEncoderLayerImpl
    : public Cloneable<TransformerEncoderLayerImpl> {
 public:
  TransformerEncoderLayerImpl(int64_t d_model, int64_t nhead)
      : TransformerEncoderLayerImpl(
            TransformerEncoderLayerOptions(d_model, nhead)) {}
  explicit TransformerEncoderLayerImpl(TransformerEncoderLayerOptions options_);

  Tensor forward(
      const Tensor& src,
      const Tensor& src_mask = {},
      const Tensor& src_key_padding_mask = {});

  void reset() override;

  void reset_parameters();

 protected:
  FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}, {2, AnyValue(Tensor())})

 public:
  TransformerEncoderLayerOptions options;

  MultiheadAttention self_attn = nullptr;

  Linear linear1 = nullptr;

  Dropout dropout = nullptr;

  Linear linear2 = nullptr;

  LayerNorm norm1 = nullptr;
  LayerNorm norm2 = nullptr;

  Dropout dropout1 = nullptr;
  Dropout dropout2 = nullptr;
};

TORCH_MODULE(TransformerEncoderLayer);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerDecoderLayer
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API TransformerDecoderLayerImpl
    : public Cloneable<TransformerDecoderLayerImpl> {
 public:
  TransformerDecoderLayerImpl(int64_t d_model, int64_t nhead)
      : TransformerDecoderLayerImpl(
            TransformerDecoderLayerOptions(d_model, nhead)) {}
  explicit TransformerDecoderLayerImpl(TransformerDecoderLayerOptions options_);

  void reset() override;

  void reset_parameters();

  Tensor forward(
      Tensor tgt,
      const Tensor& memory,
      const Tensor& tgt_mask = {},
      const Tensor& memory_mask = {},
      const Tensor& tgt_key_padding_mask = {},
      const Tensor& memory_key_padding_mask = {});

  TransformerDecoderLayerOptions options;

  MultiheadAttention self_attn{nullptr};

  Dropout dropout1{nullptr};

  LayerNorm norm1{nullptr};

  MultiheadAttention multihead_attn{nullptr};

  Dropout dropout2{nullptr};

  LayerNorm norm2{nullptr};

  Linear linear1{nullptr};

  Dropout dropout{nullptr};

  Linear linear2{nullptr};

  Dropout dropout3{nullptr};

  LayerNorm norm3{nullptr};

 protected:
  FORWARD_HAS_DEFAULT_ARGS(
      {2, AnyValue(Tensor())},
      {3, AnyValue(Tensor())},
      {4, AnyValue(Tensor())},
      {5, AnyValue(Tensor())})

  Tensor activation(const Tensor& input);
};

TORCH_MODULE(TransformerDecoderLayer);

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources