Shortcuts

Program Listing for File rnn.h

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/rnn.h)

#pragma once

#include <torch/nn/cloneable.h>
#include <torch/nn/modules/common.h>
#include <torch/nn/modules/dropout.h>
#include <torch/nn/options/rnn.h>
#include <torch/nn/pimpl.h>
#include <torch/nn/utils/rnn.h>
#include <torch/types.h>

#include <ATen/ATen.h>
#include <c10/util/Exception.h>

#include <cstddef>
#include <functional>
#include <memory>
#include <vector>

namespace torch {
namespace nn {

namespace detail {
template <typename Derived>
class TORCH_API RNNImplBase : public torch::nn::Cloneable<Derived> {
 public:
  explicit RNNImplBase(const RNNOptionsBase& options_);

  void reset() override;

  void reset_parameters();

  void to(torch::Device device, torch::Dtype dtype, bool non_blocking = false)
      override;
  void to(torch::Dtype dtype, bool non_blocking = false) override;
  void to(torch::Device device, bool non_blocking = false) override;

  void pretty_print(std::ostream& stream) const override;

  void flatten_parameters();

  std::vector<Tensor> all_weights() const;

  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  RNNOptionsBase options_base;

 protected:
  // Resets flat_weights_
  // Note: be v. careful before removing this, as 3rd party device types
  // likely rely on this behavior to properly .to() modules like LSTM.
  void reset_flat_weights();

  void check_input(const Tensor& input, const Tensor& batch_sizes) const;

  std::tuple<int64_t, int64_t, int64_t> get_expected_hidden_size(
      const Tensor& input,
      const Tensor& batch_sizes) const;

  void check_hidden_size(
      const Tensor& hx,
      std::tuple<int64_t, int64_t, int64_t> expected_hidden_size,
      std::string msg = "Expected hidden size {1}, got {2}") const;

  void check_forward_args(Tensor input, Tensor hidden, Tensor batch_sizes)
      const;

  Tensor permute_hidden(Tensor hx, const Tensor& permutation) const;

  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  std::vector<std::string> flat_weights_names_;
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  std::vector<std::vector<std::string>> all_weights_;
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  std::vector<Tensor> flat_weights_;
};
} // namespace detail

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API RNNImpl : public detail::RNNImplBase<RNNImpl> {
 public:
  RNNImpl(int64_t input_size, int64_t hidden_size)
      : RNNImpl(RNNOptions(input_size, hidden_size)) {}
  explicit RNNImpl(const RNNOptions& options_);

  std::tuple<Tensor, Tensor> forward(const Tensor& input, Tensor hx = {});

 protected:
  FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())})

 public:
  std::tuple<torch::nn::utils::rnn::PackedSequence, Tensor>
  forward_with_packed_input(
      const torch::nn::utils::rnn::PackedSequence& packed_input,
      Tensor hx = {});

  RNNOptions options;

 protected:
  std::tuple<Tensor, Tensor> forward_helper(
      const Tensor& input,
      const Tensor& batch_sizes,
      const Tensor& sorted_indices,
      int64_t max_batch_size,
      Tensor hx);
};

TORCH_MODULE(RNN);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API LSTMImpl : public detail::RNNImplBase<LSTMImpl> {
 public:
  LSTMImpl(int64_t input_size, int64_t hidden_size)
      : LSTMImpl(LSTMOptions(input_size, hidden_size)) {}
  explicit LSTMImpl(const LSTMOptions& options_);

  std::tuple<Tensor, std::tuple<Tensor, Tensor>> forward(
      const Tensor& input,
      torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {});

 protected:
  FORWARD_HAS_DEFAULT_ARGS(
      {1, AnyValue(torch::optional<std::tuple<Tensor, Tensor>>())})

 public:
  std::tuple<torch::nn::utils::rnn::PackedSequence, std::tuple<Tensor, Tensor>>
  forward_with_packed_input(
      const torch::nn::utils::rnn::PackedSequence& packed_input,
      torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {});

  LSTMOptions options;

 protected:
  void check_forward_args(
      const Tensor& input,
      std::tuple<Tensor, Tensor> hidden,
      const Tensor& batch_sizes) const;

  std::tuple<int64_t, int64_t, int64_t> get_expected_cell_size(
      const Tensor& input,
      const Tensor& batch_sizes) const;

  std::tuple<Tensor, Tensor> permute_hidden(
      std::tuple<Tensor, Tensor> hx,
      const Tensor& permutation) const;

  std::tuple<Tensor, std::tuple<Tensor, Tensor>> forward_helper(
      const Tensor& input,
      const Tensor& batch_sizes,
      const Tensor& sorted_indices,
      int64_t max_batch_size,
      torch::optional<std::tuple<Tensor, Tensor>> hx_opt);
};

TORCH_MODULE(LSTM);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API GRUImpl : public detail::RNNImplBase<GRUImpl> {
 public:
  GRUImpl(int64_t input_size, int64_t hidden_size)
      : GRUImpl(GRUOptions(input_size, hidden_size)) {}
  explicit GRUImpl(const GRUOptions& options_);

  std::tuple<Tensor, Tensor> forward(const Tensor& input, Tensor hx = {});

 protected:
  FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(torch::Tensor())})

 public:
  std::tuple<torch::nn::utils::rnn::PackedSequence, Tensor>
  forward_with_packed_input(
      const torch::nn::utils::rnn::PackedSequence& packed_input,
      Tensor hx = {});

  GRUOptions options;

 protected:
  std::tuple<Tensor, Tensor> forward_helper(
      const Tensor& input,
      const Tensor& batch_sizes,
      const Tensor& sorted_indices,
      int64_t max_batch_size,
      Tensor hx);
};

TORCH_MODULE(GRU);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCellImplBase
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

namespace detail {
template <typename Derived>
class TORCH_API RNNCellImplBase : public torch::nn::Cloneable<Derived> {
 public:
  explicit RNNCellImplBase(const RNNCellOptionsBase& options_);

  void reset() override;

  void reset_parameters();

  void pretty_print(std::ostream& stream) const override;

  RNNCellOptionsBase options_base;

  Tensor weight_ih;
  Tensor weight_hh;
  Tensor bias_ih;
  Tensor bias_hh;

 protected:
  void check_forward_input(const Tensor& input, const std::string name) const;
  virtual std::string get_nonlinearity_str() const;
};
} // namespace detail

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCell
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API RNNCellImpl : public detail::RNNCellImplBase<RNNCellImpl> {
 public:
  RNNCellImpl(int64_t input_size, int64_t hidden_size)
      : RNNCellImpl(RNNCellOptions(input_size, hidden_size)) {}
  explicit RNNCellImpl(const RNNCellOptions& options_);

  Tensor forward(const Tensor& input, Tensor hx = {});

 protected:
  FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())})

 public:
  RNNCellOptions options;

 protected:
  std::string get_nonlinearity_str() const override;
};

TORCH_MODULE(RNNCell);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTMCell
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API LSTMCellImpl : public detail::RNNCellImplBase<LSTMCellImpl> {
 public:
  LSTMCellImpl(int64_t input_size, int64_t hidden_size)
      : LSTMCellImpl(LSTMCellOptions(input_size, hidden_size)) {}
  explicit LSTMCellImpl(const LSTMCellOptions& options_);

  std::tuple<Tensor, Tensor> forward(
      const Tensor& input,
      torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {});

 protected:
  FORWARD_HAS_DEFAULT_ARGS(
      {1, AnyValue(torch::optional<std::tuple<Tensor, Tensor>>())})

 public:
  LSTMCellOptions options;
};

TORCH_MODULE(LSTMCell);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRUCell
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

class TORCH_API GRUCellImpl : public detail::RNNCellImplBase<GRUCellImpl> {
 public:
  GRUCellImpl(int64_t input_size, int64_t hidden_size)
      : GRUCellImpl(GRUCellOptions(input_size, hidden_size)) {}
  explicit GRUCellImpl(const GRUCellOptions& options_);

  Tensor forward(const Tensor& input, Tensor hx = {});

 protected:
  FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())})

 public:
  GRUCellOptions options;
};

TORCH_MODULE(GRUCell);

} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources