Note
Click here to download the full example code
(prototype) Introduction to Named Tensors in PyTorch¶
Author: Richard Zou
Named Tensors aim to make tensors easier to use by allowing users to associate explicit names with tensor dimensions. In most cases, operations that take dimension parameters will accept dimension names, avoiding the need to track dimensions by position. In addition, named tensors use names to automatically check that APIs are being used correctly at runtime, providing extra safety. Names can also be used to rearrange dimensions, for example, to support “broadcasting by name” rather than “broadcasting by position”.
This tutorial is intended as a guide to the functionality that will be included with the 1.3 launch. By the end of it, you will be able to:
- Create Tensors with named dimensions, as well as remove or rename those dimensions
- Understand the basics of how operations propagate dimension names
- See how naming dimensions enables clearer code in two key areas:
- Broadcasting operations
- Flattening and unflattening dimensions
Finally, we’ll put this into practice by writing a multi-head attention module using named tensors.
Named tensors in PyTorch are inspired by and done in collaboration with Sasha Rush. Sasha proposed the original idea and proof of concept in his January 2019 blog post.
Basics: named dimensions¶
PyTorch now allows Tensors to have named dimensions; factory functions take a new names argument that associates a name with each dimension. This works with most factory functions, such as
- tensor
- empty
- ones
- zeros
- randn
- rand
Here we construct a tensor with names:
import torch
imgs = torch.randn(1, 2, 2, 3, names=('N', 'C', 'H', 'W'))
print(imgs.names)
Out:
('N', 'C', 'H', 'W')
Unlike in
the original named tensors blogpost,
named dimensions are ordered: tensor.names[i]
is the name of the i
th
dimension of tensor
.
There are two ways to rename a Tensor
’s dimensions:
# Method #1: set the .names attribute (this changes name in-place)
imgs.names = ['batch', 'channel', 'width', 'height']
print(imgs.names)
# Method #2: specify new names (this changes names out-of-place)
imgs = imgs.rename(channel='C', width='W', height='H')
print(imgs.names)
Out:
('batch', 'channel', 'width', 'height')
('batch', 'C', 'W', 'H')
The preferred way to remove names is to call tensor.rename(None)
:
imgs = imgs.rename(None)
print(imgs.names)
Out:
(None, None, None, None)
Unnamed tensors (tensors with no named dimensions) still work as
normal and do not have names in their repr
.
unnamed = torch.randn(2, 1, 3)
print(unnamed)
print(unnamed.names)
Out:
tensor([[[-0.5647, 0.8112, 1.4354]],
[[-1.1201, -2.5431, 0.1843]]])
(None, None, None)
Named tensors do not require that all dimensions be named.
imgs = torch.randn(3, 1, 1, 2, names=('N', None, None, None))
print(imgs.names)
Out:
('N', None, None, None)
Because named tensors can coexist with unnamed tensors, we need a nice way to
write named tensor-aware code that works with both named and unnamed tensors.
Use tensor.refine_names(*names)
to refine dimensions and lift unnamed
dims to named dims. Refining a dimension is defined as a “rename” with the
following constraints:
- A
None
dim can be refined to have any name - A named dim can only be refined to have the same name.
imgs = torch.randn(3, 1, 1, 2)
named_imgs = imgs.refine_names('N', 'C', 'H', 'W')
print(named_imgs.names)
# Refine the last two dims to 'H' and 'W'. In Python 2, use the string '...'
# instead of ...
named_imgs = imgs.refine_names(..., 'H', 'W')
print(named_imgs.names)
def catch_error(fn):
try:
fn()
assert False
except RuntimeError as err:
err = str(err)
if len(err) > 180:
err = err[:180] + "..."
print(err)
named_imgs = imgs.refine_names('N', 'C', 'H', 'W')
# Tried to refine an existing name to a different name
catch_error(lambda: named_imgs.refine_names('N', 'C', 'H', 'width'))
Out:
('N', 'C', 'H', 'W')
(None, None, 'H', 'W')
refine_names: cannot coerce Tensor['N', 'C', 'H', 'W'] to Tensor['N', 'C', 'H', 'width'] because 'W' is different from 'width' at index 3
Most simple operations propagate names. The ultimate goal for named tensors
is for all operations to propagate names in a reasonable, intuitive manner.
Support for many common operations has been added at the time of the 1.3
release; here, for example, is .abs()
:
print(named_imgs.abs().names)
Out:
('N', 'C', 'H', 'W')
Accessors and Reduction¶
One can use dimension names to refer to dimensions instead of the positional
dimension. These operations also propagate names. Indexing (basic and
advanced) has not been implemented yet but is on the roadmap. Using the
named_imgs
tensor from above, we can do:
output = named_imgs.sum('C') # Perform a sum over the channel dimension
print(output.names)
img0 = named_imgs.select('N', 0) # get one image
print(img0.names)
Out:
('N', 'H', 'W')
('C', 'H', 'W')
Name inference¶
Names are propagated on operations in a two step process called name inference:
- Check names: an operator may perform automatic checks at runtime that check that certain dimension names must match.
- Propagate names: name inference propagates output names to output tensors.
Let’s go through the very small example of adding 2 one-dim tensors with no broadcasting.
x = torch.randn(3, names=('X',))
y = torch.randn(3)
z = torch.randn(3, names=('Z',))
Check names: first, we will check whether the names of these two tensors
match. Two names match if and only if they are equal (string equality) or
at least one is None
(None
is essentially a special wildcard name).
The only one of these three that will error, therefore, is x + z
:
catch_error(lambda: x + z)
Out:
Error when attempting to broadcast dims ['X'] and dims ['Z']: dim 'X' and dim 'Z' are at the same position from the right but do not match.
Propagate names: unify the two names by returning the most refined name
of the two. With x + y
, X
is more refined than None
.
print((x + y).names)
Out:
('X',)
Most name inference rules are straightforward but some of them can have unexpected semantics. Let’s go through a couple you’re likely to encounter: broadcasting and matrix multiply.
Broadcasting¶
Named tensors do not change broadcasting behavior; they still broadcast by position. However, when checking two dimensions for if they can be broadcasted, PyTorch also checks that the names of those dimensions match.
This results in named tensors preventing unintended alignment during
operations that broadcast. In the below example, we apply a
per_batch_scale
to imgs
.
imgs = torch.randn(2, 2, 2, 2, names=('N', 'C', 'H', 'W'))
per_batch_scale = torch.rand(2, names=('N',))
catch_error(lambda: imgs * per_batch_scale)
Out:
Error when attempting to broadcast dims ['N', 'C', 'H', 'W'] and dims ['N']: dim 'W' and dim 'N' are at the same position from the right but do not match.
Without names
, the per_batch_scale
tensor is aligned with the last
dimension of imgs
, which is not what we intended. We really wanted to
perform the operation by aligning per_batch_scale
with the batch
dimension of imgs
.
See the new “explicit broadcasting by names” functionality for how to
align tensors by name, covered below.
Matrix multiply¶
torch.mm(A, B)
performs a dot product between the second dim of A
and the first dim of B
, returning a tensor with the first dim of A
and the second dim of B
. (other matmul functions, such as
torch.matmul
, torch.mv
, and torch.dot
, behave similarly).
markov_states = torch.randn(128, 5, names=('batch', 'D'))
transition_matrix = torch.randn(5, 5, names=('in', 'out'))
# Apply one transition
new_state = markov_states @ transition_matrix
print(new_state.names)
Out:
('batch', 'out')
As you can see, matrix multiply does not check if the contracted dimensions have the same name.
Next, we’ll cover two new behaviors that named tensors enable: explicit broadcasting by names and flattening and unflattening dimensions by names
New behavior: Explicit broadcasting by names¶
One of the main complaints about working with multiple dimensions is the need
to unsqueeze
“dummy” dimensions so that operations can occur.
For example, in our per-batch-scale example before, with unnamed tensors
we’d do the following:
imgs = torch.randn(2, 2, 2, 2) # N, C, H, W
per_batch_scale = torch.rand(2) # N
correct_result = imgs * per_batch_scale.view(2, 1, 1, 1) # N, C, H, W
incorrect_result = imgs * per_batch_scale.expand_as(imgs)
assert not torch.allclose(correct_result, incorrect_result)
We can make these operations safer (and easily agnostic to the number of
dimensions) by using names. We provide a new tensor.align_as(other)
operation that permutes the dimensions of tensor to match the order specified
in other.names
, adding one-sized dimensions where appropriate
(tensor.align_to(*names)
works as well):
imgs = imgs.refine_names('N', 'C', 'H', 'W')
per_batch_scale = per_batch_scale.refine_names('N')
named_result = imgs * per_batch_scale.align_as(imgs)
# note: named tensors do not yet work with allclose
assert torch.allclose(named_result.rename(None), correct_result)
New behavior: Flattening and unflattening dimensions by names¶
One common operation is flattening and unflattening dimensions. Right now,
users perform this using either view
, reshape
, or flatten
; use
cases include flattening batch dimensions to send tensors into operators that
must take inputs with a certain number of dimensions (i.e., conv2d takes 4D
input).
To make these operation more semantically meaningful than view or reshape, we
introduce a new tensor.unflatten(dim, namedshape)
method and update
flatten
to work with names: tensor.flatten(dims, new_dim)
.
flatten
can only flatten adjacent dimensions but also works on
non-contiguous dims. One must pass into unflatten
a named shape,
which is a list of (dim, size)
tuples, to specify how to unflatten the
dim. It is possible to save the sizes during a flatten
for unflatten
but we do not yet do that.
imgs = imgs.flatten(['C', 'H', 'W'], 'features')
print(imgs.names)
imgs = imgs.unflatten('features', (('C', 2), ('H', 2), ('W', 2)))
print(imgs.names)
Out:
('N', 'features')
('N', 'C', 'H', 'W')
Autograd support¶
Autograd currently ignores names on all tensors and just treats them like regular tensors. Gradient computation is correct but we lose the safety that names give us. It is on the roadmap to introduce handling of names to autograd.
x = torch.randn(3, names=('D',))
weight = torch.randn(3, names=('D',), requires_grad=True)
loss = (x - weight).abs()
grad_loss = torch.randn(3)
loss.backward(grad_loss)
correct_grad = weight.grad.clone()
print(correct_grad) # Unnamed for now. Will be named in the future
weight.grad.zero_()
grad_loss = grad_loss.refine_names('C')
loss = (x - weight).abs()
# Ideally we'd check that the names of loss and grad_loss match, but we don't
# yet
loss.backward(grad_loss)
print(weight.grad) # still unnamed
assert torch.allclose(weight.grad, correct_grad)
Out:
tensor([ 0.3588, 0.4460, -0.4983])
tensor([ 0.3588, 0.4460, -0.4983])
Other supported (and unsupported) features¶
See here for a detailed breakdown of what is supported with the 1.3 release.
In particular, we want to call out three important features that are not currently supported:
- Saving or loading named tensors via
torch.save
ortorch.load
- Multi-processing via
torch.multiprocessing
- JIT support; for example, the following will error
imgs_named = torch.randn(1, 2, 2, 3, names=('N', 'C', 'H', 'W'))
@torch.jit.script
def fn(x):
return x
catch_error(lambda: fn(imgs_named))
Out:
NYI: Named tensors are currently unsupported in TorchScript. As a workaround please drop names via `tensor = tensor.rename(None)`.
As a workaround, please drop names via tensor = tensor.rename(None)
before using anything that does not yet support named tensors.
Longer example: Multi-head attention¶
Now we’ll go through a complete example of implementing a common
PyTorch nn.Module
: multi-head attention. We assume the reader is already
familiar with multi-head attention; for a refresher, check out
this explanation
or
this explanation.
We adapt the implementation of multi-head attention from ParlAI; specifically here. Read through the code at that example; then, compare with the code below, noting that there are four places labeled (I), (II), (III), and (IV), where using named tensors enables more readable code; we will dive into each of these after the code block.
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, n_heads, dim, dropout=0):
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
self.dim = dim
self.attn_dropout = nn.Dropout(p=dropout)
self.q_lin = nn.Linear(dim, dim)
self.k_lin = nn.Linear(dim, dim)
self.v_lin = nn.Linear(dim, dim)
nn.init.xavier_normal_(self.q_lin.weight)
nn.init.xavier_normal_(self.k_lin.weight)
nn.init.xavier_normal_(self.v_lin.weight)
self.out_lin = nn.Linear(dim, dim)
nn.init.xavier_normal_(self.out_lin.weight)
def forward(self, query, key=None, value=None, mask=None):
# (I)
query = query.refine_names(..., 'T', 'D')
self_attn = key is None and value is None
if self_attn:
mask = mask.refine_names(..., 'T')
else:
mask = mask.refine_names(..., 'T', 'T_key') # enc attn
dim = query.size('D')
assert dim == self.dim, \
f'Dimensions do not match: {dim} query vs {self.dim} configured'
assert mask is not None, 'Mask is None, please specify a mask'
n_heads = self.n_heads
dim_per_head = dim // n_heads
scale = math.sqrt(dim_per_head)
# (II)
def prepare_head(tensor):
tensor = tensor.refine_names(..., 'T', 'D')
return (tensor.unflatten('D', [('H', n_heads), ('D_head', dim_per_head)])
.align_to(..., 'H', 'T', 'D_head'))
assert value is None
if self_attn:
key = value = query
elif value is None:
# key and value are the same, but query differs
key = key.refine_names(..., 'T', 'D')
value = key
dim = key.size('D')
# Distinguish between query_len (T) and key_len (T_key) dims.
k = prepare_head(self.k_lin(key)).rename(T='T_key')
v = prepare_head(self.v_lin(value)).rename(T='T_key')
q = prepare_head(self.q_lin(query))
dot_prod = q.div_(scale).matmul(k.align_to(..., 'D_head', 'T_key'))
dot_prod.refine_names(..., 'H', 'T', 'T_key') # just a check
# (III)
attn_mask = (mask == 0).align_as(dot_prod)
dot_prod.masked_fill_(attn_mask, -float(1e20))
attn_weights = self.attn_dropout(F.softmax(dot_prod / scale,
dim='T_key'))
# (IV)
attentioned = (
attn_weights.matmul(v).refine_names(..., 'H', 'T', 'D_head')
.align_to(..., 'T', 'H', 'D_head')
.flatten(['H', 'D_head'], 'D')
)
return self.out_lin(attentioned).refine_names(..., 'T', 'D')
(I) Refining the input tensor dims
def forward(self, query, key=None, value=None, mask=None):
# (I)
query = query.refine_names(..., 'T', 'D')
The query = query.refine_names(..., 'T', 'D')
serves as enforcable documentation
and lifts input dimensions to being named. It checks that the last two dimensions
can be refined to ['T', 'D']
, preventing potentially silent or confusing size
mismatch errors later down the line.
(II) Manipulating dimensions in prepare_head
# (II)
def prepare_head(tensor):
tensor = tensor.refine_names(..., 'T', 'D')
return (tensor.unflatten('D', [('H', n_heads), ('D_head', dim_per_head)])
.align_to(..., 'H', 'T', 'D_head'))
The first thing to note is how the code clearly states the input and
output dimensions: the input tensor must end with the T
and D
dims
and the output tensor ends in H
, T
, and D_head
dims.
The second thing to note is how clearly the code describes what is going on.
prepare_head takes the key, query, and value and splits the embedding dim into
multiple heads, finally rearranging the dim order to be [..., 'H', 'T', 'D_head']
.
ParlAI implements prepare_head
as the following, using view
and transpose
operations:
def prepare_head(tensor):
# input is [batch_size, seq_len, n_heads * dim_per_head]
# output is [batch_size * n_heads, seq_len, dim_per_head]
batch_size, seq_len, _ = tensor.size()
tensor = tensor.view(batch_size, tensor.size(1), n_heads, dim_per_head)
tensor = (
tensor.transpose(1, 2)
.contiguous()
.view(batch_size * n_heads, seq_len, dim_per_head)
)
return tensor
Our named tensor variant uses ops that, though more verbose, have more
semantic meaning than view
and transpose
and includes enforcable
documentation in the form of names.
(III) Explicit broadcasting by names
def ignore():
# (III)
attn_mask = (mask == 0).align_as(dot_prod)
dot_prod.masked_fill_(attn_mask, -float(1e20))
mask
usually has dims [N, T]
(in the case of self attention) or
[N, T, T_key]
(in the case of encoder attention) while dot_prod
has dims [N, H, T, T_key]
. To make mask
broadcast correctly with
dot_prod
, we would usually unsqueeze dims 1
and -1
in the case
of self attention or unsqueeze
dim 1
in the case of encoder
attention. Using named tensors, we simply align attn_mask
to dot_prod
using align_as
and stop worrying about where to unsqueeze
dims.
(IV) More dimension manipulation using align_to and flatten
def ignore():
# (IV)
attentioned = (
attn_weights.matmul(v).refine_names(..., 'H', 'T', 'D_head')
.align_to(..., 'T', 'H', 'D_head')
.flatten(['H', 'D_head'], 'D')
)
Here, as in (II), align_to
and flatten
are more semantically
meaningful than view
and transpose
(despite being more verbose).
Running the example¶
n, t, d, h = 7, 5, 2 * 3, 3
query = torch.randn(n, t, d, names=('N', 'T', 'D'))
mask = torch.ones(n, t, names=('N', 'T'))
attn = MultiHeadAttention(h, d)
output = attn(query, mask=mask)
# works as expected!
print(output.names)
Out:
('N', 'T', 'D')
The above works as expected. Furthermore, note that in the code we
did not mention the name of the batch dimension at all. In fact,
our MultiHeadAttention
module is agnostic to the existence of batch
dimensions.
query = torch.randn(t, d, names=('T', 'D'))
mask = torch.ones(t, names=('T',))
output = attn(query, mask=mask)
print(output.names)
Out:
('T', 'D')