Shortcuts

Program Listing for File map.h

Return to documentation for file (torch/csrc/api/include/torch/data/datasets/map.h)

#pragma once

#include <torch/data/datasets/base.h>
#include <torch/types.h>

#include <c10/util/ArrayRef.h>

#include <cstddef>
#include <type_traits>
#include <utility>

namespace torch {
namespace data {
namespace datasets {
namespace detail {
template <bool C, typename T>
using optional_if_t = typename std::conditional<C, torch::optional<T>, T>::type;
} // namespace detail

template <typename SourceDataset, typename AppliedTransform>
class MapDataset : public BatchDataset<
                       MapDataset<SourceDataset, AppliedTransform>,
                       detail::optional_if_t<
                           SourceDataset::is_stateful,
                           typename AppliedTransform::OutputBatchType>,
                       typename SourceDataset::BatchRequestType> {
 public:
  using DatasetType = SourceDataset;
  using TransformType = AppliedTransform;
  using BatchRequestType = typename SourceDataset::BatchRequestType;
  using OutputBatchType = detail::optional_if_t<
      SourceDataset::is_stateful,
      typename AppliedTransform::OutputBatchType>;

  MapDataset(DatasetType dataset, TransformType transform)
      : dataset_(std::move(dataset)), transform_(std::move(transform)) {}

  OutputBatchType get_batch(BatchRequestType indices) override {
    return get_batch_impl(std::move(indices));
  }

  // NOLINTNEXTLINE(bugprone-exception-escape)
  optional<size_t> size() const noexcept override {
    return dataset_.size();
  }

  void reset() {
    dataset_.reset();
  }

  const SourceDataset& dataset() noexcept {
    return dataset_;
  }

  const AppliedTransform& transform() noexcept {
    return transform_;
  }

 private:
  template <
      typename D = SourceDataset,
      typename = torch::disable_if_t<D::is_stateful>>
  OutputBatchType get_batch_impl(BatchRequestType indices) {
    return transform_.apply_batch(dataset_.get_batch(std::move(indices)));
  }

  template <typename D = SourceDataset>
  torch::enable_if_t<D::is_stateful, OutputBatchType> get_batch_impl(
      BatchRequestType indices) {
    if (auto batch = dataset_.get_batch(std::move(indices))) {
      return transform_.apply_batch(std::move(*batch));
    }
    return nullopt;
  }

  SourceDataset dataset_;

  // The transformation that is applied to batches received from the dataset.
  AppliedTransform transform_;
};

template <typename DatasetType, typename TransformType>
MapDataset<DatasetType, TransformType> map(
    DatasetType dataset,
    TransformType transform) {
  static_assert(
      std::is_same<
          typename std::conditional<
              DatasetType::is_stateful,
              typename DatasetType::BatchType::value_type,
              typename DatasetType::BatchType>::type,
          typename TransformType::InputBatchType>::value,
      "BatchType type of dataset does not match input type of transform");
  return {std::move(dataset), std::move(transform)};
}

} // namespace datasets
} // namespace data
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources