Program Listing for File base.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/data/transforms/base.h
)
#pragma once
#include <torch/types.h>
#include <utility>
#include <vector>
namespace torch {
namespace data {
namespace transforms {
template <typename InputBatch, typename OutputBatch>
class BatchTransform {
public:
using InputBatchType = InputBatch;
using OutputBatchType = OutputBatch;
virtual ~BatchTransform() = default;
virtual OutputBatch apply_batch(InputBatch input_batch) = 0;
};
template <typename Input, typename Output>
class Transform
: public BatchTransform<std::vector<Input>, std::vector<Output>> {
public:
using InputType = Input;
using OutputType = Output;
virtual OutputType apply(InputType input) = 0;
std::vector<Output> apply_batch(std::vector<Input> input_batch) override {
std::vector<Output> output_batch;
output_batch.reserve(input_batch.size());
for (auto&& input : input_batch) {
output_batch.push_back(apply(std::move(input)));
}
return output_batch;
}
};
} // namespace transforms
} // namespace data
} // namespace torch