Program Listing for File stack.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/data/transforms/stack.h
)
#pragma once
#include <torch/data/example.h>
#include <torch/data/transforms/collate.h>
#include <torch/types.h>
#include <utility>
#include <vector>
namespace torch {
namespace data {
namespace transforms {
template <typename T = Example<>>
struct Stack;
template <>
struct Stack<Example<>> : public Collation<Example<>> {
Example<> apply_batch(std::vector<Example<>> examples) override {
std::vector<torch::Tensor> data, targets;
data.reserve(examples.size());
targets.reserve(examples.size());
for (auto& example : examples) {
data.push_back(std::move(example.data));
targets.push_back(std::move(example.target));
}
return {torch::stack(data), torch::stack(targets)};
}
};
template <>
struct Stack<TensorExample>
: public Collation<Example<Tensor, example::NoTarget>> {
TensorExample apply_batch(std::vector<TensorExample> examples) override {
std::vector<torch::Tensor> data;
data.reserve(examples.size());
for (auto& example : examples) {
data.push_back(std::move(example.data));
}
return torch::stack(data);
}
};
} // namespace transforms
} // namespace data
} // namespace torch