Program Listing for File transformercoder.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/options/transformercoder.h
)
#pragma once
#include <torch/arg.h>
#include <torch/csrc/Export.h>
#include <torch/enum.h>
#include <torch/types.h>
#include <torch/nn/modules/container/any.h>
#include <torch/nn/modules/transformerlayer.h>
namespace torch {
namespace nn {
struct TORCH_API TransformerEncoderOptions {
// This constructor will keep a shallow copy of encoder_layer, so it keeps all
// the data in encoder_layer.
TransformerEncoderOptions(
TransformerEncoderLayer encoder_layer,
int64_t num_layers);
// This constructor will create a new TransformerEncoderLayer obj based on
// passed in encoder_layer_options.
TransformerEncoderOptions(
const TransformerEncoderLayerOptions& encoder_layer_options,
int64_t num_layers);
TORCH_ARG(TransformerEncoderLayer, encoder_layer) = nullptr;
TORCH_ARG(int64_t, num_layers);
TORCH_ARG(AnyModule, norm);
};
struct TORCH_API TransformerDecoderOptions {
// This constructor will keep the a ref of passed in decoder_layer,
// so it keeps all the data in decoder_layer.
TransformerDecoderOptions(
TransformerDecoderLayer decoder_layer,
int64_t num_layers);
// This constructor will create a new TransformerDecoderLayer obj,
// based on passed in decoder_layer_options.
TransformerDecoderOptions(
const TransformerDecoderLayerOptions& decoder_layer_options,
int64_t num_layers);
TORCH_ARG(TransformerDecoderLayer, decoder_layer) = nullptr;
TORCH_ARG(int64_t, num_layers);
TORCH_ARG(AnyModule, norm);
};
} // namespace nn
} // namespace torch