Program Listing for File transformercoder.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/modules/transformercoder.h
)
#pragma once
#include <torch/nn/cloneable.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/common.h>
#include <torch/nn/modules/container/any.h>
#include <torch/nn/modules/container/modulelist.h>
#include <torch/nn/options/transformercoder.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>
#include <ostream>
namespace torch {
namespace nn {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerEncoder
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API TransformerEncoderImpl
: public Cloneable<TransformerEncoderImpl> {
public:
TransformerEncoderImpl(
TransformerEncoderLayer encoder_layer,
int64_t num_layers)
: TransformerEncoderImpl(
TransformerEncoderOptions(encoder_layer, num_layers)) {}
explicit TransformerEncoderImpl(TransformerEncoderOptions options_);
Tensor forward(
const Tensor& src,
const Tensor& src_mask = {},
const Tensor& src_key_padding_mask = {});
void reset() override;
void reset_parameters();
protected:
FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}, {2, AnyValue(Tensor())})
public:
TransformerEncoderOptions options;
ModuleList layers = nullptr;
AnyModule norm;
};
TORCH_MODULE(TransformerEncoder);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerDecoder
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API TransformerDecoderImpl
: public Cloneable<TransformerDecoderImpl> {
public:
TransformerDecoderImpl(
TransformerDecoderLayer decoder_layer,
int64_t num_layers)
: TransformerDecoderImpl(
TransformerDecoderOptions(decoder_layer, num_layers)) {}
explicit TransformerDecoderImpl(TransformerDecoderOptions options_);
void reset() override;
void reset_parameters();
Tensor forward(
const Tensor& tgt,
const Tensor& memory,
const Tensor& tgt_mask = {},
const Tensor& memory_mask = {},
const Tensor& tgt_key_padding_mask = {},
const Tensor& memory_key_padding_mask = {});
TransformerDecoderOptions options;
ModuleList layers{nullptr};
AnyModule norm;
protected:
FORWARD_HAS_DEFAULT_ARGS(
{2, AnyValue(Tensor())},
{3, AnyValue(Tensor())},
{4, AnyValue(Tensor())},
{5, AnyValue(Tensor())})
};
TORCH_MODULE(TransformerDecoder);
} // namespace nn
} // namespace torch