Shortcuts

Program Listing for File stack.h

Return to documentation for file (torch/csrc/api/include/torch/data/transforms/stack.h)

#pragma once

#include <torch/data/example.h>
#include <torch/data/transforms/collate.h>
#include <torch/types.h>

#include <utility>
#include <vector>

namespace torch {
namespace data {
namespace transforms {

template <typename T = Example<>>
struct Stack;

template <>
struct Stack<Example<>> : public Collation<Example<>> {
  Example<> apply_batch(std::vector<Example<>> examples) override {
    std::vector<torch::Tensor> data, targets;
    data.reserve(examples.size());
    targets.reserve(examples.size());
    for (auto& example : examples) {
      data.push_back(std::move(example.data));
      targets.push_back(std::move(example.target));
    }
    return {torch::stack(data), torch::stack(targets)};
  }
};

template <>
struct Stack<TensorExample>
    : public Collation<Example<Tensor, example::NoTarget>> {
  TensorExample apply_batch(std::vector<TensorExample> examples) override {
    std::vector<torch::Tensor> data;
    data.reserve(examples.size());
    for (auto& example : examples) {
      data.push_back(std::move(example.data));
    }
    return torch::stack(data);
  }
};
} // namespace transforms
} // namespace data
} // namespace torch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources