Program Listing for File tensor.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/data/datasets/tensor.h
)
#pragma once
#include <torch/data/datasets/base.h>
#include <torch/data/example.h>
#include <torch/types.h>
#include <cstddef>
#include <vector>
namespace torch {
namespace data {
namespace datasets {
struct TensorDataset : public Dataset<TensorDataset, TensorExample> {
explicit TensorDataset(const std::vector<Tensor>& tensors)
: TensorDataset(torch::stack(tensors)) {}
explicit TensorDataset(torch::Tensor tensor) : tensor(std::move(tensor)) {}
TensorExample get(size_t index) override {
return tensor[index];
}
optional<size_t> size() const override {
return tensor.size(0);
}
Tensor tensor;
};
} // namespace datasets
} // namespace data
} // namespace torch