Program Listing for File transformer.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/options/transformer.h
)
#pragma once
#include <torch/arg.h>
#include <torch/csrc/Export.h>
#include <torch/enum.h>
#include <torch/types.h>
#include <torch/nn/modules/container/any.h>
#include <torch/nn/options/transformerlayer.h>
namespace torch::nn {
struct TORCH_API TransformerOptions {
// The following constructors are commonly used
// Please don't add more unless it is proved as a common usage
TransformerOptions() = default;
TransformerOptions(int64_t d_model, int64_t nhead);
TransformerOptions(
int64_t d_model,
int64_t nhead,
int64_t num_encoder_layers,
int64_t num_decoder_layers);
TORCH_ARG(int64_t, d_model) = 512;
TORCH_ARG(int64_t, nhead) = 8;
TORCH_ARG(int64_t, num_encoder_layers) = 6;
TORCH_ARG(int64_t, num_decoder_layers) = 6;
TORCH_ARG(int64_t, dim_feedforward) = 2048;
TORCH_ARG(double, dropout) = 0.1;
TORCH_ARG(activation_t, activation) = torch::kReLU;
TORCH_ARG(AnyModule, custom_encoder);
TORCH_ARG(AnyModule, custom_decoder);
};
} // namespace torch::nn