Program Listing for File rnn.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/utils/rnn.h
)
#pragma once
#include <c10/util/irange.h>
#include <torch/types.h>
#include <utility>
namespace torch {
namespace nn {
namespace utils {
namespace rnn {
inline Tensor invert_permutation(const Tensor& permutation) {
if (!permutation.defined()) {
return torch::Tensor();
}
Tensor output =
torch::empty_like(permutation, torch::MemoryFormat::Contiguous);
output.scatter_(
0,
permutation,
torch::arange(0, permutation.numel(), permutation.device()));
return output;
}
class PackedSequence {
public:
explicit PackedSequence(
Tensor data,
Tensor batch_sizes,
Tensor sorted_indices = {},
Tensor unsorted_indices = {}) {
// NB: if unsorted_indices is provided, it should be the inverse permutation
// to sorted_indices. Don't assert it here because the PackedSequence ctor
// should only be used internally.
if (!unsorted_indices.defined()) {
unsorted_indices = invert_permutation(sorted_indices);
}
TORCH_CHECK(
batch_sizes.device().type() == kCPU,
"batch_sizes should always be on CPU. "
"Instances of PackedSequence should never be created manually. "
"They should be instantiated by functions like pack_sequence "
"and pack_padded_sequences in nn::utils::rnn. "
"https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence");
data_ = std::move(data);
batch_sizes_ = std::move(batch_sizes);
sorted_indices_ = std::move(sorted_indices);
unsorted_indices_ = std::move(unsorted_indices);
}
const Tensor& data() const {
return data_;
}
const Tensor& batch_sizes() const {
return batch_sizes_;
}
const Tensor& sorted_indices() const {
return sorted_indices_;
}
const Tensor& unsorted_indices() const {
return unsorted_indices_;
}
PackedSequence pin_memory() const {
// Why not convert `batch_sizes`?
// See NOTE [ device and dtype of a PackedSequence ]
return PackedSequence(
data_.pin_memory(),
batch_sizes_,
sorted_indices_.defined() ? sorted_indices_.pin_memory() : Tensor(),
unsorted_indices_.defined() ? unsorted_indices_.pin_memory()
: Tensor());
}
PackedSequence to(TensorOptions options) const {
// Performs dtype and/or device conversion on `data_`.
//
// If the ``data_`` Tensor already has the correct `torch::Dtype`
// and `torch::Device`, then ``self`` is returned.
// Otherwise, returns a copy with the desired configuration.
// Why not convert `batch_sizes`?
// See NOTE [ device and dtype of a PackedSequence ]
Tensor data = data_.to(options);
if (data.is_same(data_)) {
return *this;
} else {
// Does not forward device or dtype args, device is set from data.device()
Tensor sorted_indices = sorted_indices_.defined()
? sorted_indices_.to(
options.device(data.device()).dtype(sorted_indices_.dtype()))
: Tensor();
Tensor unsorted_indices = unsorted_indices_.defined()
? unsorted_indices_.to(
options.device(data.device()).dtype(unsorted_indices_.dtype()))
: Tensor();
return PackedSequence(
std::move(data),
batch_sizes_,
std::move(sorted_indices),
std::move(unsorted_indices));
}
}
PackedSequence cuda() const {
return to(kCUDA);
}
PackedSequence cpu() const {
return to(kCPU);
}
bool is_cuda() const {
return data_.is_cuda();
}
bool is_pinned() const {
return data_.is_pinned();
}
private:
Tensor data_;
Tensor batch_sizes_;
Tensor sorted_indices_;
Tensor unsorted_indices_;
};
inline PackedSequence pack_padded_sequence(
Tensor input,
Tensor lengths,
bool batch_first = false,
bool enforce_sorted = true) {
lengths = lengths.to(kInt64);
Tensor sorted_indices;
if (enforce_sorted) {
sorted_indices = Tensor();
} else {
std::tie(lengths, sorted_indices) =
torch::sort(lengths, /*dim=*/-1, /*descending=*/true);
sorted_indices = sorted_indices.to(input.device());
int64_t batch_dim = batch_first ? 0 : 1;
input = input.index_select(batch_dim, sorted_indices);
}
auto [data, batch_sizes] =
torch::_pack_padded_sequence(input, lengths, batch_first);
return PackedSequence(
std::move(data), std::move(batch_sizes), std::move(sorted_indices), {});
}
inline std::tuple<Tensor, Tensor> pad_packed_sequence(
PackedSequence sequence,
bool batch_first = false,
double padding_value = 0.0,
c10::optional<int64_t> total_length = torch::nullopt) {
int64_t max_seq_length = sequence.batch_sizes().size(0);
if (total_length.has_value()) {
int64_t total_length_val = total_length.value();
TORCH_CHECK(
total_length_val >= max_seq_length,
"Expected total_length to be at least the length "
"of the longest sequence in input, but got "
"total_length=",
total_length_val,
" and max sequence length being ",
max_seq_length);
max_seq_length = total_length_val;
}
auto [padded_output, lengths] = torch::_pad_packed_sequence(
sequence.data(),
sequence.batch_sizes(),
batch_first,
padding_value,
max_seq_length);
const Tensor& unsorted_indices = sequence.unsorted_indices();
if (unsorted_indices.defined()) {
int64_t batch_dim = batch_first ? 0 : 1;
return std::make_tuple(
padded_output.index_select(batch_dim, unsorted_indices),
lengths.index({unsorted_indices.cpu()}));
}
return std::make_tuple(padded_output, lengths);
}
inline Tensor pad_sequence(
ArrayRef<Tensor> sequences,
bool batch_first = false,
double padding_value = 0) {
return at::pad_sequence(sequences, batch_first, padding_value);
}
inline PackedSequence pack_sequence(
ArrayRef<Tensor> sequences,
bool enforce_sorted = true) {
Tensor lengths = torch::empty({(int64_t)sequences.size()}, kInt64);
for (const auto i : c10::irange(sequences.size())) {
lengths[i] = sequences[i].size(0);
}
return pack_padded_sequence(
at::pad_sequence(sequences),
std::move(lengths),
/*batch_first=*/false,
/*enforce_sorted=*/enforce_sorted);
}
} // namespace rnn
} // namespace utils
} // namespace nn
} // namespace torch