Program Listing for File base.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/data/dataloader/base.h
)
#pragma once
#include <torch/data/dataloader_options.h>
#include <torch/data/detail/data_shuttle.h>
#include <torch/data/detail/sequencers.h>
#include <torch/data/iterator.h>
#include <torch/data/samplers/random.h>
#include <torch/data/worker_exception.h>
#include <torch/types.h>
#include <torch/csrc/utils/variadic.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <cstddef>
#include <exception>
#include <memory>
#include <thread>
#include <type_traits>
#include <utility>
#include <vector>
namespace torch {
namespace data {
template <typename Dataset, typename Batch, typename BatchRequest>
class DataLoaderBase {
public:
using BatchType = Batch;
using BatchRequestType = BatchRequest;
DataLoaderBase(
DataLoaderOptions options,
std::unique_ptr<Dataset> main_thread_dataset = nullptr)
: options_(std::move(options)),
main_thread_dataset_(std::move(main_thread_dataset)),
sequencer_(new_sequencer()) {}
// NOLINTNEXTLINE(bugprone-exception-escape)
virtual ~DataLoaderBase() {
join();
}
Iterator<Batch> begin() {
TORCH_CHECK(
shuttle_.in_flight_jobs() == 0,
"Attempted to get a new DataLoader iterator "
"while another iterator is not yet exhausted");
reset();
return Iterator<Batch>(std::make_unique<detail::ValidIterator<Batch>>(
[this] { return this->next(); }));
}
Iterator<Batch> end() {
return Iterator<Batch>(std::make_unique<detail::SentinelIterator<Batch>>());
}
void join() {
if (joined_) {
return;
}
shuttle_.drain();
// Send one 'quit' message per worker. Since a worker dies (exits its
// thread) after receiving this message, each `QuitWorker()` message will be
// read by exactly one worker.
for (const auto w : c10::irange(options_.workers)) {
(void)w; // Suppress unused variable warning
push_job(QuitWorker());
}
for (auto& worker : workers_) {
worker.join();
}
joined_ = true;
}
const FullDataLoaderOptions& options() const noexcept {
return options_;
}
protected:
struct Sequenced {
Sequenced() = default;
Sequenced(size_t sqn) : sequence_number(sqn) {}
size_t sequence_number;
};
struct QuitWorker {};
struct Job : Sequenced {
Job() = default;
Job(QuitWorker q, size_t sqn) : Sequenced(sqn), quit(q) {}
Job(BatchRequest&& i, size_t sqn)
: Sequenced(sqn), batch_request(std::move(i)) {}
optional<QuitWorker> quit;
optional<BatchRequest> batch_request;
};
struct Result : Sequenced {
Result() = default;
Result(optional<Batch>&& b, size_t sqn)
: Sequenced(sqn), batch(std::move(b)) {}
Result(std::exception_ptr exception, size_t sqn)
: Sequenced(sqn), exception(std::move(exception)) {}
optional<Batch> batch;
std::exception_ptr exception;
};
virtual optional<BatchRequestType> get_batch_request() = 0;
virtual void reset() {
shuttle_.drain();
sequence_number_ = 0;
sequencer_ = new_sequencer();
prefetch();
}
void prefetch(size_t requested_jobs) {
for (const auto r : c10::irange(requested_jobs)) {
(void)r; // Suppress unused variable
if (auto batch_request = get_batch_request()) {
this->push_job(std::move(*batch_request));
} else {
break;
}
}
}
void prefetch() {
prefetch(options_.max_jobs);
}
optional<BatchType> next() {
if (options_.workers > 0) {
while (optional<Result> result = this->pop_result()) {
if (result->exception) {
throw WorkerException(result->exception);
} else if (result->batch) {
prefetch(1);
return std::move(result->batch);
}
}
} else if (auto batch_request = get_batch_request()) {
return this->main_thread_dataset_->get_batch(std::move(*batch_request));
}
return nullopt;
}
void worker_thread(Dataset& dataset) {
while (true) {
auto job = shuttle_.pop_job();
if (job.quit) {
break;
}
try {
auto batch = dataset.get_batch(std::move(*job.batch_request));
shuttle_.push_result({std::move(batch), job.sequence_number});
} catch (...) {
shuttle_.push_result({std::current_exception(), job.sequence_number});
}
}
}
template <typename T>
void push_job(T value) {
shuttle_.push_job({std::move(value), sequence_number_++});
}
optional<Result> pop_result() {
return sequencer_->next(
[this] { return this->shuttle_.pop_result(this->options_.timeout); });
}
std::unique_ptr<detail::sequencers::Sequencer<Result>> new_sequencer() {
if (options_.enforce_ordering) {
return std::make_unique<detail::sequencers::OrderedSequencer<Result>>(
options_.max_jobs);
}
return std::make_unique<detail::sequencers::NoSequencer<Result>>();
}
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
const FullDataLoaderOptions options_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::unique_ptr<Dataset> main_thread_dataset_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
size_t sequence_number_ = 0;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::vector<std::thread> workers_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
detail::DataShuttle<Job, Result> shuttle_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::unique_ptr<detail::sequencers::Sequencer<Result>> sequencer_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
bool joined_ = false;
};
} // namespace data
} // namespace torch