Shortcuts

Program Listing for File rnn.h

Return to documentation for file (torch/csrc/api/include/torch/nn/options/rnn.h)

#pragma once

#include <torch/arg.h>
#include <torch/csrc/Export.h>
#include <torch/enum.h>
#include <torch/types.h>

namespace torch::nn {

namespace detail {

struct TORCH_API RNNOptionsBase {
  typedef std::variant<
      enumtype::kLSTM,
      enumtype::kGRU,
      enumtype::kRNN_TANH,
      enumtype::kRNN_RELU>
      rnn_options_base_mode_t;

  RNNOptionsBase(
      rnn_options_base_mode_t mode,
      int64_t input_size,
      int64_t hidden_size);

  TORCH_ARG(rnn_options_base_mode_t, mode);
  TORCH_ARG(int64_t, input_size);
  TORCH_ARG(int64_t, hidden_size);
  TORCH_ARG(int64_t, num_layers) = 1;
  TORCH_ARG(bool, bias) = true;
  TORCH_ARG(bool, batch_first) = false;
  TORCH_ARG(double, dropout) = 0.0;
  TORCH_ARG(bool, bidirectional) = false;
  TORCH_ARG(int64_t, proj_size) = 0;
};

} // namespace detail

struct TORCH_API RNNOptions {
  typedef std::variant<enumtype::kTanh, enumtype::kReLU> nonlinearity_t;

  RNNOptions(int64_t input_size, int64_t hidden_size);

  TORCH_ARG(int64_t, input_size);
  TORCH_ARG(int64_t, hidden_size);
  TORCH_ARG(int64_t, num_layers) = 1;
  TORCH_ARG(nonlinearity_t, nonlinearity) = torch::kTanh;
  TORCH_ARG(bool, bias) = true;
  TORCH_ARG(bool, batch_first) = false;
  TORCH_ARG(double, dropout) = 0.0;
  TORCH_ARG(bool, bidirectional) = false;
};

struct TORCH_API LSTMOptions {
  LSTMOptions(int64_t input_size, int64_t hidden_size);

  TORCH_ARG(int64_t, input_size);
  TORCH_ARG(int64_t, hidden_size);
  TORCH_ARG(int64_t, num_layers) = 1;
  TORCH_ARG(bool, bias) = true;
  TORCH_ARG(bool, batch_first) = false;
  TORCH_ARG(double, dropout) = 0.0;
  TORCH_ARG(bool, bidirectional) = false;
  TORCH_ARG(int64_t, proj_size) = 0;
};

struct TORCH_API GRUOptions {
  GRUOptions(int64_t input_size, int64_t hidden_size);

  TORCH_ARG(int64_t, input_size);
  TORCH_ARG(int64_t, hidden_size);
  TORCH_ARG(int64_t, num_layers) = 1;
  TORCH_ARG(bool, bias) = true;
  TORCH_ARG(bool, batch_first) = false;
  TORCH_ARG(double, dropout) = 0.0;
  TORCH_ARG(bool, bidirectional) = false;
};

namespace detail {

struct TORCH_API RNNCellOptionsBase {
  RNNCellOptionsBase(
      int64_t input_size,
      int64_t hidden_size,
      bool bias,
      int64_t num_chunks);
  TORCH_ARG(int64_t, input_size);
  TORCH_ARG(int64_t, hidden_size);
  TORCH_ARG(bool, bias);
  TORCH_ARG(int64_t, num_chunks);
};

} // namespace detail

struct TORCH_API RNNCellOptions {
  typedef std::variant<enumtype::kTanh, enumtype::kReLU> nonlinearity_t;

  RNNCellOptions(int64_t input_size, int64_t hidden_size);

  TORCH_ARG(int64_t, input_size);
  TORCH_ARG(int64_t, hidden_size);
  TORCH_ARG(bool, bias) = true;
  TORCH_ARG(nonlinearity_t, nonlinearity) = torch::kTanh;
};

struct TORCH_API LSTMCellOptions {
  LSTMCellOptions(int64_t input_size, int64_t hidden_size);

  TORCH_ARG(int64_t, input_size);
  TORCH_ARG(int64_t, hidden_size);
  TORCH_ARG(bool, bias) = true;
};

struct TORCH_API GRUCellOptions {
  GRUCellOptions(int64_t input_size, int64_t hidden_size);

  TORCH_ARG(int64_t, input_size);
  TORCH_ARG(int64_t, hidden_size);
  TORCH_ARG(bool, bias) = true;
};

} // namespace torch::nn

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources