Shortcuts

torch.nested

Introduction

Warning

The PyTorch API of nested tensors is in prototype stage and will change in the near future.

NestedTensor allows the user to pack a list of Tensors into a single, efficient datastructure.

The only constraint on the input Tensors is that their dimension must match.

This enables more efficient metadata representations and access to purpose built kernels.

One application of NestedTensors is to express sequential data in various domains. While the conventional approach is to pad variable length sequences, NestedTensor enables users to bypass padding. The API for calling operations on a nested tensor is no different from that of a regular torch.Tensor, which should allow seamless integration with existing models, with the main difference being construction of the inputs.

As this is a prototype feature, the operations supported are still limited. However, we welcome issues, feature requests and contributions. More information on contributing can be found on this wiki.

Construction

Construction is straightforward and involves passing a list of Tensors to the torch.nested.nested_tensor constructor.

>>> a, b = torch.arange(3), torch.arange(5) + 3
>>> a
tensor([0, 1, 2])
>>> b
tensor([3, 4, 5, 6, 7])
>>> nt = torch.nested.nested_tensor([a, b])
>>> nt
nested_tensor([
  tensor([0, 1, 2]),
    tensor([3, 4, 5, 6, 7])
    ])

Data type, device and whether gradients are required can be chosen via the usual keyword arguments.

>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32, device="cuda", requires_grad=True)
>>> nt
nested_tensor([
  tensor([0., 1., 2.], device='cuda:0', requires_grad=True),
  tensor([3., 4., 5., 6., 7.], device='cuda:0', requires_grad=True)
], device='cuda:0', requires_grad=True)

In the vein of torch.as_tensor, torch.nested.as_nested_tensor can be used to preserve autograd history from the tensors passed to the constructor. For more information, refer to the section on Nested tensor constructor and conversion functions.

In order to form a valid NestedTensor all the passed Tensors need to match in dimension, but none of the other attributes need to.

>>> a = torch.randn(3, 50, 70) # image 1
>>> b = torch.randn(3, 128, 64) # image 2
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
>>> nt.dim()
4

If one of the dimensions doesn’t match, the constructor throws an error.

>>> a = torch.randn(50, 128) # text 1
>>> b = torch.randn(3, 128, 64) # image 2
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: All Tensors given to nested_tensor must have the same dimension. Found dimension 3 for Tensor at index 1 and dimension 2 for Tensor at index 0.

Note that the passed Tensors are being copied into a contiguous piece of memory. The resulting NestedTensor allocates new memory to store them and does not keep a reference.

At this moment we only support one level of nesting, i.e. a simple, flat list of Tensors. In the future we can add support for multiple levels of nesting, such as a list that consists entirely of lists of Tensors. Note that for this extension it is important to maintain an even level of nesting across entries so that the resulting NestedTensor has a well defined dimension. If you have a need for this feature, please feel encouraged to open a feature request so that we can track it and plan accordingly.

size

Even though a NestedTensor does not support .size() (or .shape), it supports .size(i) if dimension i is regular.

>>> a = torch.randn(50, 128) # text 1
>>> b = torch.randn(32, 128) # text 2
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
>>> nt.size(0)
2
>>> nt.size(1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Given dimension 1 is irregular and does not have a size.
>>> nt.size(2)
128

If all dimensions are regular, the NestedTensor is intended to be semantically indistinguishable from a regular torch.Tensor.

>>> a = torch.randn(20, 128) # text 1
>>> nt = torch.nested.nested_tensor([a, a], dtype=torch.float32)
>>> nt.size(0)
2
>>> nt.size(1)
20
>>> nt.size(2)
128
>>> torch.stack(nt.unbind()).size()
torch.Size([2, 20, 128])
>>> torch.stack([a, a]).size()
torch.Size([2, 20, 128])
>>> torch.equal(torch.stack(nt.unbind()), torch.stack([a, a]))
True

In the future we might make it easier to detect this condition and convert seamlessly.

Please open a feature request if you have a need for this (or any other related feature for that matter).

unbind

unbind allows you to retrieve a view of the constituents.

>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(3, 4)
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
>>> nt
nested_tensor([
  tensor([[ 1.2286, -1.2343, -1.4842],
          [-0.7827,  0.6745,  0.0658]]),
  tensor([[-1.1247, -0.4078, -1.0633,  0.8083],
          [-0.2871, -0.2980,  0.5559,  1.9885],
          [ 0.4074,  2.4855,  0.0733,  0.8285]])
])
>>> nt.unbind()
(tensor([[ 1.2286, -1.2343, -1.4842],
        [-0.7827,  0.6745,  0.0658]]), tensor([[-1.1247, -0.4078, -1.0633,  0.8083],
        [-0.2871, -0.2980,  0.5559,  1.9885],
        [ 0.4074,  2.4855,  0.0733,  0.8285]]))
>>> nt.unbind()[0] is not a
True
>>> nt.unbind()[0].mul_(3)
tensor([[ 3.6858, -3.7030, -4.4525],
        [-2.3481,  2.0236,  0.1975]])
>>> nt
nested_tensor([
  tensor([[ 3.6858, -3.7030, -4.4525],
          [-2.3481,  2.0236,  0.1975]]),
  tensor([[-1.1247, -0.4078, -1.0633,  0.8083],
          [-0.2871, -0.2980,  0.5559,  1.9885],
          [ 0.4074,  2.4855,  0.0733,  0.8285]])
])

Note that nt.unbind()[0] is not a copy, but rather a slice of the underlying memory, which represents the first entry or constituent of the NestedTensor.

Nested tensor constructor and conversion functions

The following functions are related to nested tensors:

torch.nested.nested_tensor(tensor_list, *, dtype=None, device=None, requires_grad=False, pin_memory=False)[source]

Constructs a nested tensor with no autograd history (also known as a “leaf tensor”, see Autograd mechanics) from tensor_list a list of tensors.

Parameters:

tensor_list (List[Tensor]) – a list of tensors with the same ndim

Keyword Arguments:
  • dtype (torch.dtype, optional) – the desired type of returned nested tensor. Default: if None, same torch.dtype as leftmost tensor in the list.

  • device (torch.device, optional) – the desired device of returned nested tensor. Default: if None, same torch.device as leftmost tensor in the list

  • requires_grad (bool, optional) – If autograd should record operations on the returned nested tensor. Default: False.

  • pin_memory (bool, optional) – If set, returned nested tensor would be allocated in the pinned memory. Works only for CPU tensors. Default: False.

Return type:

Tensor

Example:

>>> a = torch.arange(3, dtype=torch.float, requires_grad=True)
>>> b = torch.arange(5, dtype=torch.float, requires_grad=True)
>>> nt = torch.nested.nested_tensor([a, b], requires_grad=True)
>>> nt.is_leaf
True
torch.nested.as_nested_tensor(tensor_list, dtype=None, device=None)[source]

Constructs a nested tensor preserving autograd history from tensor_list a list of tensors.

Note

Tensors within the list are always copied by this function due to current nested tensor semantics.

Parameters:

tensor_list (List[Tensor]) – a list of tensors with the same ndim

Keyword Arguments:
  • dtype (torch.dtype, optional) – the desired type of returned nested tensor. Default: if None, same torch.dtype as leftmost tensor in the list.

  • device (torch.device, optional) – the desired device of returned nested tensor. Default: if None, same torch.device as leftmost tensor in the list

Return type:

Tensor

Example:

>>> a = torch.arange(3, dtype=torch.float, requires_grad=True)
>>> b = torch.arange(5, dtype=torch.float, requires_grad=True)
>>> nt = torch.nested.as_nested_tensor([a, b])
>>> nt.is_leaf
False
>>> fake_grad = torch.nested_tensor([torch.ones_like(a), torch.zeros_like(b)])
>>> nt.backward(fake_grad)
>>> a.grad
tensor([1., 1., 1.])
>>> b.grad
tensor([0., 0., 0., 0., 0.])
torch.nested.to_padded_tensor(input, padding, output_size=None, out=None) Tensor

Returns a new (non-nested) Tensor by padding the input nested tensor. The leading entries will be filled with the nested data, while the trailing entries will be padded.

Warning

to_padded_tensor() always copies the underlying data, since the nested and the non-nested tensors differ in memory layout.

Parameters:

padding (float) – The padding value for the trailing entries.

Keyword Arguments:
  • output_size (Tuple[int]) – The size of the output tensor. If given, it must be large enough to contain all nested data; else, will infer by taking the max size of each nested sub-tensor along each dimension.

  • out (Tensor, optional) – the output tensor.

Example:

>>> nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))])
nested_tensor([
  tensor([[ 1.6862, -1.1282,  1.1031,  0.0464, -1.3276],
          [-1.9967, -1.0054,  1.8972,  0.9174, -1.4995]]),
  tensor([[-1.8546, -0.7194, -0.2918, -0.1846],
          [ 0.2773,  0.8793, -0.5183, -0.6447],
          [ 1.8009,  1.8468, -0.9832, -1.5272]])
])
>>> pt_infer = torch.nested.to_padded_tensor(nt, 0.0)
tensor([[[ 1.6862, -1.1282,  1.1031,  0.0464, -1.3276],
         [-1.9967, -1.0054,  1.8972,  0.9174, -1.4995],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],
        [[-1.8546, -0.7194, -0.2918, -0.1846,  0.0000],
         [ 0.2773,  0.8793, -0.5183, -0.6447,  0.0000],
         [ 1.8009,  1.8468, -0.9832, -1.5272,  0.0000]]])
>>> pt_large = torch.nested.to_padded_tensor(nt, 1.0, (2, 4, 6))
tensor([[[ 1.6862, -1.1282,  1.1031,  0.0464, -1.3276,  1.0000],
         [-1.9967, -1.0054,  1.8972,  0.9174, -1.4995,  1.0000],
         [ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000],
         [ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000]],
        [[-1.8546, -0.7194, -0.2918, -0.1846,  1.0000,  1.0000],
         [ 0.2773,  0.8793, -0.5183, -0.6447,  1.0000,  1.0000],
         [ 1.8009,  1.8468, -0.9832, -1.5272,  1.0000,  1.0000],
         [ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000]]])
>>> pt_small = torch.nested.to_padded_tensor(nt, 2.0, (2, 2, 2))
RuntimeError: Value in output_size is less than NestedTensor padded size. Truncation is not supported.

Supported operations

In this section, we summarize the operations that are currently supported on NestedTensor and any constraints they have.

PyTorch operation

Constraints

torch.matmul()

Supports matrix multiplication between two (>= 3d) nested tensors where the last two dimensions are matrix dimensions and the leading (batch) dimensions have the same size (i.e. no broadcasting support for batch dimensions yet).

torch.bmm()

Supports batch matrix multiplication of two 3-d nested tensors.

torch.nn.Linear()

Supports 3-d nested input and a dense 2-d weight matrix.

torch.nn.functional.softmax()

Supports softmax along all dims except dim=0.

torch.nn.Dropout()

Behavior is the same as on regular tensors.

torch.relu()

Behavior is the same as on regular tensors.

torch.gelu()

Behavior is the same as on regular tensors.

torch.add()

Supports elementwise addition of two nested tensors. Supports addition of a scalar to a nested tensor.

torch.mul()

Supports elementwise multiplication of two nested tensors. Supports multipication of a nested tensor by a scalar.

torch.select()

Supports selecting along dim=0 only (analogously nt[i]).

torch.clone()

Behavior is the same as on regular tensors.

torch.detach()

Behavior is the same as on regular tensors.

torch.unbind()

Supports unbinding along dim=0 only.

torch.reshape()

Supports reshaping with size of dim=0 preserved (i.e. number of tensors nested cannot be changed). Unlike regular tensors, a size of -1 here means that the existing size is inherited. In particular, the only valid size for a ragged dimension is -1. Size inference is not implemented yet and hence for new dimensions the size cannot be -1.

torch.Tensor.reshape_as()

Similar constraint as for reshape.

torch.transpose()

Supports transposing of all dims except dim=0.

torch.Tensor.view()

Rules for the new shape are similar to that of reshape.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources