Shortcuts

Program Listing for File transformercoder.h

Return to documentation for file (torch/csrc/api/include/torch/nn/options/transformercoder.h)

#pragma once

#include <torch/arg.h>
#include <torch/csrc/Export.h>
#include <torch/enum.h>
#include <torch/types.h>

#include <torch/nn/modules/container/any.h>
#include <torch/nn/modules/transformerlayer.h>

namespace torch {
namespace nn {

struct TORCH_API TransformerEncoderOptions {
  // This constructor will keep a shallow copy of encoder_layer, so it keeps all
  // the data in encoder_layer.
  TransformerEncoderOptions(
      TransformerEncoderLayer encoder_layer,
      int64_t num_layers);
  // This constructor will create a new TransformerEncoderLayer obj based on
  // passed in encoder_layer_options.
  TransformerEncoderOptions(
      const TransformerEncoderLayerOptions& encoder_layer_options,
      int64_t num_layers);

  TORCH_ARG(TransformerEncoderLayer, encoder_layer) = nullptr;

  TORCH_ARG(int64_t, num_layers);

  TORCH_ARG(AnyModule, norm);
};

struct TORCH_API TransformerDecoderOptions {
  // This constructor will keep the a ref of passed in decoder_layer,
  // so it keeps all the data in decoder_layer.
  TransformerDecoderOptions(
      TransformerDecoderLayer decoder_layer,
      int64_t num_layers);
  // This constructor will create a new TransformerDecoderLayer obj,
  // based on passed in decoder_layer_options.
  TransformerDecoderOptions(
      const TransformerDecoderLayerOptions& decoder_layer_options,
      int64_t num_layers);

  TORCH_ARG(TransformerDecoderLayer, decoder_layer) = nullptr;

  TORCH_ARG(int64_t, num_layers);

  TORCH_ARG(AnyModule, norm);
};

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources