import torch
import numpy as np
if "1.11.0" not in torch.__version__:
!pip uninstall --y torch

# Import factory function


# First example of addition
data = torch.arange(5.)
mask = torch.tensor([True, True, False, True, False])

print(m0)
print(m1)
print(torch.cos(m0))
print(m0 + m0)

try:
# For now the masks must match. We treat them like shapes.
# We can relax this later on, but should have a good reason for it.
# We'll revisit this once we have reductions.
print(m0 + m1)
except ValueError as e:
print(e)

masked_tensor(
[  0.0000,   1.0000,       --,   3.0000,       --]
)
[      --,       --,   2.0000,       --,   4.0000]
)
[  1.0000,   0.5403,       --,  -0.9900,       --]
)
[  0.0000,   2.0000,       --,   6.0000,       --]
)
Input masks must match. If you need support for this, please open an issue on Github.


npm0 = np.ma.masked_array(data.numpy(), (~mask).numpy())
print("npm0:       ", npm0)
print("npm1:       ", npm1)
print("npm0 + npm1:", npm0 + npm1)

npm0:        [0.0 1.0 -- 3.0 --]
npm1:        [-- -- 2.0 -- 4.0]
npm0 + npm1: [-- -- -- -- --]


NumPy of course has the opportunity to avoid addition altogether in this case by check whether any results are not masked, but chooses not to. Presumably it’s more expensive to allreduce the mask every time to avoid the binary addition of the data in this case.

m0t = m0.to_tensor(0)
m1t = m1.to_tensor(0)

print(m0t)
print(m1t)
print(m2t)

tensor([0., 1., 0., 3., 0.])
tensor([0., 0., 2., 0., 4.])
[      --,       --,       --,       --,       --]
)


Example of printing a 2d MaskedTensor and setup for reductions below

data = torch.randn(8, 3).mul(10).int().float()
mask = torch.randint(2, (8, 3), dtype=torch.bool)
print(data)
print(m)

tensor([[-11., -10.,   3.],
[  3.,   0.,   6.],
[  4.,   7.,   8.],
[ 17.,   2.,   3.],
[  0.,   3.,   7.],
[ -2.,  -5.,  -4.],
[-16.,  -2.,  -6.],
[ -4.,  12.,   7.]])
tensor([[ True, False, False],
[ True, False, False],
[ True, False,  True],
[False,  True, False],
[False, False,  True],
[False,  True,  True],
[ True,  True, False],
[ True, False,  True]])
[
[-11.0000,       --,       --],
[  3.0000,       --,       --],
[  4.0000,       --,   8.0000],
[      --,   2.0000,       --],
[      --,       --,   7.0000],
[      --,  -5.0000,  -4.0000],
[-16.0000,  -2.0000,       --],
[ -4.0000,       --,   7.0000]
]
)


Reduction semantics based on https://github.com/pytorch/rfcs/pull/27

print("sum:", torch.sum(m, 1))
print("mean:", torch.mean(m, 1))
print("prod:", torch.prod(m, 1))
print("min:", torch.amin(m, 1))
print("max:", torch.amax(m, 1))

sum: masked_tensor(
[-11.0000,   3.0000,  12.0000,   2.0000,   7.0000,  -9.0000, -18.0000,   3.0000]
)
[-11.0000,   3.0000,   6.0000,   2.0000,   7.0000,  -4.5000,  -9.0000,   1.5000]
)
[-11.0000,   3.0000,  32.0000,   2.0000,   7.0000,  20.0000,  32.0000, -28.0000]
)
[-11.0000,   3.0000,   4.0000,   2.0000,   7.0000,  -5.0000, -16.0000,  -4.0000]
)
[-11.0000,   3.0000,   8.0000,   2.0000,   7.0000,  -4.0000,  -2.0000,   7.0000]
)


Now that we have reductions, let’s revisit as to why we’ll probably want to have a good reason to allow addition of MaskedTensors with different masks.

data0 = torch.arange(10.).reshape(2, 5)
data1 = torch.arange(10.).reshape(2, 5) + 10
mask0 = torch.tensor([[True, True, False, False, False], [False, False, False, True, True]])
mask1 = torch.tensor([[False, False, False, True, True], [True, True, False, False, False]])

print("\nnpm0:\n", npm0)
print("\nnpm1:\n", npm1)
print("\n(npm0 + npm1).sum(0):\n", (npm0 + npm1).sum(0))
print("\nnpm0.sum(0) + npm1.sum(0):\n", (npm0.sum(0) + npm1.sum(0)))
print("\n(data0 + data1).sum(0):\n", (data0 + data1).sum(0))
print("\n(data0 + data1).sum(0):\n", (data0.sum(0) + data1.sum(0)))

npm0:
[[-- -- 2.0 3.0 4.0]
[5.0 6.0 7.0 -- --]]

npm1:
[[10.0 11.0 12.0 -- --]
[-- -- 17.0 18.0 19.0]]

(npm0 + npm1).sum(0):
[-- -- 38.0 -- --]

npm0.sum(0) + npm1.sum(0):
[15.0 17.0 38.0 21.0 23.0]

(data0 + data1).sum(0):
tensor([30., 34., 38., 42., 46.])

(data0 + data1).sum(0):
tensor([30., 34., 38., 42., 46.])


Sum and addition should be associative. However with NumPy’s semantics we allow them not to be. Instead of allowing these semantics, at least in the case of addition and sum, we could ask the user to fill the MaskedTensor’s undefined elements with 0 values or as in the MaskedTensor addition examples above be very specific about the semantics used.

While it’s obviously possible to support this, we think we should cover other operators first and really make sure we can’t avoid this behavior via other means.

data = torch.randn(4, 5, 3).mul(5).float()
mask = torch.randint(2, (4, 5, 3), dtype=torch.bool)
print(m)

masked_tensor(
[
[
[ -3.5697,  -8.1847,       --],
[  3.6362,       --,   9.7370],
[      --,   2.9715,  -0.1606],
[      --,       --,   3.3154],
[  3.7489,  -0.0817,  -0.3159]
],
[
[      --,   8.0847,   7.2663],
[      --,   4.7614,       --],
[ -0.1326,   0.7159,  -8.0863],
[  6.0987,   1.4388,   1.4861],
[  0.4542,       --,       --]
],
[
[-10.4363,   4.0115,       --],
[ -0.6105,   4.2993,  -0.7551],
[      --,       --,       --],
[      --,       --,   0.0252],
[ -1.3458,       --,       --]
],
[
[      --,       --,   2.8867],
[      --,  10.2686,   1.7118],
[      --,       --,       --],
[  2.3257,   3.9612,       --],
[ -3.8898,       --,   9.8560]
]
]
)


Example of indexing and advanced indexing

print(m[0])
print(m[torch.tensor([0, 2])])

masked_tensor(
[
[ -3.5697,  -8.1847,       --],
[  3.6362,       --,   9.7370],
[      --,   2.9715,  -0.1606],
[      --,       --,   3.3154],
[  3.7489,  -0.0817,  -0.3159]
]
)
[
[
[ -3.5697,  -8.1847,       --],
[  3.6362,       --,   9.7370],
[      --,   2.9715,  -0.1606],
[      --,       --,   3.3154],
[  3.7489,  -0.0817,  -0.3159]
],
[
[-10.4363,   4.0115,       --],
[ -0.6105,   4.2993,  -0.7551],
[      --,       --,       --],
[      --,       --,   0.0252],
[ -1.3458,       --,       --]
]
]
)
[ -3.5697,  -8.1847,   3.6362,   9.7370,   2.9715,  -0.1606,   3.3154,   3.7489,  -0.0817,  -0.3159,   8.0847,   7.2663,   4.7614,  -0.1326,   0.7159,  -8.0863,   6.0987,   1.4388,   1.4861,   0.4542, -10.4363,   4.0115,  -0.6105,   4.2993,  -0.7551,   0.0252,  -1.3458,   2.8867,  10.2686,   1.7118,   2.3257,   3.9612,  -3.8898,   9.8560]
)


torch.manual_seed(22)
# Sum needs custom autograd, since the mask of the input should be maintained
data = torch.randn(2, 2, 3).mul(5).float()
mask = torch.randint(2, (2, 2, 3), dtype=torch.bool)
print(m)
s = torch.sum(m)
print("s: ", s)
s.backward()

# sum needs to return a scalar MaskedTensor because the input might be fully masked
data = torch.randn(2, 2, 3).mul(5).float()
mask = torch.zeros(2, 2, 3, dtype=torch.bool)
print("\n", m)
s = torch.sum(m)
print("s: ", s)
s.backward()

masked_tensor(
[
[
[      --,       --,  -0.5084],
[  6.7935, -15.3725,       --]
],
[
[      --,   1.2078,       --],
[  6.5820,       --,  -1.6679]
]
]
)
[
[
[      --,       --,   1.0000],
[  1.0000,   1.0000,       --]
],
[
[      --,   1.0000,       --],
[  1.0000,       --,   1.0000]
]
]
)

[
[
[      --,       --,       --],
[      --,       --,       --]
],
[
[      --,       --,       --],
[      --,       --,       --]
]
]
)
[
[
[      --,       --,       --],
[      --,       --,       --]
],
[
[      --,       --,       --],
[      --,       --,       --]
]
]
)

# Grad of multiplication of MaskedTensor and Tensor
print("x:\n", x)
print("y:\n", y)
# In general a MaskedTensor is considered a generalization of Tensor's shape.
# The mask is a more complex, higher dimensional shape and thus the Tensor
# broadcasts to it. I'd love to find a more rigorous definition of this.
z = x * y
print("x * y:\n", z)
z.sum().backward()

x:
[  3.0000,       --]
)
y:
x * y:
[  6.0000,       --]
)

[  2.0000,       --]
)
[  3.0000,       --]
)


### A note on is_contiguous

# is_contiguous doesn't work
t = torch.arange(4).reshape(1, 2, 2).float()
t = t.clone()
mt = mt.view(mt.size())
mt = mt.transpose(0, 1)
print(mt.is_contiguous(), mt.size(), mt.stride())
mt = mt.view(mt.size())
print(mt.is_contiguous(), mt.size(), mt.stride())
mt = mt.contiguous()
print(mt.is_contiguous(), mt.size(), mt.stride())

True torch.Size([2, 1, 2]) (2, 2, 1)
True torch.Size([2, 1, 2]) (2, 4, 1)
True torch.Size([2, 1, 2]) (2, 2, 1)
True torch.Size([2, 1, 2]) (2, 2, 1)

# Because .contiguous doesn't work we need to modify view to use reshape instead