Shortcuts

Program Listing for File base.h

Return to documentation for file (torch/csrc/api/include/torch/data/datasets/base.h)

#pragma once

#include <torch/data/example.h>
#include <torch/types.h>

#include <c10/util/ArrayRef.h>

#include <cstddef>
#include <cstdint>
#include <type_traits>
#include <utility>
#include <vector>

namespace torch {
namespace data {
namespace datasets {
template <typename S, typename T>
class MapDataset;
template <typename D, typename T>
MapDataset<D, T> map(D, T); // NOLINT
} // namespace datasets
} // namespace data
} // namespace torch

namespace torch {
namespace data {
namespace datasets {
namespace detail {
template <typename T>
struct is_optional : std::false_type {};
template <typename T>
struct is_optional<std::optional<T>> : std::true_type {};
} // namespace detail

template <
    typename Self,
    typename Batch = std::vector<Example<>>,
    typename BatchRequest = ArrayRef<size_t>>
class BatchDataset {
 public:
  using SelfType = Self;
  using BatchType = Batch;
  using BatchRequestType = BatchRequest;
  constexpr static bool is_stateful = detail::is_optional<BatchType>::value;

  virtual ~BatchDataset() = default;

  virtual Batch get_batch(BatchRequest request) = 0;

  virtual std::optional<size_t> size() const = 0;

  template <typename TransformType>
  MapDataset<Self, TransformType> map(TransformType transform) & {
    return datasets::map(static_cast<Self&>(*this), std::move(transform));
  }

  template <typename TransformType>
  MapDataset<Self, TransformType> map(TransformType transform) && {
    return datasets::map(
        std::move(static_cast<Self&>(*this)), std::move(transform));
  }
};

template <typename Self, typename SingleExample = Example<>>
class Dataset : public BatchDataset<Self, std::vector<SingleExample>> {
 public:
  using ExampleType = SingleExample;

  virtual ExampleType get(size_t index) = 0;

  std::vector<ExampleType> get_batch(ArrayRef<size_t> indices) override {
    std::vector<ExampleType> batch;
    batch.reserve(indices.size());
    for (const auto i : indices) {
      batch.push_back(get(i));
    }
    return batch;
  }
};

template <typename Self, typename Batch = std::vector<Example<>>>
using StreamDataset = BatchDataset<Self, Batch, /*BatchRequest=*/size_t>;
} // namespace datasets
} // namespace data
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources