Shortcuts

Template Class DataLoaderBase

Class Documentation

template<typename Dataset, typename Batch, typename BatchRequest>
class DataLoaderBase

Public Types

using BatchType = Batch
using BatchRequestType = BatchRequest

Public Functions

inline DataLoaderBase(DataLoaderOptions options, std::unique_ptr<Dataset> main_thread_dataset = nullptr)

Constructs a new DataLoader from a dataset to sample from, options to configure the DataLoader with, and a sampler that specifies the sampling strategy.

inline virtual ~DataLoaderBase()
inline Iterator<Batch> begin()

Returns an iterator into the DataLoader.

The lifetime of the iterator is bound to the DataLoader. In C++ standards language, the category of the iterator is OutputIterator. See https://en.cppreference.com/w/cpp/named_req/OutputIterator for what this means. In short: you may increment the iterator and dereference it, but cannot go back, or step forward more than one position at a time. When the DataLoader is exhausted, it will compare equal with the special “sentinel” iterator returned by DataLoader::end(). Most of the time, you should only use range-for loops to loop over the DataLoader, but standard algorithms like std::copy(dataloader.begin(), dataloader.end(), output_iterator) are supported too.

inline Iterator<Batch> end()

Returns a special “sentinel” iterator that compares equal with a non-sentinel iterator once the DataLoader is exhausted.

inline void join()

Joins the DataLoader’s worker threads and drains internal queues.

This function may only be invoked from the main thread (in which the DataLoader lives).

inline const FullDataLoaderOptions &options() const noexcept

Returns the options with which the DataLoader was configured.

Protected Functions

virtual std::optional<BatchRequestType> get_batch_request() = 0

Subclass hook for getting the next batch request.

The stateless case will ask the sampler for a new batch request (e.g. a vector of indices), while the stateful one will simply return the batch size.

inline virtual void reset()

Resets the internal state of the DataLoader, optionally pre-fetching new jobs.

inline void prefetch(size_t requested_jobs)

Schedules requested_jobs many new batches to be fetched.

The actual number of jobs scheduled may be less if the DataLoader exhausts.

inline void prefetch()

Schedules the maximum number of jobs (based on the max_jobs option).

inline std::optional<BatchType> next()

Returns the next batch of data, or an empty optional if the DataLoader is exhausted.

This operation will block until a batch is available if one is still expected.

inline void worker_thread(Dataset &dataset)

The function that worker threads run.

template<typename T>
inline void push_job(T value)

Convenience method that calls shuttle_.push_job() with the next sequence number.

inline std::optional<Result> pop_result()

Convenience method that gets the next result from the sequencer.

inline std::unique_ptr<detail::sequencers::Sequencer<Result>> new_sequencer()

Convenience method that creates a new sequencer based on the enforce_ordering option.

Protected Attributes

const FullDataLoaderOptions options_

The options the DataLoader was configured with.

std::unique_ptr<Dataset> main_thread_dataset_

The dataset for the main thread, only has a value if the number of worker threads was configured as zero, meaning the main thread has to do all the work (synchronously).

NOTE: Really want this to be on the heap when empty, therefore unique_ptr and not optional.

size_t sequence_number_ = 0

The sequence number for the next batch to be retrieved from the dataset.

std::vector<std::thread> workers_

The worker threads, running the worker_thread() method.

detail::DataShuttle<Job, Result> shuttle_

The DataShuttle which takes care of the life cycle of a job.

std::unique_ptr<detail::sequencers::Sequencer<Result>> sequencer_

The Sequencer, which handles optional ordering of batches.

bool joined_ = false

True if the DataLoader has joined its worker threads.

struct Job : public torch::data::DataLoaderBase<Dataset, Batch, BatchRequest>::Sequenced

A Job is either a BatchRequest (new indices to fetch data at) or a QuitWorker object, to indicate the worker should shut down.

Public Functions

Job() = default
inline Job(QuitWorker q, size_t sqn)
inline Job(BatchRequest &&i, size_t sqn)

Public Members

std::optional<QuitWorker> quit
std::optional<BatchRequest> batch_request
struct QuitWorker
struct Result : public torch::data::DataLoaderBase<Dataset, Batch, BatchRequest>::Sequenced

The finished result of a job.

Public Functions

Result() = default
inline Result(std::optional<Batch> &&b, size_t sqn)
inline Result(std::exception_ptr exception, size_t sqn)

Public Members

std::optional<Batch> batch
std::exception_ptr exception
struct Sequenced

Simple mix-in to give something a sequence number.

Subclassed by torch::data::DataLoaderBase< Dataset, Batch, BatchRequest >::Job, torch::data::DataLoaderBase< Dataset, Batch, BatchRequest >::Result

Public Functions

Sequenced() = default
inline Sequenced(size_t sqn)

Public Members

size_t sequence_number

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources