Program Listing for File rnn.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/modules/rnn.h
)
#pragma once
#include <torch/nn/cloneable.h>
#include <torch/nn/modules/common.h>
#include <torch/nn/modules/dropout.h>
#include <torch/nn/options/rnn.h>
#include <torch/nn/pimpl.h>
#include <torch/nn/utils/rnn.h>
#include <torch/types.h>
#include <ATen/ATen.h>
#include <c10/util/Exception.h>
#include <cstddef>
#include <functional>
#include <memory>
#include <vector>
namespace torch {
namespace nn {
namespace detail {
template <typename Derived>
class TORCH_API RNNImplBase : public torch::nn::Cloneable<Derived> {
public:
explicit RNNImplBase(const RNNOptionsBase& options_);
void reset() override;
void reset_parameters();
void to(torch::Device device, torch::Dtype dtype, bool non_blocking = false)
override;
void to(torch::Dtype dtype, bool non_blocking = false) override;
void to(torch::Device device, bool non_blocking = false) override;
void pretty_print(std::ostream& stream) const override;
void flatten_parameters();
std::vector<Tensor> all_weights() const;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
RNNOptionsBase options_base;
protected:
// Resets flat_weights_
// Note: be v. careful before removing this, as 3rd party device types
// likely rely on this behavior to properly .to() modules like LSTM.
void reset_flat_weights();
void check_input(const Tensor& input, const Tensor& batch_sizes) const;
std::tuple<int64_t, int64_t, int64_t> get_expected_hidden_size(
const Tensor& input,
const Tensor& batch_sizes) const;
void check_hidden_size(
const Tensor& hx,
std::tuple<int64_t, int64_t, int64_t> expected_hidden_size,
std::string msg = "Expected hidden size {1}, got {2}") const;
void check_forward_args(Tensor input, Tensor hidden, Tensor batch_sizes)
const;
Tensor permute_hidden(Tensor hx, const Tensor& permutation) const;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::vector<std::string> flat_weights_names_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::vector<std::vector<std::string>> all_weights_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::vector<Tensor> flat_weights_;
};
} // namespace detail
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API RNNImpl : public detail::RNNImplBase<RNNImpl> {
public:
RNNImpl(int64_t input_size, int64_t hidden_size)
: RNNImpl(RNNOptions(input_size, hidden_size)) {}
explicit RNNImpl(const RNNOptions& options_);
std::tuple<Tensor, Tensor> forward(const Tensor& input, Tensor hx = {});
protected:
FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())})
public:
std::tuple<torch::nn::utils::rnn::PackedSequence, Tensor>
forward_with_packed_input(
const torch::nn::utils::rnn::PackedSequence& packed_input,
Tensor hx = {});
RNNOptions options;
protected:
std::tuple<Tensor, Tensor> forward_helper(
const Tensor& input,
const Tensor& batch_sizes,
const Tensor& sorted_indices,
int64_t max_batch_size,
Tensor hx);
};
TORCH_MODULE(RNN);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API LSTMImpl : public detail::RNNImplBase<LSTMImpl> {
public:
LSTMImpl(int64_t input_size, int64_t hidden_size)
: LSTMImpl(LSTMOptions(input_size, hidden_size)) {}
explicit LSTMImpl(const LSTMOptions& options_);
std::tuple<Tensor, std::tuple<Tensor, Tensor>> forward(
const Tensor& input,
torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
protected:
FORWARD_HAS_DEFAULT_ARGS(
{1, AnyValue(torch::optional<std::tuple<Tensor, Tensor>>())})
public:
std::tuple<torch::nn::utils::rnn::PackedSequence, std::tuple<Tensor, Tensor>>
forward_with_packed_input(
const torch::nn::utils::rnn::PackedSequence& packed_input,
torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
LSTMOptions options;
protected:
void check_forward_args(
const Tensor& input,
std::tuple<Tensor, Tensor> hidden,
const Tensor& batch_sizes) const;
std::tuple<int64_t, int64_t, int64_t> get_expected_cell_size(
const Tensor& input,
const Tensor& batch_sizes) const;
std::tuple<Tensor, Tensor> permute_hidden(
std::tuple<Tensor, Tensor> hx,
const Tensor& permutation) const;
std::tuple<Tensor, std::tuple<Tensor, Tensor>> forward_helper(
const Tensor& input,
const Tensor& batch_sizes,
const Tensor& sorted_indices,
int64_t max_batch_size,
torch::optional<std::tuple<Tensor, Tensor>> hx_opt);
};
TORCH_MODULE(LSTM);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API GRUImpl : public detail::RNNImplBase<GRUImpl> {
public:
GRUImpl(int64_t input_size, int64_t hidden_size)
: GRUImpl(GRUOptions(input_size, hidden_size)) {}
explicit GRUImpl(const GRUOptions& options_);
std::tuple<Tensor, Tensor> forward(const Tensor& input, Tensor hx = {});
protected:
FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(torch::Tensor())})
public:
std::tuple<torch::nn::utils::rnn::PackedSequence, Tensor>
forward_with_packed_input(
const torch::nn::utils::rnn::PackedSequence& packed_input,
Tensor hx = {});
GRUOptions options;
protected:
std::tuple<Tensor, Tensor> forward_helper(
const Tensor& input,
const Tensor& batch_sizes,
const Tensor& sorted_indices,
int64_t max_batch_size,
Tensor hx);
};
TORCH_MODULE(GRU);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCellImplBase
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
namespace detail {
template <typename Derived>
class TORCH_API RNNCellImplBase : public torch::nn::Cloneable<Derived> {
public:
explicit RNNCellImplBase(const RNNCellOptionsBase& options_);
void reset() override;
void reset_parameters();
void pretty_print(std::ostream& stream) const override;
RNNCellOptionsBase options_base;
Tensor weight_ih;
Tensor weight_hh;
Tensor bias_ih;
Tensor bias_hh;
protected:
void check_forward_input(const Tensor& input, const std::string name) const;
virtual std::string get_nonlinearity_str() const;
};
} // namespace detail
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCell
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API RNNCellImpl : public detail::RNNCellImplBase<RNNCellImpl> {
public:
RNNCellImpl(int64_t input_size, int64_t hidden_size)
: RNNCellImpl(RNNCellOptions(input_size, hidden_size)) {}
explicit RNNCellImpl(const RNNCellOptions& options_);
Tensor forward(const Tensor& input, Tensor hx = {});
protected:
FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())})
public:
RNNCellOptions options;
protected:
std::string get_nonlinearity_str() const override;
};
TORCH_MODULE(RNNCell);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTMCell
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API LSTMCellImpl : public detail::RNNCellImplBase<LSTMCellImpl> {
public:
LSTMCellImpl(int64_t input_size, int64_t hidden_size)
: LSTMCellImpl(LSTMCellOptions(input_size, hidden_size)) {}
explicit LSTMCellImpl(const LSTMCellOptions& options_);
std::tuple<Tensor, Tensor> forward(
const Tensor& input,
torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
protected:
FORWARD_HAS_DEFAULT_ARGS(
{1, AnyValue(torch::optional<std::tuple<Tensor, Tensor>>())})
public:
LSTMCellOptions options;
};
TORCH_MODULE(LSTMCell);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRUCell
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class TORCH_API GRUCellImpl : public detail::RNNCellImplBase<GRUCellImpl> {
public:
GRUCellImpl(int64_t input_size, int64_t hidden_size)
: GRUCellImpl(GRUCellOptions(input_size, hidden_size)) {}
explicit GRUCellImpl(const GRUCellOptions& options_);
Tensor forward(const Tensor& input, Tensor hx = {});
protected:
FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())})
public:
GRUCellOptions options;
};
TORCH_MODULE(GRUCell);
} // namespace nn
} // namespace torch