Shortcuts

Program Listing for File shared.h

Return to documentation for file (torch/csrc/api/include/torch/data/datasets/shared.h)

#pragma once

#include <torch/data/datasets/base.h>

#include <memory>
#include <utility>

namespace torch {
namespace data {
namespace datasets {

template <typename UnderlyingDataset>
class SharedBatchDataset : public BatchDataset<
                               SharedBatchDataset<UnderlyingDataset>,
                               typename UnderlyingDataset::BatchType,
                               typename UnderlyingDataset::BatchRequestType> {
 public:
  using BatchType = typename UnderlyingDataset::BatchType;
  using BatchRequestType = typename UnderlyingDataset::BatchRequestType;

  /* implicit */ SharedBatchDataset(
      std::shared_ptr<UnderlyingDataset> shared_dataset)
      : dataset_(std::move(shared_dataset)) {}

  BatchType get_batch(BatchRequestType request) override {
    return dataset_->get_batch(std::move(request));
  }

  optional<size_t> size() const override {
    return dataset_->size();
  }

  UnderlyingDataset& operator*() {
    return *dataset_;
  }

  const UnderlyingDataset& operator*() const {
    return *dataset_;
  }

  UnderlyingDataset* operator->() {
    return dataset_.get();
  }

  const UnderlyingDataset* operator->() const {
    return dataset_.get();
  }

  void reset() {
    dataset_->reset();
  }

 private:
  std::shared_ptr<UnderlyingDataset> dataset_;
};

template <typename UnderlyingDataset, typename... Args>
SharedBatchDataset<UnderlyingDataset> make_shared_dataset(Args&&... args) {
  return std::make_shared<UnderlyingDataset>(std::forward<Args>(args)...);
}
} // namespace datasets
} // namespace data
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources