Overview of MaskedTensors

Open in Colab

import torch
import numpy as np
from maskedtensor import masked_tensor
from maskedtensor import as_masked_tensor

Basic masking semantics

MaskedTensor vs NumPy’s MaskedArray semantics

# First example of addition
data = torch.arange(5.)
mask = torch.tensor([True, True, False, True, False])
m0 = masked_tensor(data, mask)
m1 = masked_tensor(data, ~mask)

print(m0)
print(m1)
print(torch.cos(m0))
print(m0 + m0)

try:
  # For now the masks must match. We treat them like shapes.
  # We can relax this later on, but should have a good reason for it.
  # We'll revisit this once we have reductions.
  print(m0 + m1)
except ValueError as e:
  print(e)
masked_tensor(
  [  0.0000,   1.0000,       --,   3.0000,       --]
)
masked_tensor(
  [      --,       --,   2.0000,       --,   4.0000]
)
masked_tensor(
  [  1.0000,   0.5403,       --,  -0.9900,       --]
)
masked_tensor(
  [  0.0000,   2.0000,       --,   6.0000,       --]
)
Input masks must match. If you need support for this, please open an issue on Github.

NumPy’s MaskedArray implements intersection semantics here. If one of two elements are masked out the resulting element will be masked out as well. Note that MaskedArray’s factory function inverts the mask (similar to torch.nn.MHA). For MaskedTensor we’d apply the logical_and operator to both masks during a binary operation to get the semantics NumPy has. Since NumPy stores the inverted mask they apply the logical_or operator. But to repeat this point we suggest to not support addition between MaskedTensors with masks that don’t match. See the section on reductions for why we should have good reasons for this.

npm0 = np.ma.masked_array(data.numpy(), (~mask).numpy())
npm1 = np.ma.masked_array(data.numpy(), (mask).numpy())
print("npm0:       ", npm0)
print("npm1:       ", npm1)
print("npm0 + npm1:", npm0 + npm1)
npm0:        [0.0 1.0 -- 3.0 --]
npm1:        [-- -- 2.0 -- 4.0]
npm0 + npm1: [-- -- -- -- --]

MaskedTensor also supports these semantics by giving access to the masks and conveniently converting a MaskedTensor to a Tensor with masked values filled in with a particular value.

NumPy of course has the opportunity to avoid addition altogether in this case by check whether any results are not masked, but chooses not to. Presumably it’s more expensive to allreduce the mask every time to avoid the binary addition of the data in this case.

m0t = m0.to_tensor(0)
m1t = m1.to_tensor(0)

m2t = masked_tensor(m0t + m1t, m0.mask() & m1.mask())
print(m0t)
print(m1t)
print(m2t)
tensor([0., 1., 0., 3., 0.])
tensor([0., 0., 2., 0., 4.])
masked_tensor(
  [      --,       --,       --,       --,       --]
)

MaskedTensor reduction semantics

Example of printing a 2d MaskedTensor and setup for reductions below

data = torch.randn(8, 3).mul(10).int().float()
mask = torch.randint(2, (8, 3), dtype=torch.bool)
print(data)
print(mask)
m = masked_tensor(data, mask)
print(m)
tensor([[ 20., -14.,  -4.],
        [ 22.,   0.,  16.],
        [ -9.,  -6.,  -6.],
        [  4.,   0.,  10.],
        [  2.,  -1.,  -5.],
        [ -4.,  -3., -10.],
        [  1.,  -5.,   0.],
        [  0.,   6., -20.]])
tensor([[False, False, False],
        [ True, False, False],
        [ True,  True,  True],
        [ True, False,  True],
        [ True, False, False],
        [False,  True,  True],
        [False,  True,  True],
        [ True,  True, False]])
masked_tensor(
  [
    [      --,       --,       --],
    [ 22.0000,       --,       --],
    [ -9.0000,  -6.0000,  -6.0000],
    [  4.0000,       --,  10.0000],
    [  2.0000,       --,       --],
    [      --,  -3.0000, -10.0000],
    [      --,  -5.0000,   0.0000],
    [  0.0000,   6.0000,       --]
  ]
)

Reduction semantics based on https://github.com/pytorch/rfcs/pull/27

print("sum:", torch.sum(m, 1))
print("mean:", torch.mean(m, 1))
print("prod:", torch.prod(m, 1))
print("min:", torch.amin(m, 1))
print("max:", torch.amax(m, 1))
sum: masked_tensor(
  [      --,  22.0000, -21.0000,  14.0000,   2.0000, -13.0000,  -5.0000,   6.0000]
)
mean: masked_tensor(
  [      --,  22.0000,  -7.0000,   7.0000,   2.0000,  -6.5000,  -2.5000,   3.0000]
)
prod: masked_tensor(
  [       --,  22.0000, -324.0000,  40.0000,   2.0000,  30.0000,  -0.0000,   0.0000]
)
min: masked_tensor(
  [      --,  22.0000,  -9.0000,   4.0000,   2.0000, -10.0000,  -5.0000,   0.0000]
)
max: masked_tensor(
  [      --,  22.0000,  -6.0000,  10.0000,   2.0000,  -3.0000,   0.0000,   6.0000]
)

Now that we have reductions, let’s revisit as to why we’ll probably want to have a good reason to allow addition of MaskedTensors with different masks.

data0 = torch.arange(10.).reshape(2, 5)
data1 = torch.arange(10.).reshape(2, 5) + 10
mask0 = torch.tensor([[True, True, False, False, False], [False, False, False, True, True]])
mask1 = torch.tensor([[False, False, False, True, True], [True, True, False, False, False]])

npm0 = np.ma.masked_array(data0.numpy(), (mask0).numpy())
npm1 = np.ma.masked_array(data1.numpy(), (mask1).numpy())
print("\nnpm0:\n", npm0)
print("\nnpm1:\n", npm1)
print("\n(npm0 + npm1).sum(0):\n", (npm0 + npm1).sum(0))
print("\nnpm0.sum(0) + npm1.sum(0):\n", (npm0.sum(0) + npm1.sum(0)))
print("\n(data0 + data1).sum(0):\n", (data0 + data1).sum(0))
print("\n(data0 + data1).sum(0):\n", (data0.sum(0) + data1.sum(0)))
npm0:
 [[-- -- 2.0 3.0 4.0]
 [5.0 6.0 7.0 -- --]]

npm1:
 [[10.0 11.0 12.0 -- --]
 [-- -- 17.0 18.0 19.0]]

(npm0 + npm1).sum(0):
 [-- -- 38.0 -- --]

npm0.sum(0) + npm1.sum(0):
 [15.0 17.0 38.0 21.0 23.0]

(data0 + data1).sum(0):
 tensor([30., 34., 38., 42., 46.])

(data0 + data1).sum(0):
 tensor([30., 34., 38., 42., 46.])

Sum and addition should be associative. However with NumPy’s semantics we allow them not to be. Instead of allowing these semantics, at least in the case of addition and sum, we could ask the user to fill the MaskedTensor’s undefined elements with 0 values or as in the MaskedTensor addition examples above be very specific about the semantics used.

While it’s obviously possible to support this, we think we should cover other operators first and really make sure we can’t avoid this behavior via other means.

Indexing and Advanced Indexing

data = torch.randn(4, 5, 3).mul(5).float()
mask = torch.randint(2, (4, 5, 3), dtype=torch.bool)
m = masked_tensor(data, mask)
print(m)
masked_tensor(
  [
    [
      [ -6.0792,       --,       --],
      [      --,       --,  -1.9683],
      [ -2.3700,  -1.9161,  -1.8073],
      [ -3.9613,  -2.5290,  -9.9581],
      [      --,       --,  -1.3996]
    ],
    [
      [ -4.7627,   0.4052,  -2.2773],
      [      --,  -0.7082,       --],
      [      --,       --,       --],
      [      --,       --,   0.5614],
      [  2.4055,       --,       --]
    ],
    [
      [  4.9276,  -2.9012,       --],
      [      --,  -6.6361,   2.5489],
      [  3.0996,  -3.8993,       --],
      [ -0.2633,   0.0727,       --],
      [ -4.5588,   4.1068,       --]
    ],
    [
      [ -3.0962,   3.4735,       --],
      [ -9.2337,  -2.1719,       --],
      [  2.9923,       --,       --],
      [      --,   2.8241,   6.9662],
      [      --,  -3.4862,   4.1371]
    ]
  ]
)

Example of indexing and advanced indexing

print(m[0])
print(m[torch.tensor([0, 2])])
print(m[m.mask()])
masked_tensor(
  [
    [ -6.0792,       --,       --],
    [      --,       --,  -1.9683],
    [ -2.3700,  -1.9161,  -1.8073],
    [ -3.9613,  -2.5290,  -9.9581],
    [      --,       --,  -1.3996]
  ]
)
masked_tensor(
  [
    [
      [ -6.0792,       --,       --],
      [      --,       --,  -1.9683],
      [ -2.3700,  -1.9161,  -1.8073],
      [ -3.9613,  -2.5290,  -9.9581],
      [      --,       --,  -1.3996]
    ],
    [
      [  4.9276,  -2.9012,       --],
      [      --,  -6.6361,   2.5489],
      [  3.0996,  -3.8993,       --],
      [ -0.2633,   0.0727,       --],
      [ -4.5588,   4.1068,       --]
    ]
  ]
)
masked_tensor(
  [ -6.0792,  -1.9683,  -2.3700,  -1.9161,  -1.8073,  -3.9613,  -2.5290,  -9.9581,  -1.3996,  -4.7627,   0.4052,  -2.2773,  -0.7082,   0.5614,   2.4055,   4.9276,  -2.9012,  -6.6361,   2.5489,   3.0996,  -3.8993,  -0.2633,   0.0727,  -4.5588,   4.1068,  -3.0962,   3.4735,  -9.2337,  -2.1719,   2.9923,   2.8241,   6.9662,  -3.4862,   4.1371]
)

MaskedTensor gradient examples

torch.manual_seed(22)
# Sum needs custom autograd, since the mask of the input should be maintained
data = torch.randn(2, 2, 3).mul(5).float()
mask = torch.randint(2, (2, 2, 3), dtype=torch.bool)
m = masked_tensor(data, mask, requires_grad=True)
print(m)
s = torch.sum(m)
print("s: ", s)
s.backward()
print("m.grad: ", m.grad)

# sum needs to return a scalar MaskedTensor because the input might be fully masked
data = torch.randn(2, 2, 3).mul(5).float()
mask = torch.zeros(2, 2, 3, dtype=torch.bool)
m = masked_tensor(data, mask, requires_grad=True)
print("\n", m)
s = torch.sum(m)
print("s: ", s)
s.backward()
print("m.grad: ", m.grad)
masked_tensor(
  [
    [
      [      --,       --,  -0.5084],
      [  6.7935, -15.3725,       --]
    ],
    [
      [      --,   1.2078,       --],
      [  6.5820,       --,  -1.6679]
    ]
  ]
)
s:  masked_tensor( -2.9655, True)
m.grad:  masked_tensor(
  [
    [
      [      --,       --,   1.0000],
      [  1.0000,   1.0000,       --]
    ],
    [
      [      --,   1.0000,       --],
      [  1.0000,       --,   1.0000]
    ]
  ]
)

 masked_tensor(
  [
    [
      [      --,       --,       --],
      [      --,       --,       --]
    ],
    [
      [      --,       --,       --],
      [      --,       --,       --]
    ]
  ]
)
s:  masked_tensor(--, False)
m.grad:  masked_tensor(
  [
    [
      [      --,       --,       --],
      [      --,       --,       --]
    ],
    [
      [      --,       --,       --],
      [      --,       --,       --]
    ]
  ]
)
# Grad of multiplication of MaskedTensor and Tensor
x = masked_tensor(torch.tensor([3.0, 4.0]), torch.tensor([True, False]), requires_grad=True)
print("x:\n", x)
y = torch.tensor([2., 1.]).requires_grad_()
print("y:\n", y)
# The mask broadcast in the sense that the result is masked.
# In general a MaskedTensor is considered a generalization of Tensor's shape.
# The mask is a more complex, higher dimensional shape and thus the Tensor
# broadcasts to it. I'd love to find a more rigorous definition of this.
z = x * y
print("x * y:\n", z)
z.sum().backward()
print("\nx.grad: ", x.grad)
# The regular torch.Tensor now has a MaskedTensor grad
print("y.grad: ", y.grad)
x:
 masked_tensor(
  [  3.0000,       --]
)
y:
 tensor([2., 1.], requires_grad=True)
x * y:
 masked_tensor(
  [  6.0000,       --]
)

x.grad:  masked_tensor(
  [  2.0000,       --]
)
y.grad:  masked_tensor(
  [  3.0000,       --]
)