Resolving NaN Grad

Open in Colab

import torch
import numpy as np
if "1.11.0" not in torch.__version__:
    !pip uninstall --y torch
    !pip install torch -f --pre
# Import factory function
from maskedtensor import masked_tensor
from maskedtensor import as_masked_tensor

Resolving Issues

One issue that vanilla tensors run into is the inability to differentiate between gradients that are not defined (nan) vs. gradients that are actually 0.

Below, by way of example, we show several different issues where torch.Tensor falls short and MaskedTensor can resolve and/or work around the NaN gradient problem.

PyTorch Issue 10729 - torch.where

PyTorch result:

# This behavior underlies the fix to clamp, which uses where in its derivative
x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], requires_grad=True)
y = torch.where(x < 0, torch.exp(x), torch.ones_like(x))
print("y:", y)
print("x.grad:", x.grad)
print("y.grad:", y.grad)
y: tensor([4.5400e-05, 6.7379e-03, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00],
x.grad: tensor([4.5400e-05, 6.7379e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00,        nan,        nan])
y.grad: None
/data/home/georgeqi/miniconda/envs/mt_release_0.1.0/lib/python3.8/site-packages/torch/ UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See for more informations. (Triggered internally at  aten/src/ATen/core/TensorBody.h:475.)
  return self._grad

MaskedTensor result:

x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], requires_grad=True)
mask = x < 0
mx = masked_tensor(x, mask, requires_grad=True)
my = masked_tensor(torch.ones_like(x), ~mask, requires_grad=True)
y = torch.where(mask, torch.exp(mx), my)
s = y.sum()
# Gradient is only provided to selected subset.
# Effectively this changes the gradient of where to mask out elements instead
# of setting them to zero.
print("mx.grad: ", mx.grad)
mx.grad:  masked_tensor(
  [  0.0000,   0.0067,       --,       --,       --,       --,       --,       --,       --,       --,       --]

The gradient here is only provided to the selected subset. Effectively, this changes the gradient of where to mask out elements instead of setting them to zero.

PyTorch Issue 52248 - another torch.where

PyTorch result:

# A more recent incarnation specific to where of this

a = torch.randn((), requires_grad=True)
b = torch.tensor(False)
c = torch.ones(())

print(torch.where(b, a/0, c))
print(torch.autograd.grad(torch.where(b, a/0, c), a))
tensor(1., grad_fn=<SWhereBackward0>)

MaskedTensor result:

a = masked_tensor(torch.randn(()), torch.tensor(True), requires_grad=True)
b = torch.tensor(False)
c = torch.ones(())

print(torch.where(b, a/0, c))
print(torch.autograd.grad(torch.where(b, a/0, c), a))
masked_tensor(  1.0000, True)
(masked_tensor(--, False),)

PyTorch Issue 67180 - torch.nansum and torch.nanmean

PyTorch result:

a = torch.tensor([1., 2., float('nan')])
b = torch.tensor(1.0, requires_grad=True)
c = a * b
c1 = torch.nansum(c)  # or torch.nanmean

bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)

MaskedTensor result:

a = torch.tensor([1., 2., float('nan')])
b = torch.tensor(1.0, requires_grad=True)
ma = masked_tensor(a, ~torch.isnan(a))
c = ma * b
c1 = torch.sum(c)  # or torch.nanmean

bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
masked_tensor(  3.0000, True)

PyTorch Issue 4132 - when using mask, x/0 yields NaN grad

PyTorch result:

x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
y = x/div # => y is [inf, 1]

mask = (div != 0) # => mask is [0, 1]
loss = y[mask]

x.grad # grad is [nan, 1], but expected [0, 1]
tensor([nan, 1.])

MaskedTensor result:

x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
y = x/div # => y is [inf, 1]

mask = (div != 0) # => mask is [0, 1]
loss = as_masked_tensor(y, mask)
# We could add autograd support for indexing here instead of using sum
loss = loss.sum()

x.grad # grad is [nan, 1], but expected [0, 1]
  [      --,   1.0000]