Shortcuts

Program Listing for File transformercoder.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/transformercoder.h)

#pragma once

#include <torch/nn/cloneable.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/common.h>
#include <torch/nn/modules/container/any.h>
#include <torch/nn/modules/container/modulelist.h>
#include <torch/nn/options/transformercoder.h>
#include <torch/nn/pimpl.h>

#include <torch/types.h>

#include <ostream>

namespace torch {
namespace nn {

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerEncoder
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API TransformerEncoderImpl
    : public Cloneable<TransformerEncoderImpl> {
 public:
  TransformerEncoderImpl(
      TransformerEncoderLayer encoder_layer,
      int64_t num_layers)
      : TransformerEncoderImpl(
            TransformerEncoderOptions(encoder_layer, num_layers)) {}
  explicit TransformerEncoderImpl(TransformerEncoderOptions options_);

  Tensor forward(
      const Tensor& src,
      const Tensor& src_mask = {},
      const Tensor& src_key_padding_mask = {});

  void reset() override;

  void reset_parameters();

 protected:
  FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}, {2, AnyValue(Tensor())})

 public:
  TransformerEncoderOptions options;

  ModuleList layers = nullptr;

  AnyModule norm;
};

TORCH_MODULE(TransformerEncoder);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerDecoder
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API TransformerDecoderImpl
    : public Cloneable<TransformerDecoderImpl> {
 public:
  TransformerDecoderImpl(
      TransformerDecoderLayer decoder_layer,
      int64_t num_layers)
      : TransformerDecoderImpl(
            TransformerDecoderOptions(decoder_layer, num_layers)) {}
  explicit TransformerDecoderImpl(TransformerDecoderOptions options_);

  void reset() override;

  void reset_parameters();

  Tensor forward(
      const Tensor& tgt,
      const Tensor& memory,
      const Tensor& tgt_mask = {},
      const Tensor& memory_mask = {},
      const Tensor& tgt_key_padding_mask = {},
      const Tensor& memory_key_padding_mask = {});

  TransformerDecoderOptions options;

  ModuleList layers{nullptr};

  AnyModule norm;

 protected:
  FORWARD_HAS_DEFAULT_ARGS(
      {2, AnyValue(Tensor())},
      {3, AnyValue(Tensor())},
      {4, AnyValue(Tensor())},
      {5, AnyValue(Tensor())})
};

TORCH_MODULE(TransformerDecoder);

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources