Program Listing for File stateless.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/data/dataloader/stateless.h
)
#pragma once
#include <torch/data/dataloader/base.h>
#include <torch/data/worker_exception.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <cstddef>
#include <thread>
#include <utility>
namespace torch {
namespace data {
template <typename Dataset, typename Sampler>
class StatelessDataLoader : public DataLoaderBase<
Dataset,
typename Dataset::BatchType,
typename Sampler::BatchRequestType> {
public:
using super = DataLoaderBase<
Dataset,
typename Dataset::BatchType,
typename Sampler::BatchRequestType>;
using typename super::BatchRequestType;
StatelessDataLoader(
Dataset dataset,
Sampler sampler,
DataLoaderOptions options)
: super(std::move(options)), sampler_(std::move(sampler)) {
for (const auto w : c10::irange(this->options_.workers)) {
// Here we copy the dataset into the worker thread closure. Each worker
// has its own copy of the dataset. This means the dataset must be
// trivially copiable, or else we don't expect more than one worker to
// be in use.
(void)w; // Suppress unused variable warning
this->workers_.emplace_back(
[this, dataset]() mutable { this->worker_thread(dataset); });
}
if (this->options_.workers == 0) {
this->main_thread_dataset_ =
std::make_unique<Dataset>(std::move(dataset));
}
}
private:
void reset() override {
sampler_.reset();
// Call the base class method last because it calls `prefetch()`
super::reset();
}
optional<BatchRequestType> get_batch_request() override {
auto indices = sampler_.next(this->options_.batch_size);
if (!indices ||
(indices->size() < this->options_.batch_size &&
this->options_.drop_last)) {
return nullopt;
}
AT_ASSERT(indices->size() > 0);
return indices;
}
Sampler sampler_;
};
} // namespace data
} // namespace torch