Shortcuts

Program Listing for File data_shuttle.h

Return to documentation for file (torch/csrc/api/include/torch/data/detail/data_shuttle.h)

#pragma once

#include <torch/data/detail/queue.h>
#include <torch/types.h>

#include <c10/util/Exception.h>
#include <c10/util/Optional.h>

#include <chrono>
#include <utility>

namespace torch {
namespace data {
namespace detail {

template <typename Job, typename Result>
class DataShuttle {
 public:
  void push_job(Job job) {
    new_jobs_.push(std::move(job));
    ++in_flight_jobs_;
  }

  void push_result(Result result) {
    results_.push(std::move(result));
  }

  Job pop_job() {
    return new_jobs_.pop();
  }

  optional<Result> pop_result(
      optional<std::chrono::milliseconds> timeout = nullopt) {
    if (in_flight_jobs_ > 0) {
      auto result = results_.pop(timeout);
      --in_flight_jobs_;
      return result;
    }
    return nullopt;
  }

  void drain() {
    // Clear all inputs so that no further jobs are scheduled.
    auto number_cleared = new_jobs_.clear();
    in_flight_jobs_ -= number_cleared;
    // Remove any outstanding results.
    while (in_flight_jobs_ > 0) {
      pop_result();
    }
  }

  size_t in_flight_jobs() const noexcept {
    return in_flight_jobs_;
  }

 private:
  Queue<Job> new_jobs_;
  size_t in_flight_jobs_ = 0;
  Queue<Result> results_;
};

} // namespace detail
} // namespace data
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources