Shortcuts

Program Listing for File rnn.h

Return to documentation for file (torch/csrc/api/include/torch/nn/utils/rnn.h)

#pragma once

#include <c10/util/irange.h>
#include <torch/types.h>

#include <utility>

namespace torch {
namespace nn {
namespace utils {
namespace rnn {

inline Tensor invert_permutation(const Tensor& permutation) {
  if (!permutation.defined()) {
    return torch::Tensor();
  }
  Tensor output =
      torch::empty_like(permutation, torch::MemoryFormat::Contiguous);
  output.scatter_(
      0,
      permutation,
      torch::arange(0, permutation.numel(), permutation.device()));
  return output;
}

class PackedSequence {
 public:
  explicit PackedSequence(
      Tensor data,
      Tensor batch_sizes,
      Tensor sorted_indices = {},
      Tensor unsorted_indices = {}) {
    // NB: if unsorted_indices is provided, it should be the inverse permutation
    // to sorted_indices. Don't assert it here because the PackedSequence ctor
    // should only be used internally.
    if (!unsorted_indices.defined()) {
      unsorted_indices = invert_permutation(sorted_indices);
    }
    TORCH_CHECK(
        batch_sizes.device().type() == kCPU,
        "batch_sizes should always be on CPU. "
        "Instances of PackedSequence should never be created manually. "
        "They should be instantiated by functions like pack_sequence "
        "and pack_padded_sequences in nn::utils::rnn. "
        "https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence");
    data_ = std::move(data);
    batch_sizes_ = std::move(batch_sizes);
    sorted_indices_ = std::move(sorted_indices);
    unsorted_indices_ = std::move(unsorted_indices);
  }

  const Tensor& data() const {
    return data_;
  }

  const Tensor& batch_sizes() const {
    return batch_sizes_;
  }

  const Tensor& sorted_indices() const {
    return sorted_indices_;
  }

  const Tensor& unsorted_indices() const {
    return unsorted_indices_;
  }

  PackedSequence pin_memory() const {
    // Why not convert `batch_sizes`?
    // See NOTE [ device and dtype of a PackedSequence ]
    return PackedSequence(
        data_.pin_memory(),
        batch_sizes_,
        sorted_indices_.defined() ? sorted_indices_.pin_memory() : Tensor(),
        unsorted_indices_.defined() ? unsorted_indices_.pin_memory()
                                    : Tensor());
  }

  PackedSequence to(TensorOptions options) const {
    // Performs dtype and/or device conversion on `data_`.
    //
    // If the ``data_`` Tensor already has the correct `torch::Dtype`
    // and `torch::Device`, then ``self`` is returned.
    // Otherwise, returns a copy with the desired configuration.

    // Why not convert `batch_sizes`?
    // See NOTE [ device and dtype of a PackedSequence ]
    Tensor data = data_.to(options);
    if (data.is_same(data_)) {
      return *this;
    } else {
      // Does not forward device or dtype args, device is set from data.device()
      Tensor sorted_indices = sorted_indices_.defined()
          ? sorted_indices_.to(
                options.device(data.device()).dtype(sorted_indices_.dtype()))
          : Tensor();
      Tensor unsorted_indices = unsorted_indices_.defined()
          ? unsorted_indices_.to(
                options.device(data.device()).dtype(unsorted_indices_.dtype()))
          : Tensor();
      return PackedSequence(
          std::move(data),
          batch_sizes_,
          std::move(sorted_indices),
          std::move(unsorted_indices));
    }
  }

  PackedSequence cuda() const {
    return to(kCUDA);
  }

  PackedSequence cpu() const {
    return to(kCPU);
  }

  bool is_cuda() const {
    return data_.is_cuda();
  }

  bool is_pinned() const {
    return data_.is_pinned();
  }

 private:
  Tensor data_;
  Tensor batch_sizes_;
  Tensor sorted_indices_;
  Tensor unsorted_indices_;
};

inline PackedSequence pack_padded_sequence(
    Tensor input,
    Tensor lengths,
    bool batch_first = false,
    bool enforce_sorted = true) {
  lengths = lengths.to(kInt64);
  Tensor sorted_indices;
  if (enforce_sorted) {
    sorted_indices = Tensor();
  } else {
    std::tie(lengths, sorted_indices) =
        torch::sort(lengths, /*dim=*/-1, /*descending=*/true);
    sorted_indices = sorted_indices.to(input.device());
    int64_t batch_dim = batch_first ? 0 : 1;
    input = input.index_select(batch_dim, sorted_indices);
  }

  auto [data, batch_sizes] =
      torch::_pack_padded_sequence(input, lengths, batch_first);
  return PackedSequence(
      std::move(data), std::move(batch_sizes), std::move(sorted_indices), {});
}

inline std::tuple<Tensor, Tensor> pad_packed_sequence(
    PackedSequence sequence,
    bool batch_first = false,
    double padding_value = 0.0,
    c10::optional<int64_t> total_length = torch::nullopt) {
  int64_t max_seq_length = sequence.batch_sizes().size(0);
  if (total_length.has_value()) {
    int64_t total_length_val = total_length.value();
    TORCH_CHECK(
        total_length_val >= max_seq_length,
        "Expected total_length to be at least the length "
        "of the longest sequence in input, but got "
        "total_length=",
        total_length_val,
        " and max sequence length being ",
        max_seq_length);
    max_seq_length = total_length_val;
  }
  auto [padded_output, lengths] = torch::_pad_packed_sequence(
      sequence.data(),
      sequence.batch_sizes(),
      batch_first,
      padding_value,
      max_seq_length);
  const Tensor& unsorted_indices = sequence.unsorted_indices();
  if (unsorted_indices.defined()) {
    int64_t batch_dim = batch_first ? 0 : 1;
    return std::make_tuple(
        padded_output.index_select(batch_dim, unsorted_indices),
        lengths.index({unsorted_indices.cpu()}));
  }
  return std::make_tuple(padded_output, lengths);
}

inline Tensor pad_sequence(
    ArrayRef<Tensor> sequences,
    bool batch_first = false,
    double padding_value = 0) {
  return at::pad_sequence(sequences, batch_first, padding_value);
}

inline PackedSequence pack_sequence(
    ArrayRef<Tensor> sequences,
    bool enforce_sorted = true) {
  Tensor lengths = torch::empty({(int64_t)sequences.size()}, kInt64);
  for (const auto i : c10::irange(sequences.size())) {
    lengths[i] = sequences[i].size(0);
  }
  return pack_padded_sequence(
      at::pad_sequence(sequences),
      std::move(lengths),
      /*batch_first=*/false,
      /*enforce_sorted=*/enforce_sorted);
}

} // namespace rnn
} // namespace utils
} // namespace nn
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources