Program Listing for File transformer.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/modules/transformer.h
)
#pragma once
#include <torch/nn/cloneable.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/common.h>
#include <torch/nn/options/transformer.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>
#include <ostream>
namespace torch::nn {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Transformer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API TransformerImpl : public Cloneable<TransformerImpl> {
public:
explicit TransformerImpl(TransformerOptions options_);
Tensor forward(
const Tensor& src,
const Tensor& tgt,
const Tensor& src_mask = {},
const Tensor& tgt_mask = {},
const Tensor& memory_mask = {},
const Tensor& src_key_padding_mask = {},
const Tensor& tgt_key_padding_mask = {},
const Tensor& memory_key_padding_mask = {});
void reset() override;
void reset_parameters();
static Tensor generate_square_subsequent_mask(int64_t sz);
protected:
FORWARD_HAS_DEFAULT_ARGS(
{2, AnyValue(Tensor())},
{3, AnyValue(Tensor())},
{4, AnyValue(Tensor())},
{5, AnyValue(Tensor())},
{6, AnyValue(Tensor())},
{7, AnyValue(Tensor())})
public:
TransformerOptions options;
AnyModule encoder;
AnyModule decoder;
};
TORCH_MODULE(Transformer);
} // namespace torch::nn