Program Listing for File tensor.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/data/transforms/tensor.h
)
#pragma once
#include <torch/data/example.h>
#include <torch/data/transforms/base.h>
#include <torch/types.h>
#include <functional>
#include <utility>
namespace torch {
namespace data {
namespace transforms {
template <typename Target = Tensor>
class TensorTransform
: public Transform<Example<Tensor, Target>, Example<Tensor, Target>> {
public:
using E = Example<Tensor, Target>;
using typename Transform<E, E>::InputType;
using typename Transform<E, E>::OutputType;
virtual Tensor operator()(Tensor input) = 0;
OutputType apply(InputType input) override {
input.data = (*this)(std::move(input.data));
return input;
}
};
template <typename Target = Tensor>
class TensorLambda : public TensorTransform<Target> {
public:
using FunctionType = std::function<Tensor(Tensor)>;
explicit TensorLambda(FunctionType function)
: function_(std::move(function)) {}
Tensor operator()(Tensor input) override {
return function_(std::move(input));
}
private:
FunctionType function_;
};
template <typename Target = Tensor>
struct Normalize : public TensorTransform<Target> {
Normalize(ArrayRef<double> mean, ArrayRef<double> stddev)
: mean(torch::tensor(mean, torch::kFloat32)
.unsqueeze(/*dim=*/1)
.unsqueeze(/*dim=*/2)),
stddev(torch::tensor(stddev, torch::kFloat32)
.unsqueeze(/*dim=*/1)
.unsqueeze(/*dim=*/2)) {}
torch::Tensor operator()(Tensor input) override {
return input.sub(mean).div(stddev);
}
torch::Tensor mean, stddev;
};
} // namespace transforms
} // namespace data
} // namespace torch