Program Listing for File rnn.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/options/rnn.h
)
#pragma once
#include <torch/arg.h>
#include <torch/csrc/Export.h>
#include <torch/enum.h>
#include <torch/types.h>
namespace torch {
namespace nn {
namespace detail {
struct TORCH_API RNNOptionsBase {
typedef std::variant<
enumtype::kLSTM,
enumtype::kGRU,
enumtype::kRNN_TANH,
enumtype::kRNN_RELU>
rnn_options_base_mode_t;
RNNOptionsBase(
rnn_options_base_mode_t mode,
int64_t input_size,
int64_t hidden_size);
TORCH_ARG(rnn_options_base_mode_t, mode);
TORCH_ARG(int64_t, input_size);
TORCH_ARG(int64_t, hidden_size);
TORCH_ARG(int64_t, num_layers) = 1;
TORCH_ARG(bool, bias) = true;
TORCH_ARG(bool, batch_first) = false;
TORCH_ARG(double, dropout) = 0.0;
TORCH_ARG(bool, bidirectional) = false;
TORCH_ARG(int64_t, proj_size) = 0;
};
} // namespace detail
struct TORCH_API RNNOptions {
typedef std::variant<enumtype::kTanh, enumtype::kReLU> nonlinearity_t;
RNNOptions(int64_t input_size, int64_t hidden_size);
TORCH_ARG(int64_t, input_size);
TORCH_ARG(int64_t, hidden_size);
TORCH_ARG(int64_t, num_layers) = 1;
TORCH_ARG(nonlinearity_t, nonlinearity) = torch::kTanh;
TORCH_ARG(bool, bias) = true;
TORCH_ARG(bool, batch_first) = false;
TORCH_ARG(double, dropout) = 0.0;
TORCH_ARG(bool, bidirectional) = false;
};
struct TORCH_API LSTMOptions {
LSTMOptions(int64_t input_size, int64_t hidden_size);
TORCH_ARG(int64_t, input_size);
TORCH_ARG(int64_t, hidden_size);
TORCH_ARG(int64_t, num_layers) = 1;
TORCH_ARG(bool, bias) = true;
TORCH_ARG(bool, batch_first) = false;
TORCH_ARG(double, dropout) = 0.0;
TORCH_ARG(bool, bidirectional) = false;
TORCH_ARG(int64_t, proj_size) = 0;
};
struct TORCH_API GRUOptions {
GRUOptions(int64_t input_size, int64_t hidden_size);
TORCH_ARG(int64_t, input_size);
TORCH_ARG(int64_t, hidden_size);
TORCH_ARG(int64_t, num_layers) = 1;
TORCH_ARG(bool, bias) = true;
TORCH_ARG(bool, batch_first) = false;
TORCH_ARG(double, dropout) = 0.0;
TORCH_ARG(bool, bidirectional) = false;
};
namespace detail {
struct TORCH_API RNNCellOptionsBase {
RNNCellOptionsBase(
int64_t input_size,
int64_t hidden_size,
bool bias,
int64_t num_chunks);
TORCH_ARG(int64_t, input_size);
TORCH_ARG(int64_t, hidden_size);
TORCH_ARG(bool, bias);
TORCH_ARG(int64_t, num_chunks);
};
} // namespace detail
struct TORCH_API RNNCellOptions {
typedef std::variant<enumtype::kTanh, enumtype::kReLU> nonlinearity_t;
RNNCellOptions(int64_t input_size, int64_t hidden_size);
TORCH_ARG(int64_t, input_size);
TORCH_ARG(int64_t, hidden_size);
TORCH_ARG(bool, bias) = true;
TORCH_ARG(nonlinearity_t, nonlinearity) = torch::kTanh;
};
struct TORCH_API LSTMCellOptions {
LSTMCellOptions(int64_t input_size, int64_t hidden_size);
TORCH_ARG(int64_t, input_size);
TORCH_ARG(int64_t, hidden_size);
TORCH_ARG(bool, bias) = true;
};
struct TORCH_API GRUCellOptions {
GRUCellOptions(int64_t input_size, int64_t hidden_size);
TORCH_ARG(int64_t, input_size);
TORCH_ARG(int64_t, hidden_size);
TORCH_ARG(bool, bias) = true;
};
} // namespace nn
} // namespace torch