Shortcuts

torch.unflatten

torch.unflatten(input, dim, sizes) Tensor

Expands a dimension of the input tensor over multiple dimensions.

See also

torch.flatten() the inverse of this function. It coalesces several dimensions into one.

Parameters:
  • input (Tensor) – the input tensor.

  • dim (int) – Dimension to be unflattened, specified as an index into input.shape.

  • sizes (Tuple[int]) – New shape of the unflattened dimension. One of its elements can be -1 in which case the corresponding output dimension is inferred. Otherwise, the product of sizes must equal input.shape[dim].

Returns:

A View of input with the specified dimension unflattened.

Examples::
>>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape
torch.Size([3, 2, 2, 1])
>>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape
torch.Size([3, 2, 2, 1])
>>> torch.unflatten(torch.randn(5, 12, 3), -1, (2, 2, 3, 1, 1)).shape
torch.Size([5, 2, 2, 3, 1, 1, 3])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources