Distinguishing between 0 and NaN gradient
import torch
import numpy as np
from maskedtensor import masked_tensor
from maskedtensor import as_masked_tensor
Resolving Issues
One issue that vanilla tensors run into is the inability to distinguish between gradients that are not defined (nan) vs. gradients that are actually 0.
Below, by way of example, we show several different issues where torch.Tensor
falls short and MaskedTensor
can resolve and/or work around the NaN gradient problem.
PyTorch Issue 10729 - torch.where
PyTorch result:
# This behavior underlies the fix to clamp, which uses where in its derivative
x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], requires_grad=True)
y = torch.where(x < 0, torch.exp(x), torch.ones_like(x))
print("y:", y)
y.sum().backward()
print("x.grad:", x.grad)
print("y.grad:", y.grad)
y: tensor([4.5400e-05, 6.7379e-03, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00],
grad_fn=<WhereBackward0>)
x.grad: tensor([4.5400e-05, 6.7379e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, nan, nan])
y.grad: None
/tmp/ipykernel_2351/3791710618.py:7: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at aten/src/ATen/core/TensorBody.h:478.)
print("y.grad:", y.grad)
MaskedTensor result:
x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], requires_grad=True)
mask = x < 0
mx = masked_tensor(x, mask, requires_grad=True)
my = masked_tensor(torch.ones_like(x), ~mask, requires_grad=True)
y = torch.where(mask, torch.exp(mx), my)
s = y.sum()
s.backward()
# Gradient is only provided to selected subset.
# Effectively this changes the gradient of where to mask out elements instead
# of setting them to zero.
print("mx.grad: ", mx.grad)
mx.grad: masked_tensor(
[ 0.0000, 0.0067, --, --, --, --, --, --, --, --, --]
)
The gradient here is only provided to the selected subset. Effectively, this changes the gradient of where to mask out elements instead of setting them to zero.
PyTorch Issue 52248 - another torch.where
PyTorch result:
# A more recent incarnation specific to where of this
# https://github.com/pytorch/pytorch/issues/52248
a = torch.randn((), requires_grad=True)
b = torch.tensor(False)
c = torch.ones(())
print(torch.where(b, a/0, c))
print(torch.autograd.grad(torch.where(b, a/0, c), a))
tensor(1., grad_fn=<WhereBackward0>)
(tensor(nan),)
MaskedTensor result:
a = masked_tensor(torch.randn(()), torch.tensor(True), requires_grad=True)
b = torch.tensor(False)
c = torch.ones(())
print(torch.where(b, a/0, c))
print(torch.autograd.grad(torch.where(b, a/0, c), a))
masked_tensor( 1.0000, True)
(masked_tensor(--, False),)
PyTorch Issue 67180 - torch.nansum and torch.nanmean
PyTorch result:
a = torch.tensor([1., 2., float('nan')])
b = torch.tensor(1.0, requires_grad=True)
c = a * b
c1 = torch.nansum(c) # or torch.nanmean
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
bgrad1
tensor(nan)
MaskedTensor result:
a = torch.tensor([1., 2., float('nan')])
b = torch.tensor(1.0, requires_grad=True)
ma = masked_tensor(a, ~torch.isnan(a))
c = ma * b
c1 = torch.sum(c) # or torch.nanmean
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
bgrad1
masked_tensor( 3.0000, True)
PyTorch Issue 4132 - when using mask, x/0 yields NaN grad
PyTorch result:
x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
y = x/div # => y is [inf, 1]
mask = (div != 0) # => mask is [0, 1]
loss = y[mask]
loss.backward()
x.grad # grad is [nan, 1], but expected [0, 1]
tensor([nan, 1.])
MaskedTensor result:
x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
y = x/div # => y is [inf, 1]
mask = (div != 0) # => mask is [0, 1]
loss = as_masked_tensor(y, mask)
# We could add autograd support for indexing here instead of using sum
loss = loss.sum()
loss.backward()
x.grad # grad is [nan, 1], but expected [0, 1]
masked_tensor(
[ --, 1.0000]
)