Shortcuts

Program Listing for File tensor.h

Return to documentation for file (torch/csrc/api/include/torch/data/transforms/tensor.h)

#pragma once

#include <torch/data/example.h>
#include <torch/data/transforms/base.h>
#include <torch/types.h>

#include <functional>
#include <utility>

namespace torch {
namespace data {
namespace transforms {

template <typename Target = Tensor>
class TensorTransform
    : public Transform<Example<Tensor, Target>, Example<Tensor, Target>> {
 public:
  using E = Example<Tensor, Target>;
  using typename Transform<E, E>::InputType;
  using typename Transform<E, E>::OutputType;

  virtual Tensor operator()(Tensor input) = 0;

  OutputType apply(InputType input) override {
    input.data = (*this)(std::move(input.data));
    return input;
  }
};

template <typename Target = Tensor>
class TensorLambda : public TensorTransform<Target> {
 public:
  using FunctionType = std::function<Tensor(Tensor)>;

  explicit TensorLambda(FunctionType function)
      : function_(std::move(function)) {}

  Tensor operator()(Tensor input) override {
    return function_(std::move(input));
  }

 private:
  FunctionType function_;
};

template <typename Target = Tensor>
struct Normalize : public TensorTransform<Target> {
  Normalize(ArrayRef<double> mean, ArrayRef<double> stddev)
      : mean(torch::tensor(mean, torch::kFloat32)
                 .unsqueeze(/*dim=*/1)
                 .unsqueeze(/*dim=*/2)),
        stddev(torch::tensor(stddev, torch::kFloat32)
                   .unsqueeze(/*dim=*/1)
                   .unsqueeze(/*dim=*/2)) {}

  torch::Tensor operator()(Tensor input) override {
    return input.sub(mean).div(stddev);
  }

  torch::Tensor mean, stddev;
};
} // namespace transforms
} // namespace data
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources