Program Listing for File stateful.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/data/dataloader/stateful.h
)
#pragma once
#include <c10/util/irange.h>
#include <torch/data/dataloader/base.h>
#include <cstddef>
#include <thread>
#include <utility>
namespace torch {
namespace data {
template <typename Dataset>
class StatefulDataLoader : public DataLoaderBase<
Dataset,
typename Dataset::BatchType::value_type,
typename Dataset::BatchRequestType> {
public:
using super = DataLoaderBase<
Dataset,
typename Dataset::BatchType::value_type,
typename Dataset::BatchRequestType>;
using typename super::BatchRequestType;
StatefulDataLoader(Dataset dataset, DataLoaderOptions options)
: super(
std::move(options),
std::make_unique<Dataset>(std::move(dataset))) {
for (const auto w : c10::irange(this->options_.workers)) {
// As opposed to the stateless case, here all worker threads access the
// same underlying dataset.
this->workers_.emplace_back(
[this] { this->worker_thread(*this->main_thread_dataset_); });
}
}
private:
void reset() override {
this->main_thread_dataset_->reset();
// Call the base class method last because it calls `prefetch()`
super::reset();
}
optional<BatchRequestType> get_batch_request() override {
return this->options_.batch_size;
}
};
} // namespace data
} // namespace torch