Shortcuts

Template Class RNNImplBase

Inheritance Relationships

Base Type

Class Documentation

template<typename Derived>
class RNNImplBase : public torch::nn::Cloneable<Derived>

Base class for all RNN implementations (intended for code sharing).

Public Functions

explicit RNNImplBase(const RNNOptionsBase &options_)
virtual void reset() override

Initializes the parameters of the RNN module.

void reset_parameters()
virtual void to(torch::Device device, torch::Dtype dtype, bool non_blocking = false) override

Overrides nn::Module::to() to call flatten_parameters() after the original operation.

virtual void to(torch::Dtype dtype, bool non_blocking = false) override

Recursively casts all parameters to the given dtype.

If non_blocking is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.

virtual void to(torch::Device device, bool non_blocking = false) override

Recursively moves all parameters to the given device.

If non_blocking is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.

virtual void pretty_print(std::ostream &stream) const override

Pretty prints the RNN module into the given stream.

void flatten_parameters()

Modifies the internal storage of weights for optimization purposes.

On CPU, this method should be called if any of the weight or bias vectors are changed (i.e. weights are added or removed). On GPU, it should be called any time the storage of any parameter is modified, e.g. any time a parameter is assigned a new value. This allows using the fast path in cuDNN implementations of respective RNN forward() methods. It is called once upon construction, inside reset().

std::vector<Tensor> all_weights() const

Public Members

RNNOptionsBase options_base

The RNN’s options.

Protected Functions

void reset_flat_weights()
void check_input(const Tensor &input, const Tensor &batch_sizes) const
std::tuple<int64_t, int64_t, int64_t> get_expected_hidden_size(const Tensor &input, const Tensor &batch_sizes) const
void check_hidden_size(const Tensor &hx, std::tuple<int64_t, int64_t, int64_t> expected_hidden_size, std::string msg = "Expected hidden size {1}, got {2}") const
void check_forward_args(Tensor input, Tensor hidden, Tensor batch_sizes) const
Tensor permute_hidden(Tensor hx, const Tensor &permutation) const

Protected Attributes

std::vector<std::string> flat_weights_names_
std::vector<std::vector<std::string>> all_weights_
std::vector<Tensor> flat_weights_

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources