Shortcuts

Program Listing for File iterator.h

Return to documentation for file (torch/csrc/api/include/torch/data/iterator.h)

#pragma once

#include <torch/csrc/utils/variadic.h>
#include <torch/types.h>

#include <c10/util/Exception.h>

#include <functional>
#include <iterator>
#include <memory>
#include <type_traits>
#include <utility>

namespace torch {
namespace data {
namespace detail {
// For increased safety and more separated logic, this implementation of
// `Iterator` consists of a `ValidIterator` and a `SentinelIterator`. A
// `ValidIterator` yields new batches until the `DataLoader` is exhausted. While
// the `DataLoader` is not exhausted, `ValidIterator`s compare equal if they are
// the same object. When the `ValidIterator` becomes exhausted, it compares
// equal to the `SentinelIterator`, but not before. Half the code here is to
// implement double dispatch for the comparison. Got damnit, C++.

template <typename Batch>
struct ValidIterator;

template <typename Batch>
struct SentinelIterator;

template <typename Batch>
struct IteratorImpl {
  virtual ~IteratorImpl() = default;
  virtual void next() = 0;
  virtual Batch& get() = 0;
  virtual bool operator==(const IteratorImpl& other) const = 0;
  virtual bool operator==(const ValidIterator<Batch>& other) const = 0;
  virtual bool operator==(const SentinelIterator<Batch>& other) const = 0;
};

template <typename Batch>
struct ValidIterator : public IteratorImpl<Batch> {
  using BatchProducer = std::function<std::optional<Batch>()>;

  explicit ValidIterator(BatchProducer next_batch)
      : next_batch_(std::move(next_batch)) {}

  void next() override {
    // If we didn't get the very first batch yet, get it now.
    lazy_initialize();
    TORCH_CHECK(
        batch_.has_value(), "Attempted to increment iterator past the end");
    // Increment to the next batch.
    batch_ = next_batch_();
  }

  Batch& get() override {
    // If we didn't get the very first batch yet, get it now.
    lazy_initialize();
    TORCH_CHECK(
        batch_.has_value(),
        "Attempted to dereference iterator that was past the end");
    return batch_.value();
  }

  bool operator==(const IteratorImpl<Batch>& other) const override {
    return other == *this;
  }

  bool operator==(const SentinelIterator<Batch>& /* unused */) const override {
    lazy_initialize();
    return !batch_;
  }

  bool operator==(const ValidIterator<Batch>& other) const override {
    return &other == this;
  }

  void lazy_initialize() const {
    if (!initialized_) {
      batch_ = next_batch_();
      initialized_ = true;
    }
  }

  BatchProducer next_batch_;
  mutable std::optional<Batch> batch_;
  mutable bool initialized_ = false;
};

template <typename Batch>
struct SentinelIterator : public IteratorImpl<Batch> {
  void next() override {
    AT_ERROR(
        "Incrementing the DataLoader's past-the-end iterator is not allowed");
  }

  Batch& get() override {
    AT_ERROR(
        "Dereferencing the DataLoader's past-the-end iterator is not allowed");
  }

  bool operator==(const IteratorImpl<Batch>& other) const override {
    return other == *this;
  }

  bool operator==(const ValidIterator<Batch>& other) const override {
    return other == *this;
  }

  bool operator==(const SentinelIterator<Batch>& other) const override {
    return true;
  }
};
} // namespace detail

template <typename Batch>
class Iterator {
 public:
  // Type aliases to make the class recognized as a proper iterator.
  using difference_type = std::ptrdiff_t;
  using value_type = Batch;
  using pointer = Batch*;
  using reference = Batch&;
  using iterator_category = std::input_iterator_tag;

  explicit Iterator(std::unique_ptr<detail::IteratorImpl<Batch>> impl)
      : impl_(std::move(impl)) {}

  Iterator& operator++() {
    impl_->next();
    return *this;
  }

  Batch& operator*() {
    return impl_->get();
  }

  Batch* operator->() {
    return &impl_->get();
  }

  bool operator==(const Iterator& other) const {
    return *impl_ == *other.impl_;
  }

  bool operator!=(const Iterator& other) const {
    return !(*this == other);
  }

 private:
  std::shared_ptr<detail::IteratorImpl<Batch>> impl_;
};
} // namespace data
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources