Program Listing for File transformerlayer.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/options/transformerlayer.h
)
#pragma once
#include <torch/arg.h>
#include <torch/csrc/Export.h>
#include <torch/enum.h>
#include <torch/types.h>
namespace torch {
namespace nn {
using activation_t = std::variant<
enumtype::kReLU,
enumtype::kGELU,
std::function<Tensor(const Tensor&)>>;
struct TORCH_API TransformerEncoderLayerOptions {
/* implicit */ TransformerEncoderLayerOptions(int64_t d_model, int64_t nhead);
TORCH_ARG(int64_t, d_model);
TORCH_ARG(int64_t, nhead);
TORCH_ARG(int64_t, dim_feedforward) = 2048;
TORCH_ARG(double, dropout) = 0.1;
TORCH_ARG(activation_t, activation) = torch::kReLU;
};
// ============================================================================
struct TORCH_API TransformerDecoderLayerOptions {
TransformerDecoderLayerOptions(int64_t d_model, int64_t nhead);
TORCH_ARG(int64_t, d_model);
TORCH_ARG(int64_t, nhead);
TORCH_ARG(int64_t, dim_feedforward) = 2048;
TORCH_ARG(double, dropout) = 0.1;
TORCH_ARG(activation_t, activation) = torch::kReLU;
};
} // namespace nn
} // namespace torch