Shortcuts

Program Listing for File transformer.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/transformer.h)

#pragma once

#include <torch/nn/cloneable.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/common.h>
#include <torch/nn/options/transformer.h>
#include <torch/nn/pimpl.h>

#include <torch/types.h>

#include <ostream>

namespace torch::nn {

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Transformer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API TransformerImpl : public Cloneable<TransformerImpl> {
 public:
  explicit TransformerImpl(TransformerOptions options_);

  Tensor forward(
      const Tensor& src,
      const Tensor& tgt,
      const Tensor& src_mask = {},
      const Tensor& tgt_mask = {},
      const Tensor& memory_mask = {},
      const Tensor& src_key_padding_mask = {},
      const Tensor& tgt_key_padding_mask = {},
      const Tensor& memory_key_padding_mask = {});

  void reset() override;

  void reset_parameters();

  static Tensor generate_square_subsequent_mask(int64_t sz);

 protected:
  FORWARD_HAS_DEFAULT_ARGS(
      {2, AnyValue(Tensor())},
      {3, AnyValue(Tensor())},
      {4, AnyValue(Tensor())},
      {5, AnyValue(Tensor())},
      {6, AnyValue(Tensor())},
      {7, AnyValue(Tensor())})

 public:
  TransformerOptions options;

  AnyModule encoder;

  AnyModule decoder;
};

TORCH_MODULE(Transformer);

} // namespace torch::nn

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources