Shortcuts

Program Listing for File base.h

Return to documentation for file (torch/csrc/api/include/torch/data/transforms/base.h)

#pragma once

#include <torch/types.h>

#include <utility>
#include <vector>

namespace torch {
namespace data {
namespace transforms {

template <typename InputBatch, typename OutputBatch>
class BatchTransform {
 public:
  using InputBatchType = InputBatch;
  using OutputBatchType = OutputBatch;

  virtual ~BatchTransform() = default;

  virtual OutputBatch apply_batch(InputBatch input_batch) = 0;
};

template <typename Input, typename Output>
class Transform
    : public BatchTransform<std::vector<Input>, std::vector<Output>> {
 public:
  using InputType = Input;
  using OutputType = Output;

  virtual OutputType apply(InputType input) = 0;

  std::vector<Output> apply_batch(std::vector<Input> input_batch) override {
    std::vector<Output> output_batch;
    output_batch.reserve(input_batch.size());
    for (auto&& input : input_batch) {
      output_batch.push_back(apply(std::move(input)));
    }
    return output_batch;
  }
};
} // namespace transforms
} // namespace data
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources