Shortcuts

Program Listing for File base.h

Return to documentation for file (torch/csrc/api/include/torch/data/dataloader/base.h)

#pragma once

#include <torch/data/dataloader_options.h>
#include <torch/data/detail/data_shuttle.h>
#include <torch/data/detail/sequencers.h>
#include <torch/data/iterator.h>
#include <torch/data/samplers/random.h>
#include <torch/data/worker_exception.h>
#include <torch/types.h>

#include <torch/csrc/utils/memory.h>
#include <torch/csrc/utils/variadic.h>

#include <c10/util/Exception.h>
#include <c10/util/irange.h>

#include <cstddef>
#include <exception>
#include <memory>
#include <thread>
#include <type_traits>
#include <utility>
#include <vector>

namespace torch {
namespace data {
template <typename Dataset, typename Batch, typename BatchRequest>
class DataLoaderBase {
 public:
  using BatchType = Batch;
  using BatchRequestType = BatchRequest;

  DataLoaderBase(
      DataLoaderOptions options,
      std::unique_ptr<Dataset> main_thread_dataset = nullptr)
      // NOLINTNEXTLINE(performance-move-const-arg)
      : options_(std::move(options)),
        main_thread_dataset_(std::move(main_thread_dataset)),
        sequencer_(new_sequencer()) {}

  // NOLINTNEXTLINE(bugprone-exception-escape)
  virtual ~DataLoaderBase() {
    join();
  }

  Iterator<Batch> begin() {
    TORCH_CHECK(
        shuttle_.in_flight_jobs() == 0,
        "Attempted to get a new DataLoader iterator "
        "while another iterator is not yet exhausted");
    reset();
    return Iterator<Batch>(torch::make_unique<detail::ValidIterator<Batch>>(
        [this] { return this->next(); }));
  }

  Iterator<Batch> end() {
    return Iterator<Batch>(
        torch::make_unique<detail::SentinelIterator<Batch>>());
  }

  void join() {
    if (joined_) {
      return;
    }
    shuttle_.drain();
    // Send one 'quit' message per worker. Since a worker dies (exits its
    // thread) after receiving this message, each `QuitWorker()` message will be
    // read by exactly one worker.
    for (const auto w : c10::irange(options_.workers)) {
      (void)w; // Suppress unused variable warning
      push_job(QuitWorker());
    }
    for (auto& worker : workers_) {
      worker.join();
    }
    joined_ = true;
  }

  const FullDataLoaderOptions& options() const noexcept {
    return options_;
  }

 protected:
  struct Sequenced {
    Sequenced() = default;
    Sequenced(size_t sqn) : sequence_number(sqn) {}
    size_t sequence_number;
  };

  struct QuitWorker {};

  struct Job : Sequenced {
    Job() = default;
    Job(QuitWorker q, size_t sqn) : Sequenced(sqn), quit(q) {}
    Job(BatchRequest&& i, size_t sqn)
        : Sequenced(sqn), batch_request(std::move(i)) {}
    optional<QuitWorker> quit;
    optional<BatchRequest> batch_request;
  };

  struct Result : Sequenced {
    Result() = default;
    Result(optional<Batch>&& b, size_t sqn)
        : Sequenced(sqn), batch(std::move(b)) {}
    Result(std::exception_ptr exception, size_t sqn)
        // NOLINTNEXTLINE(performance-move-const-arg)
        : Sequenced(sqn), exception(std::move(exception)) {}
    optional<Batch> batch;
    std::exception_ptr exception;
  };

  virtual optional<BatchRequestType> get_batch_request() = 0;

  virtual void reset() {
    shuttle_.drain();
    sequence_number_ = 0;
    sequencer_ = new_sequencer();
    prefetch();
  }

  void prefetch(size_t requested_jobs) {
    for (const auto r : c10::irange(requested_jobs)) {
      (void)r; // Suppress unused variable
      if (auto batch_request = get_batch_request()) {
        this->push_job(std::move(*batch_request));
      } else {
        break;
      }
    }
  }

  void prefetch() {
    prefetch(options_.max_jobs);
  }

  optional<BatchType> next() {
    if (options_.workers > 0) {
      while (optional<Result> result = this->pop_result()) {
        if (result->exception) {
          throw WorkerException(result->exception);
        } else if (result->batch) {
          prefetch(1);
          return std::move(result->batch);
        }
      }
    } else if (auto batch_request = get_batch_request()) {
      return this->main_thread_dataset_->get_batch(std::move(*batch_request));
    }
    return nullopt;
  }

  void worker_thread(Dataset& dataset) {
    while (true) {
      auto job = shuttle_.pop_job();
      if (job.quit) {
        break;
      }
      try {
        auto batch = dataset.get_batch(std::move(*job.batch_request));
        shuttle_.push_result({std::move(batch), job.sequence_number});
      } catch (...) {
        shuttle_.push_result({std::current_exception(), job.sequence_number});
      }
    }
  }

  template <typename T>
  void push_job(T value) {
    shuttle_.push_job({std::move(value), sequence_number_++});
  }

  optional<Result> pop_result() {
    return sequencer_->next(
        [this] { return this->shuttle_.pop_result(this->options_.timeout); });
  }

  std::unique_ptr<detail::sequencers::Sequencer<Result>> new_sequencer() {
    if (options_.enforce_ordering) {
      return torch::make_unique<detail::sequencers::OrderedSequencer<Result>>(
          options_.max_jobs);
    }
    return torch::make_unique<detail::sequencers::NoSequencer<Result>>();
  }

  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  const FullDataLoaderOptions options_;

  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  std::unique_ptr<Dataset> main_thread_dataset_;

  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  size_t sequence_number_ = 0;

  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  std::vector<std::thread> workers_;

  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  detail::DataShuttle<Job, Result> shuttle_;

  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  std::unique_ptr<detail::sequencers::Sequencer<Result>> sequencer_;

  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  bool joined_ = false;
};
} // namespace data
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources