Shortcuts

Program Listing for File TensorShape.h

Return to documentation for file (aten/src/ATen/native/TensorShape.h)

#pragma once
#include <ATen/core/IListRef.h>
#include <ATen/core/Tensor.h>
#include <c10/util/irange.h>

namespace at::native {

TORCH_API at::Tensor clone_preserve_strides(const at::Tensor& self);

inline bool cat_should_skip_tensor(const Tensor& t) {
  return t.sym_numel() == 0 && t.dim() == 1;
}

// Check to see if the shape of tensors is compatible
// for being concatenated along a given dimension.
inline void check_cat_shape_except_dim(
    const Tensor& first,
    const Tensor& second,
    int64_t dimension,
    int64_t index) {
  int64_t first_dims = first.dim();
  int64_t second_dims = second.dim();
  TORCH_CHECK(
      first_dims == second_dims,
      "Tensors must have same number of dimensions: got ",
      first_dims,
      " and ",
      second_dims);
  for (const auto dim : c10::irange(first_dims)) {
    if (dim == dimension) {
      continue;
    }
    int64_t first_dim_size = first.sizes()[dim];
    int64_t second_dim_size = second.sizes()[dim];
    TORCH_CHECK(
        first_dim_size == second_dim_size,
        "Sizes of tensors must match except in dimension ",
        dimension,
        ". Expected size ",
        static_cast<long long>(first_dim_size),
        " but got size ",
        static_cast<long long>(second_dim_size),
        " for tensor number ",
        index,
        " in the list.");
  }
}

inline void check_cat_no_zero_dim(const MaterializedITensorListRef& tensors) {
  [[maybe_unused]] int64_t i = 0;
  for (const Tensor& t : tensors) {
    TORCH_CHECK(
        t.dim() > 0,
        "zero-dimensional tensor (at position ",
        i,
        ") cannot be concatenated");
    i++;
  }
}

inline int64_t get_num_splits(
    const Tensor& self,
    int64_t split_size,
    int64_t dim) {
  TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
  TORCH_CHECK(
      split_size >= 0,
      "split expects split_size be non-negative, but got split_size=",
      split_size);
  int64_t dim_size = self.size(dim);
  TORCH_CHECK(
      split_size > 0 || dim_size == 0,
      "split_size can only be 0 if dimension size is 0, "
      "but got dimension size of ",
      dim_size);
  // if split_size is 0 and dimension size is 0, there is 1 split.
  int64_t num_splits = 1;
  if (split_size != 0) {
    // ensuring num_splits is at least 1 makes consistent the case where
    // split_size > dim_size (returns a single split).  We might want to error
    // here, but keep it for BC.
    num_splits = std::max<int64_t>((dim_size + split_size - 1) / split_size, 1);
  }
  return num_splits;
}

inline bool have_same_ndims(TensorList tensors) {
  auto ndim = tensors[0].dim();
  for (const auto tensor_idx : c10::irange(tensors.size())) {
    if (tensors[tensor_idx].dim() != ndim) {
      return false;
    }
  }
  return true;
}

inline void leading_dimension_matches(TensorList tensors, int64_t dim) {
  auto tensor_zero_size = tensors[0].sizes();
  std::vector<c10::SymInt> leading_dim_sizes(
      tensor_zero_size.begin(), tensor_zero_size.begin() + dim);
  for (const auto i : c10::irange(tensors.size())) {
    at::Tensor tensor = tensors[i];
    for (const auto j : c10::irange(dim)) {
      TORCH_CHECK(
          tensor.size(j) == leading_dim_sizes[j],
          "_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors");
    }
  }
}

inline int64_t preprocess_chunk_cat_inputs(
    TensorList tensors,
    int64_t dim,
    int64_t num_chunks) {
  TORCH_CHECK(num_chunks >= 1, "_chunk_cat expects positive num_chunks");
  TORCH_CHECK(
      !tensors.empty(), "_chunk_cat expects a non-empty input tensor list");
  auto expected_dtype = tensors[0].dtype();
  auto expected_device = tensors[0].device();
  for (const auto i : c10::irange(tensors.size())) {
    TORCH_CHECK(tensors[i].numel() > 0, "_chunk_cat expects non-empty tensor");
    TORCH_CHECK(
        tensors[i].dtype() == expected_dtype,
        "_chunk_cat expects all input tensors with the same dtype");
    TORCH_CHECK(
        tensors[i].device() == expected_device,
        "_chunk_cat expects all inputs tensors on the same device");
  }
  if (have_same_ndims(tensors)) {
    dim = maybe_wrap_dim(dim, tensors[0].dim());
  } else {
    TORCH_CHECK(
        dim >= 0,
        "_chunk_cat expects non-negative dim when input tensors have different ndims")
    for (const auto i : c10::irange(tensors.size())) {
      TORCH_CHECK(
          dim < tensors[i].ndimension(),
          "_chunk_cat expects dim < ndim for all input tensors");
    }
  }
  leading_dimension_matches(tensors, dim);
  return dim;
}

} // namespace at::native

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources