Rate this Page

Template Class RNNImplBase#

Inheritance Relationships#

Base Type#

Class Documentation#

template<typename Derived>
class RNNImplBase : public torch::nn::Cloneable<Derived>#

Base class for all RNN implementations (intended for code sharing).

Public Functions

explicit RNNImplBase(const RNNOptionsBase &options_)#
virtual void reset() override#

Initializes the parameters of the RNN module.

void reset_parameters()#
virtual void to(torch::Device device, torch::Dtype dtype, bool non_blocking = false) override#

Overrides nn::Module::to() to call flatten_parameters() after the original operation.

virtual void to(torch::Dtype dtype, bool non_blocking = false) override#

Recursively casts all parameters to the given dtype.

If non_blocking is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.

virtual void to(torch::Device device, bool non_blocking = false) override#

Recursively moves all parameters to the given device.

If non_blocking is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.

virtual void pretty_print(std::ostream &stream) const override#

Pretty prints the RNN module into the given stream.

void flatten_parameters()#

Modifies the internal storage of weights for optimization purposes.

On CPU, this method should be called if any of the weight or bias vectors are changed (i.e. weights are added or removed). On GPU, it should be called any time the storage of any parameter is modified, e.g. any time a parameter is assigned a new value. This allows using the fast path in cuDNN implementations of respective RNN forward() methods. It is called once upon construction, inside reset().

std::vector<Tensor> all_weights() const#

Public Members

RNNOptionsBase options_base#

The RNN’s options.

Protected Functions

void reset_flat_weights()#
void check_input(const Tensor &input, const Tensor &batch_sizes) const#
std::tuple<int64_t, int64_t, int64_t> get_expected_hidden_size(const Tensor &input, const Tensor &batch_sizes) const#
void check_hidden_size(const Tensor &hx, std::tuple<int64_t, int64_t, int64_t> expected_hidden_size, std::string msg = "Expected hidden size {1}, got {2}") const#
void check_forward_args(Tensor input, Tensor hidden, Tensor batch_sizes) const#
Tensor permute_hidden(Tensor hx, const Tensor &permutation) const#

Protected Attributes

std::vector<std::string> flat_weights_names_#
std::vector<std::vector<std::string>> all_weights_#
std::vector<Tensor> flat_weights_#