Program Listing for File lambda.h¶
↰ Return to documentation for file (torch/csrc/api/include/torch/data/transforms/lambda.h
)
#pragma once
#include <torch/data/transforms/base.h>
#include <functional>
#include <utility>
#include <vector>
namespace torch {
namespace data {
namespace transforms {
template <typename Input, typename Output = Input>
class BatchLambda : public BatchTransform<Input, Output> {
public:
using typename BatchTransform<Input, Output>::InputBatchType;
using typename BatchTransform<Input, Output>::OutputBatchType;
using FunctionType = std::function<OutputBatchType(InputBatchType)>;
explicit BatchLambda(FunctionType function)
: function_(std::move(function)) {}
OutputBatchType apply_batch(InputBatchType input_batch) override {
return function_(std::move(input_batch));
}
private:
FunctionType function_;
};
// A `Transform` that applies a user-provided functor to individual examples.
template <typename Input, typename Output = Input>
class Lambda : public Transform<Input, Output> {
public:
using typename Transform<Input, Output>::InputType;
using typename Transform<Input, Output>::OutputType;
using FunctionType = std::function<Output(Input)>;
explicit Lambda(FunctionType function) : function_(std::move(function)) {}
OutputType apply(InputType input) override {
return function_(std::move(input));
}
private:
FunctionType function_;
};
} // namespace transforms
} // namespace data
} // namespace torch