Shortcuts

Program Listing for File lambda.h

Return to documentation for file (torch/csrc/api/include/torch/data/transforms/lambda.h)

#pragma once

#include <torch/data/transforms/base.h>

#include <functional>
#include <utility>
#include <vector>

namespace torch {
namespace data {
namespace transforms {

template <typename Input, typename Output = Input>
class BatchLambda : public BatchTransform<Input, Output> {
 public:
  using typename BatchTransform<Input, Output>::InputBatchType;
  using typename BatchTransform<Input, Output>::OutputBatchType;
  using FunctionType = std::function<OutputBatchType(InputBatchType)>;

  explicit BatchLambda(FunctionType function)
      : function_(std::move(function)) {}

  OutputBatchType apply_batch(InputBatchType input_batch) override {
    return function_(std::move(input_batch));
  }

 private:
  FunctionType function_;
};

// A `Transform` that applies a user-provided functor to individual examples.
template <typename Input, typename Output = Input>
class Lambda : public Transform<Input, Output> {
 public:
  using typename Transform<Input, Output>::InputType;
  using typename Transform<Input, Output>::OutputType;
  using FunctionType = std::function<Output(Input)>;

  explicit Lambda(FunctionType function) : function_(std::move(function)) {}

  OutputType apply(InputType input) override {
    return function_(std::move(input));
  }

 private:
  FunctionType function_;
};

} // namespace transforms
} // namespace data
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources