Shortcuts

torch.squeeze

torch.squeeze(input: Tensor, dim: Optional[Union[int, List[int]]]) Tensor

Returns a tensor with all specified dimensions of input of size 1 removed.

For example, if input is of shape: (A×1×B×C×1×D)(A \times 1 \times B \times C \times 1 \times D) then the input.squeeze() will be of shape: (A×B×C×D)(A \times B \times C \times D).

When dim is given, a squeeze operation is done only in the given dimension(s). If input is of shape: (A×1×B)(A \times 1 \times B), squeeze(input, 0) leaves the tensor unchanged, but squeeze(input, 1) will squeeze the tensor to the shape (A×B)(A \times B).

Note

The returned tensor shares the storage with the input tensor, so changing the contents of one will change the contents of the other.

Warning

If the tensor has a batch dimension of size 1, then squeeze(input) will also remove the batch dimension, which can lead to unexpected errors. Consider specifying only the dims you wish to be squeezed.

Parameters
  • input (Tensor) – the input tensor.

  • dim (int or tuple of ints, optional) –

    if given, the input will be squeezed

    only in the specified dimensions.

    Changed in version 2.0: dim now accepts tuples of dimensions.

Example:

>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x)
>>> y.size()
torch.Size([2, 2, 2])
>>> y = torch.squeeze(x, 0)
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x, 1)
>>> y.size()
torch.Size([2, 2, 1, 2])
>>> y = torch.squeeze(x, (1, 2, 3))
torch.Size([2, 2, 2])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources