Shortcuts

Program Listing for File stateless.h

Return to documentation for file (torch/csrc/api/include/torch/data/dataloader/stateless.h)

#pragma once

#include <torch/data/dataloader/base.h>
#include <torch/data/worker_exception.h>

#include <c10/util/Exception.h>
#include <c10/util/irange.h>

#include <cstddef>
#include <thread>
#include <utility>

namespace torch {
namespace data {

template <typename Dataset, typename Sampler>
class StatelessDataLoader : public DataLoaderBase<
                                Dataset,
                                typename Dataset::BatchType,
                                typename Sampler::BatchRequestType> {
 public:
  using super = DataLoaderBase<
      Dataset,
      typename Dataset::BatchType,
      typename Sampler::BatchRequestType>;
  using typename super::BatchRequestType;

  StatelessDataLoader(
      Dataset dataset,
      Sampler sampler,
      DataLoaderOptions options)
      : super(std::move(options)), sampler_(std::move(sampler)) {
    for (const auto w : c10::irange(this->options_.workers)) {
      // Here we copy the dataset into the worker thread closure. Each worker
      // has its own copy of the dataset. This means the dataset must be
      // trivially copiable, or else we don't expect more than one worker to
      // be in use.
      (void)w; // Suppress unused variable warning
      this->workers_.emplace_back(
          [this, dataset]() mutable { this->worker_thread(dataset); });
    }
    if (this->options_.workers == 0) {
      this->main_thread_dataset_ =
          std::make_unique<Dataset>(std::move(dataset));
    }
  }

 private:
  void reset() override {
    sampler_.reset();
    // Call the base class method last because it calls `prefetch()`
    super::reset();
  }

  optional<BatchRequestType> get_batch_request() override {
    auto indices = sampler_.next(this->options_.batch_size);
    if (!indices ||
        (indices->size() < this->options_.batch_size &&
         this->options_.drop_last)) {
      return nullopt;
    }
    AT_ASSERT(indices->size() > 0);
    return indices;
  }

  Sampler sampler_;
};
} // namespace data
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources