Shortcuts

Program Listing for File transformerlayer.h

Return to documentation for file (torch/csrc/api/include/torch/nn/options/transformerlayer.h)

#pragma once

#include <torch/arg.h>
#include <torch/csrc/Export.h>
#include <torch/enum.h>
#include <torch/types.h>

namespace torch {
namespace nn {

using activation_t = std::variant<
    enumtype::kReLU,
    enumtype::kGELU,
    std::function<Tensor(const Tensor&)>>;

struct TORCH_API TransformerEncoderLayerOptions {
  /* implicit */ TransformerEncoderLayerOptions(int64_t d_model, int64_t nhead);

  TORCH_ARG(int64_t, d_model);

  TORCH_ARG(int64_t, nhead);

  TORCH_ARG(int64_t, dim_feedforward) = 2048;

  TORCH_ARG(double, dropout) = 0.1;

  TORCH_ARG(activation_t, activation) = torch::kReLU;
};

// ============================================================================

struct TORCH_API TransformerDecoderLayerOptions {
  TransformerDecoderLayerOptions(int64_t d_model, int64_t nhead);

  TORCH_ARG(int64_t, d_model);

  TORCH_ARG(int64_t, nhead);

  TORCH_ARG(int64_t, dim_feedforward) = 2048;

  TORCH_ARG(double, dropout) = 0.1;

  TORCH_ARG(activation_t, activation) = torch::kReLU;
};

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources