Program Listing for File data_shuttle.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/data/detail/data_shuttle.h
)
#pragma once
#include <torch/data/detail/queue.h>
#include <torch/types.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <chrono>
#include <utility>
namespace torch {
namespace data {
namespace detail {
template <typename Job, typename Result>
class DataShuttle {
public:
void push_job(Job job) {
new_jobs_.push(std::move(job));
++in_flight_jobs_;
}
void push_result(Result result) {
results_.push(std::move(result));
}
Job pop_job() {
return new_jobs_.pop();
}
optional<Result> pop_result(
optional<std::chrono::milliseconds> timeout = nullopt) {
if (in_flight_jobs_ > 0) {
auto result = results_.pop(timeout);
--in_flight_jobs_;
return result;
}
return nullopt;
}
void drain() {
// Clear all inputs so that no further jobs are scheduled.
auto number_cleared = new_jobs_.clear();
in_flight_jobs_ -= number_cleared;
// Remove any outstanding results.
while (in_flight_jobs_ > 0) {
pop_result();
}
}
size_t in_flight_jobs() const noexcept {
return in_flight_jobs_;
}
private:
Queue<Job> new_jobs_;
size_t in_flight_jobs_ = 0;
Queue<Result> results_;
};
} // namespace detail
} // namespace data
} // namespace torch