Program Listing for File iterator.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/data/iterator.h
)
#pragma once
#include <torch/csrc/utils/variadic.h>
#include <torch/types.h>
#include <c10/util/Exception.h>
#include <functional>
#include <iterator>
#include <memory>
#include <type_traits>
#include <utility>
namespace torch {
namespace data {
namespace detail {
// For increased safety and more separated logic, this implementation of
// `Iterator` consists of a `ValidIterator` and a `SentinelIterator`. A
// `ValidIterator` yields new batches until the `DataLoader` is exhausted. While
// the `DataLoader` is not exhausted, `ValidIterator`s compare equal if they are
// the same object. When the `ValidIterator` becomes exhausted, it compares
// equal to the `SentinelIterator`, but not before. Half the code here is to
// implement double dispatch for the comparison. Got damnit, C++.
template <typename Batch>
struct ValidIterator;
template <typename Batch>
struct SentinelIterator;
template <typename Batch>
struct IteratorImpl {
virtual ~IteratorImpl() = default;
virtual void next() = 0;
virtual Batch& get() = 0;
virtual bool operator==(const IteratorImpl& other) const = 0;
virtual bool operator==(const ValidIterator<Batch>& other) const = 0;
virtual bool operator==(const SentinelIterator<Batch>& other) const = 0;
};
template <typename Batch>
struct ValidIterator : public IteratorImpl<Batch> {
using BatchProducer = std::function<std::optional<Batch>()>;
explicit ValidIterator(BatchProducer next_batch)
: next_batch_(std::move(next_batch)) {}
void next() override {
// If we didn't get the very first batch yet, get it now.
lazy_initialize();
TORCH_CHECK(
batch_.has_value(), "Attempted to increment iterator past the end");
// Increment to the next batch.
batch_ = next_batch_();
}
Batch& get() override {
// If we didn't get the very first batch yet, get it now.
lazy_initialize();
TORCH_CHECK(
batch_.has_value(),
"Attempted to dereference iterator that was past the end");
return batch_.value();
}
bool operator==(const IteratorImpl<Batch>& other) const override {
return other == *this;
}
bool operator==(const SentinelIterator<Batch>& /* unused */) const override {
lazy_initialize();
return !batch_;
}
bool operator==(const ValidIterator<Batch>& other) const override {
return &other == this;
}
void lazy_initialize() const {
if (!initialized_) {
batch_ = next_batch_();
initialized_ = true;
}
}
BatchProducer next_batch_;
mutable std::optional<Batch> batch_;
mutable bool initialized_ = false;
};
template <typename Batch>
struct SentinelIterator : public IteratorImpl<Batch> {
void next() override {
AT_ERROR(
"Incrementing the DataLoader's past-the-end iterator is not allowed");
}
Batch& get() override {
AT_ERROR(
"Dereferencing the DataLoader's past-the-end iterator is not allowed");
}
bool operator==(const IteratorImpl<Batch>& other) const override {
return other == *this;
}
bool operator==(const ValidIterator<Batch>& other) const override {
return other == *this;
}
bool operator==(const SentinelIterator<Batch>& other) const override {
return true;
}
};
} // namespace detail
template <typename Batch>
class Iterator {
public:
// Type aliases to make the class recognized as a proper iterator.
using difference_type = std::ptrdiff_t;
using value_type = Batch;
using pointer = Batch*;
using reference = Batch&;
using iterator_category = std::input_iterator_tag;
explicit Iterator(std::unique_ptr<detail::IteratorImpl<Batch>> impl)
: impl_(std::move(impl)) {}
Iterator& operator++() {
impl_->next();
return *this;
}
Batch& operator*() {
return impl_->get();
}
Batch* operator->() {
return &impl_->get();
}
bool operator==(const Iterator& other) const {
return *impl_ == *other.impl_;
}
bool operator!=(const Iterator& other) const {
return !(*this == other);
}
private:
std::shared_ptr<detail::IteratorImpl<Batch>> impl_;
};
} // namespace data
} // namespace torch