Shortcuts

Program Listing for File stateful.h

Return to documentation for file (torch/csrc/api/include/torch/data/dataloader/stateful.h)

#pragma once

#include <c10/util/irange.h>
#include <torch/data/dataloader/base.h>

#include <cstddef>
#include <thread>
#include <utility>

namespace torch {
namespace data {

template <typename Dataset>
class StatefulDataLoader : public DataLoaderBase<
                               Dataset,
                               typename Dataset::BatchType::value_type,
                               typename Dataset::BatchRequestType> {
 public:
  using super = DataLoaderBase<
      Dataset,
      typename Dataset::BatchType::value_type,
      typename Dataset::BatchRequestType>;
  using typename super::BatchRequestType;

  StatefulDataLoader(Dataset dataset, DataLoaderOptions options)
      : super(
            std::move(options),
            std::make_unique<Dataset>(std::move(dataset))) {
    for (const auto w : c10::irange(this->options_.workers)) {
      // As opposed to the stateless case, here all worker threads access the
      // same underlying dataset.
      this->workers_.emplace_back(
          [this] { this->worker_thread(*this->main_thread_dataset_); });
    }
  }

 private:
  void reset() override {
    this->main_thread_dataset_->reset();
    // Call the base class method last because it calls `prefetch()`
    super::reset();
  }

  optional<BatchRequestType> get_batch_request() override {
    return this->options_.batch_size;
  }
};
} // namespace data
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources