# torch.tensor_split¶

torch.tensor_split(input, indices_or_sections, dim=0) → List of Tensors

Splits a tensor into multiple sub-tensors, all of which are views of input, along dimension dim according to the indices or number of sections specified by indices_or_sections. This function is based on NumPy’s numpy.array_split().

Parameters
• input (Tensor) – the tensor to split

• indices_or_sections (Tensor, int or list or tuple of python:ints) –

If indices_or_sections is an integer n or a zero dimensional long tensor with value n, input is split into n sections along dimension dim. If input is divisible by n along dimension dim, each section will be of equal size, input.size(dim) / n. If input is not divisible by n, the sizes of the first int(input.size(dim) % n) sections will have size int(input.size(dim) / n) + 1, and the rest will have size int(input.size(dim) / n).

If indices_or_sections is a list or tuple of ints, or a one-dimensional long tensor, then input is split along dimension dim at each of the indices in the list, tuple or tensor. For instance, indices_or_sections=[2, 3] and dim=0 would result in the tensors input[:2], input[2:3], and input[3:].

If indices_or_sections is a tensor, it must be a zero-dimensional or one-dimensional long tensor on the CPU.

• dim (int, optional) – dimension along which to split the tensor. Default: 0

Example:

>>> x = torch.arange(8)
>>> torch.tensor_split(x, 3)
(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7]))

>>> x = torch.arange(7)
>>> torch.tensor_split(x, 3)
(tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6]))
>>> torch.tensor_split(x, (1, 6))
(tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6]))

>>> x = torch.arange(14).reshape(2, 7)
>>> x
tensor([[ 0,  1,  2,  3,  4,  5,  6],
[ 7,  8,  9, 10, 11, 12, 13]])
>>> torch.tensor_split(x, 3, dim=1)
(tensor([[0, 1, 2],
[7, 8, 9]]),
tensor([[ 3,  4],
[10, 11]]),
tensor([[ 5,  6],
[12, 13]]))
>>> torch.tensor_split(x, (1, 6), dim=1)
(tensor([[0],
[7]]),
tensor([[ 1,  2,  3,  4,  5],
[ 8,  9, 10, 11, 12]]),
tensor([[ 6],
[13]]))