Shortcuts

Program Listing for File transformer.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/transformer.h)

#pragma once

#include <torch/nn/cloneable.h>
#include <torch/nn/module.h>
#include <torch/nn/options/transformer.h>
#include <torch/nn/pimpl.h>
#include <torch/nn/modules/common.h>

#include <torch/types.h>

#include <ostream>

namespace torch {
namespace nn {

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Transformer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// NOLINTNEXTLINE(bugprone-exception-escape)
class TORCH_API TransformerImpl : public Cloneable<TransformerImpl> {

  public:
    explicit TransformerImpl(TransformerOptions options_);


    Tensor forward(
      const Tensor& src,
      const Tensor& tgt,
      const Tensor& src_mask = {},
      const Tensor& tgt_mask = {},
      const Tensor& memory_mask = {},
      const Tensor& src_key_padding_mask = {},
      const Tensor& tgt_key_padding_mask = {},
      const Tensor& memory_key_padding_mask = {});

    void reset() override;

    void reset_parameters();

    static Tensor generate_square_subsequent_mask(int64_t sz);

  protected:
    FORWARD_HAS_DEFAULT_ARGS(
      {2, AnyValue(Tensor())},
      {3, AnyValue(Tensor())},
      {4, AnyValue(Tensor())},
      {5, AnyValue(Tensor())},
      {6, AnyValue(Tensor())},
      {7, AnyValue(Tensor())})

  public:
    TransformerOptions options;

    AnyModule encoder;

    AnyModule decoder;
};

TORCH_MODULE(Transformer);

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources