Program Listing for File shared.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/data/datasets/shared.h
)
#pragma once
#include <torch/data/datasets/base.h>
#include <memory>
#include <utility>
namespace torch {
namespace data {
namespace datasets {
template <typename UnderlyingDataset>
class SharedBatchDataset : public BatchDataset<
SharedBatchDataset<UnderlyingDataset>,
typename UnderlyingDataset::BatchType,
typename UnderlyingDataset::BatchRequestType> {
public:
using BatchType = typename UnderlyingDataset::BatchType;
using BatchRequestType = typename UnderlyingDataset::BatchRequestType;
/* implicit */ SharedBatchDataset(
std::shared_ptr<UnderlyingDataset> shared_dataset)
: dataset_(std::move(shared_dataset)) {}
BatchType get_batch(BatchRequestType request) override {
return dataset_->get_batch(std::move(request));
}
optional<size_t> size() const override {
return dataset_->size();
}
UnderlyingDataset& operator*() {
return *dataset_;
}
const UnderlyingDataset& operator*() const {
return *dataset_;
}
UnderlyingDataset* operator->() {
return dataset_.get();
}
const UnderlyingDataset* operator->() const {
return dataset_.get();
}
void reset() {
dataset_->reset();
}
private:
std::shared_ptr<UnderlyingDataset> dataset_;
};
template <typename UnderlyingDataset, typename... Args>
SharedBatchDataset<UnderlyingDataset> make_shared_dataset(Args&&... args) {
return std::make_shared<UnderlyingDataset>(std::forward<Args>(args)...);
}
} // namespace datasets
} // namespace data
} // namespace torch